(Fast) Probabilistic Graphs

Fast Probabilistic Programming with JAX & NumPyro

Zach Wolpe
7 min readMay 23, 2023

Probabilistic Inference is extremely useful. Many machine learning models fail to adequately quantify sensitivity to & variability in the space of possible models. Probabilistic, Bayesian models allow us to both be more explicit in our assumptions, & encode priors & variational constraints.

Probabilistic graphs take this a step further, allowing us to build probabilistic models over any arbitrary set of operations. A real-world system as both stochastic and deterministic mechanics.

These methods are old and deeply statistical, meaning they have been rigorously studied (backed by decades of mathematical scrutiny). The problem is these old statistical paradigms are:

  1. Not built to scale.
  2. Incompatible with contemporary Deep Learning frameworks.

The sampling-heavy optimization procedures required to fit probabilistic programs become intractable as data grows, and many of the tools built to handle these models do not allow for sufficient integration with deep learning architectures: enter JAX.

JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

This is exactly what we need to compute gradients & optimize parameters in large probabilistic graphs. So how do we build probabilistic graphs on JAX? Enter NumPyro.

NumPyro is a lightweight probabilistic programming library that provides a NumPy backend for Pyro. We rely on JAX for automatic differentiation and JIT compilation to GPU / CPU.

I have been deploying probabilistic graphs to quantify the variability between groups, here’s an example.

Data

Consider the Countries of the World dataset, the data contains various metrics. We wish to understand the relationship between GDP per capita & social connectivity (using Phones/1000 people as a proxy for connectivity).

Are there (latent) groups that significantly affect the data-generating process?

Examining the data by Region yield:

Feature Engineering: Clustering

The geographic regions are useful, but we might find cleaner groupings by clustering the data — modelling the variability across a series of Socioeconomics measurements.

# KPCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score
from sklearn.decomposition import KernelPCA
from sklearn.pipeline import Pipeline
from sklearn.cluster import KMeans

# scale
scaler = StandardScaler()
kpca = KernelPCA(n_components=2, kernel='rbf', gamma=0.1)
kmeans = KMeans(n_clusters=3, random_state=1)
pipe = Pipeline([('scaler', scaler), ('kpca', kpca), ('kmeans', kmeans)])

# fit
cntry = CD.countries.dropna()
pipe.fit(cntry.iloc[:,2:])
cntry['Cluster'] = pipe.predict(cntry.iloc[:,2:])
cntry['Cluster'] = cntry['Cluster'].astype(int)
cntry['Cluster2'] = cntry['Cluster'].astype(str)

It appears to better represent the data, so we’ll use this feature going forward.

Machine Learning: GAMs

We need a set of flexible — regularized — functions to model the data. A GAM (General Additive Model) is well suited to the task. GAMs use a series of non-linear basis functions approximate trends in data.

from modules.dependencies import *

class FrequentistModels:

models = {
'Linear Regression': LinearRegression,
'Linear GAM': LinearGAM
}

def __init__(self, X_train, y_train, X_test, y_test, model_name='Linear Regression') -> None:
self.X_train = X_train
self.y_train = y_train
self.X_test = X_test
self.y_test = y_test
self.model = self.models[model_name]().fit(self.X_train, self.y_train)
self.model_name = model_name


def set_model(self, model_name):
assert model_name in FrequentistModels.models.keys(), 'Model must be one of [{}]'.format(FrequentistModels.models.keys())
self.model = model_name
return self

def fit(self):
self.model.fit(self.X_train, self.y_train)
return self

def predict(self, X=None):
if X is not None:
return self.model.predict(X)
self.y_pred = self.model.predict(self.X_test)
return self

def print_model(self):
if hasattr(self.model, 'summary'):
print(self.model.summary())
else:
print(self.model)
return self

def print_score(self):
print('R^2: ', r2_score(self.y_test, self.y_pred))
return self

Fitting a GAM works great, yielding a 74% validation accuracy without much tuning. We have no information, however, about the variability across clusters.

Probabilistic Inference

JAX + NumPyro

Superb! Our model is looking great, but can we be confident in predicting outside of the training domain? Unlike a pure research setting, engineering systems require additional robustness before deployment.

Latent Space Representation

It is advantageous to identify points of potential fragility & adequately represent uncertainty in the system. Techniques to achieve semisupervised learning, domain generalization or domain adaptation can take many forms. The approach generally follows:

Learn some latent space representation that is able to generate the original data — maximizing information retention during compression. In theory, this maximizes the likelihood of learning the latent process that generates the data (and thus generates data outside of the training set).

Our feature-engineered Kernel PCA clusters represent such a subspace.

Bayesian Inference

By quantifying the variability and sensitivity of the subspace, we are able to:

  1. Inject better features into downstream models.
  2. Return confidence in the prediction space.
  3. Flag potential weak points before productionizing.

We define a general-purpose Bayesian Modelling Interface, to enforce a structure in our all downstream concrete classes.

from modules.dependencies import *

class params:
num_warmup = 500
num_samples = 1000
num_chains = 2
disable_progbar = False

class Bayesian_Model_Interface(ABC):
def __init__(self, model, runtime_params, *args, **kwargs):
self.runtime_params = runtime_params
self._model = model
self._mcmc = None

@property
def model(self):
return self._model

@model.setter
def model(self, model):
self._model = model

@property
def mcmc(self):
if self._mcmc is None:
raise ValueError('Model not fit. Run .fit() first.')
return self._mcmc

@mcmc.setter
def mcmc(self, value):
if not isinstance(value, MCMC):
raise ValueError('Value must be numpyro MCMC object.')
self._mcmc = value

def compute_model_accuracy():
pass

def print_summary(self):
self.mcmc.print_summary()

def plot_trace(self):
az.plot_trace(self.mcmc)

