(Fast) Probabilistic Graphs
Fast Probabilistic Programming with JAX & NumPyro
Probabilistic Inference is extremely useful. Many machine learning models fail to adequately quantify sensitivity to & variability in the space of possible models. Probabilistic, Bayesian models allow us to both be more explicit in our assumptions, & encode priors & variational constraints.
Probabilistic graphs take this a step further, allowing us to build probabilistic models over any arbitrary set of operations. A real-world system as both stochastic and deterministic mechanics.
These methods are old and deeply statistical, meaning they have been rigorously studied (backed by decades of mathematical scrutiny). The problem is these old statistical paradigms are:
- Not built to scale.
- Incompatible with contemporary Deep Learning frameworks.
The sampling-heavy optimization procedures required to fit probabilistic programs become intractable as data grows, and many of the tools built to handle these models do not allow for sufficient integration with deep learning architectures: enter JAX
.
This is exactly what we need to compute gradients & optimize parameters in large probabilistic graphs. So how do we build probabilistic graphs on JAX? Enter NumPyro
.
NumPyro is a lightweight probabilistic programming library that provides a NumPy backend for Pyro. We rely on JAX for automatic differentiation and JIT compilation to GPU / CPU.
I have been deploying probabilistic graphs to quantify the variability between groups, here’s an example.
Data
Consider the Countries of the World dataset, the data contains various metrics. We wish to understand the relationship between GDP per capita & social connectivity (using Phones/1000 people as a proxy for connectivity).
Are there (latent) groups that significantly affect the data-generating process?
Examining the data by Region yield:
Feature Engineering: Clustering
The geographic regions are useful, but we might find cleaner groupings by clustering the data — modelling the variability across a series of Socioeconomics measurements.
# KPCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score
from sklearn.decomposition import KernelPCA
from sklearn.pipeline import Pipeline
from sklearn.cluster import KMeans
# scale
scaler = StandardScaler()
kpca = KernelPCA(n_components=2, kernel='rbf', gamma=0.1)
kmeans = KMeans(n_clusters=3, random_state=1)
pipe = Pipeline([('scaler', scaler), ('kpca', kpca), ('kmeans', kmeans)])
# fit
cntry = CD.countries.dropna()
pipe.fit(cntry.iloc[:,2:])
cntry['Cluster'] = pipe.predict(cntry.iloc[:,2:])
cntry['Cluster'] = cntry['Cluster'].astype(int)
cntry['Cluster2'] = cntry['Cluster'].astype(str)
It appears to better represent the data, so we’ll use this feature going forward.
Machine Learning: GAMs
We need a set of flexible — regularized — functions to model the data. A GAM (General Additive Model) is well suited to the task. GAMs use a series of non-linear basis functions approximate trends in data.
from modules.dependencies import *
class FrequentistModels:
models = {
'Linear Regression': LinearRegression,
'Linear GAM': LinearGAM
}
def __init__(self, X_train, y_train, X_test, y_test, model_name='Linear Regression') -> None:
self.X_train = X_train
self.y_train = y_train
self.X_test = X_test
self.y_test = y_test
self.model = self.models[model_name]().fit(self.X_train, self.y_train)
self.model_name = model_name
def set_model(self, model_name):
assert model_name in FrequentistModels.models.keys(), 'Model must be one of [{}]'.format(FrequentistModels.models.keys())
self.model = model_name
return self
def fit(self):
self.model.fit(self.X_train, self.y_train)
return self
def predict(self, X=None):
if X is not None:
return self.model.predict(X)
self.y_pred = self.model.predict(self.X_test)
return self
def print_model(self):
if hasattr(self.model, 'summary'):
print(self.model.summary())
else:
print(self.model)
return self
def print_score(self):
print('R^2: ', r2_score(self.y_test, self.y_pred))
return self
Fitting a GAM works great, yielding a 74%
validation accuracy without much tuning. We have no information, however, about the variability across clusters.
Probabilistic Inference
JAX + NumPyro
Superb! Our model is looking great, but can we be confident in predicting outside of the training domain? Unlike a pure research setting, engineering systems require additional robustness before deployment.
Latent Space Representation
It is advantageous to identify points of potential fragility & adequately represent uncertainty in the system. Techniques to achieve semisupervised learning, domain generalization or domain adaptation can take many forms. The approach generally follows:
Our feature-engineered Kernel PCA clusters represent such a subspace.
Bayesian Inference
By quantifying the variability and sensitivity of the subspace, we are able to:
- Inject better features into downstream models.
- Return confidence in the prediction space.
- Flag potential weak points before productionizing.
We define a general-purpose Bayesian Modelling Interface, to enforce a structure in our all downstream concrete classes.
from modules.dependencies import *
class params:
num_warmup = 500
num_samples = 1000
num_chains = 2
disable_progbar = False
class Bayesian_Model_Interface(ABC):
def __init__(self, model, runtime_params, *args, **kwargs):
self.runtime_params = runtime_params
self._model = model
self._mcmc = None
@property
def model(self):
return self._model
@model.setter
def model(self, model):
self._model = model
@property
def mcmc(self):
if self._mcmc is None:
raise ValueError('Model not fit. Run .fit() first.')
return self._mcmc
@mcmc.setter
def mcmc(self, value):
if not isinstance(value, MCMC):
raise ValueError('Value must be numpyro MCMC object.')
self._mcmc = value
def compute_model_accuracy():
pass
def print_summary(self):
self.mcmc.print_summary()
def plot_trace(self):
az.plot_trace(self.mcmc)
Concrete Class
Using this interface, we can define a concrete class that represents any arbitrary set of probabilistic graphs
. The BayesianModelEngine.ModelGenerator()
method is able to instantiate any linear combination of:
- Linear covariates.
- Random intercepts.
- Random slopes.
- Variance pooling techniques.
class BayesianModelEngine(Bayesian_Model_Interface):
def __init__(self, runtime_params=params, *args, **kwargs):
super().__init__(self, runtime_params, *args, **kwargs)
self._mcmc = None
self._model = self.InterceptModelGenerator
@staticmethod
def get_plate_length(y=None, Group=None, Covariates=None):
"""
Get plate length for numpyro plate context manager.
"""
if y is not None:
return len(y)
elif Group is not None:
return len(Group)
else:
if Covariates is None:
return 1
Cov1 = list(Covariates.values())[0]
return len(Cov1)
@staticmethod
def Group_Intercept(Group, Technique):
"""
Technique:
-> FullyPooled
-> PartiallyPooled
-> Unpooled
"""
if Technique == 'FullyPooled':
# independent of group
rand_int = numpyro.sample('rand_int', dist.Normal(0., 1.))
return rand_int
if Technique == 'PartiallyPooled':
mu_a = numpyro.sample('mu_a', dist.Normal(0., 5.))
sg_a = numpyro.sample('sg_a', dist.HalfNormal(5.))
n_grp = len(np.unique(Group))
with numpyro.plate("plate_i", n_grp):
rand_int = numpyro.sample('rand_int', dist.Normal(mu_a, sg_a))
return rand_int[Group]
if Technique == 'Unpooled':
n_grp = len(np.unique(Group))
with numpyro.plate("plate_i", n_grp):
rand_int = numpyro.sample('rand_int', dist.Normal(0., 0.3))
return rand_int[Group]
raise ValueError('Technique must be one of [FullyPooled, PartiallyPooled, Unpooled]')
@staticmethod
def random_slope(Group, Xvar):
n_grp = len(np.unique(Group))
mu_b = numpyro.sample('mu_b', dist.Normal(0., 5.), sample_shape=(n_grp,))
sg_b = numpyro.sample('sg_b', dist.HalfNormal(5.), sample_shape=(n_grp,))
with numpyro.plate("plate_s", n_grp):
rand_slope = numpyro.sample('rand_slope', dist.Normal(mu_b, sg_b))
return rand_slope[Group] * Xvar
def InterceptModelGenerator(self, y=None, Group=None, Group_Technique='PartiallyPooled', Intercept:bool=False, RandSlopeVar=None, **Covariates):
"""
Flexible Model Generator.
"""
Z = 0.
if Intercept:
a = numpyro.sample('intercept', dist.Normal(0., 0.2))
Z += a
n_covs = len(Covariates.keys())
if n_covs > 0:
Beta = numpyro.sample('Beta', dist.Normal(0., 0.5), sample_shape=(n_covs,))
for i,(k,v) in enumerate(Covariates.items()):
Z += v.dot(Beta[i])
if Group is not None:
# random intercepts | Group
rand_int = BayesianModelEngine.Group_Intercept(Group, Group_Technique)
Z += rand_int
# random slopes | Group
if RandSlopeVar is not None:
# RandoSlopes : RandSlopeVar=Xvar : Random Slope Indicates the Variable name to attach random slope to.
try:
Z += BayesianModelEngine.random_slope(Group, Covariates[RandSlopeVar])
except Exception as e:
print(e)
print('Covariates: ', Covariates.keys())
sigma = numpyro.sample('sigma', dist.Exponential(1.))
n = BayesianModelEngine.get_plate_length(y=y, Group=Group, Covariates=Covariates)
with numpyro.plate('data', n):
numpyro.sample('obs', dist.Normal(Z, sigma), obs=y)
Experiments & Tuning
We fit a suite of models and approximate out-of-sample performance (WAIC/LOO
) to select a model. Our suite is run in 4 stages:
- Stage 1. Pooling Technique: Select the optimal
grouping technique: {Unpooled, FullyPooled, & Partially Pooled}
. Controlling variance constraints across clusters. - Stage 2. Grand Intercept: After selecting a pooling technique, we assess of a grand/overall intercept is beneficial.
- Stage 3. Random Slopes: Thereafter we test the relevance of attaching random slopes for each cluster. The slope variable is interchangeable.
Results
Approximating WAIC results in selecting the parameterization:
Fit_1 = dict(
Inference_Engine = BayesianModelEngine(),
params = params,
y = fd.to_jax(fd.y_training),
X1 = fd.to_jax(fd.X_training.astype(float)),
Group = fd.to_jax(fd.group_training),
Group_Technique = 'PartiallyPooled',
Intercept = True
)
Which instantiates a partially pooled, random intercept, grand intercept model.
Conclusion
These models are fast, easy to build, easy to customize and amenable to many deep-learning architectures (able to integrate with PyTorch/ Tensorflow). Any business/operational logic can be encoded as a series of deterministic operations. Uncertainty/parameter estimation can be encoded as stochastic processes.
The shortcoming of these models is their inherent complexity, requiring deep expert knowledge. This may be (somewhat) mitigated by specifying uninformative priors over ambiguous parameters (which will produce behaviour similar to frequentist models in the limit).
The full implementation is available on GitHub.