Concrete Class

Using this interface, we can define a concrete class that represents any arbitrary set of probabilistic graphs. The BayesianModelEngine.ModelGenerator() method is able to instantiate any linear combination of:

  • Linear covariates.
  • Random intercepts.
  • Random slopes.
  • Variance pooling techniques.
class BayesianModelEngine(Bayesian_Model_Interface):

def __init__(self, runtime_params=params, *args, **kwargs):
super().__init__(self, runtime_params, *args, **kwargs)
self._mcmc = None
self._model = self.InterceptModelGenerator

@staticmethod
def get_plate_length(y=None, Group=None, Covariates=None):
"""
Get plate length for numpyro plate context manager.
"""
if y is not None:
return len(y)
elif Group is not None:
return len(Group)
else:
if Covariates is None:
return 1
Cov1 = list(Covariates.values())[0]
return len(Cov1)

@staticmethod
def Group_Intercept(Group, Technique):
"""
Technique:
-> FullyPooled
-> PartiallyPooled
-> Unpooled
"""
if Technique == 'FullyPooled':
# independent of group
rand_int = numpyro.sample('rand_int', dist.Normal(0., 1.))
return rand_int

if Technique == 'PartiallyPooled':
mu_a = numpyro.sample('mu_a', dist.Normal(0., 5.))
sg_a = numpyro.sample('sg_a', dist.HalfNormal(5.))
n_grp = len(np.unique(Group))
with numpyro.plate("plate_i", n_grp):
rand_int = numpyro.sample('rand_int', dist.Normal(mu_a, sg_a))
return rand_int[Group]

if Technique == 'Unpooled':
n_grp = len(np.unique(Group))
with numpyro.plate("plate_i", n_grp):
rand_int = numpyro.sample('rand_int', dist.Normal(0., 0.3))
return rand_int[Group]
raise ValueError('Technique must be one of [FullyPooled, PartiallyPooled, Unpooled]')

@staticmethod
def random_slope(Group, Xvar):
n_grp = len(np.unique(Group))
mu_b = numpyro.sample('mu_b', dist.Normal(0., 5.), sample_shape=(n_grp,))
sg_b = numpyro.sample('sg_b', dist.HalfNormal(5.), sample_shape=(n_grp,))
with numpyro.plate("plate_s", n_grp):
rand_slope = numpyro.sample('rand_slope', dist.Normal(mu_b, sg_b))
return rand_slope[Group] * Xvar

def InterceptModelGenerator(self, y=None, Group=None, Group_Technique='PartiallyPooled', Intercept:bool=False, RandSlopeVar=None, **Covariates):
"""
Flexible Model Generator.
"""
Z = 0.

if Intercept:
a = numpyro.sample('intercept', dist.Normal(0., 0.2))
Z += a

n_covs = len(Covariates.keys())
if n_covs > 0:
Beta = numpyro.sample('Beta', dist.Normal(0., 0.5), sample_shape=(n_covs,))
for i,(k,v) in enumerate(Covariates.items()):
Z += v.dot(Beta[i])

if Group is not None:
# random intercepts | Group
rand_int = BayesianModelEngine.Group_Intercept(Group, Group_Technique)
Z += rand_int

# random slopes | Group
if RandSlopeVar is not None:
# RandoSlopes : RandSlopeVar=Xvar : Random Slope Indicates the Variable name to attach random slope to.
try:
Z += BayesianModelEngine.random_slope(Group, Covariates[RandSlopeVar])
except Exception as e:
print(e)
print('Covariates: ', Covariates.keys())

sigma = numpyro.sample('sigma', dist.Exponential(1.))
n = BayesianModelEngine.get_plate_length(y=y, Group=Group, Covariates=Covariates)

with numpyro.plate('data', n):
numpyro.sample('obs', dist.Normal(Z, sigma), obs=y)

Experiments & Tuning

We fit a suite of models and approximate out-of-sample performance (WAIC/LOO) to select a model. Our suite is run in 4 stages:

  • Stage 1. Pooling Technique: Select the optimal grouping technique: {Unpooled, FullyPooled, & Partially Pooled}. Controlling variance constraints across clusters.
  • Stage 2. Grand Intercept: After selecting a pooling technique, we assess of a grand/overall intercept is beneficial.
  • Stage 3. Random Slopes: Thereafter we test the relevance of attaching random slopes for each cluster. The slope variable is interchangeable.

Results

Approximating WAIC results in selecting the parameterization:

Fit_1 = dict(
Inference_Engine = BayesianModelEngine(),
params = params,
y = fd.to_jax(fd.y_training),
X1 = fd.to_jax(fd.X_training.astype(float)),
Group = fd.to_jax(fd.group_training),
Group_Technique = 'PartiallyPooled',
Intercept = True
)

Which instantiates a partially pooled, random intercept, grand intercept model.

R_hat = 1 indicates the goodness of fit for each parameter. We also don’t see any chain divergences, another sign of successful Markov Chain convergence.
Our model reveals a component of the (latent) data-generating process by converging on distinct sampling distributions over each socioeconomic cluster. These posteriors can be used as additional downstream features, to improve variance estimates and to catch potential out-of-sample draws that may hurt performance when predicting on unseen data in production. Note the different mean/std across posteriors.

Conclusion

These models are fast, easy to build, easy to customize and amenable to many deep-learning architectures (able to integrate with PyTorch/ Tensorflow). Any business/operational logic can be encoded as a series of deterministic operations. Uncertainty/parameter estimation can be encoded as stochastic processes.

The shortcoming of these models is their inherent complexity, requiring deep expert knowledge. This may be (somewhat) mitigated by specifying uninformative priors over ambiguous parameters (which will produce behaviour similar to frequentist models in the limit).

The full implementation is available on GitHub.

--

--