Domain Adaptation & Generalisation

Maximising model performance when predicting out of sample.

Zach Wolpe
6 min readApr 20, 2023
Learning efficient embeddings to generalize out of sample.

Deploying machine learning can be precarious. Traditional software requires rigorous testing before deployment. Machine Learning carries the same constraints, as well as the additional burden of assessing the stochastic nature of the system.

Are we confident that our model performance is not degrading over time?

Good MLOps accounts for the inherent stochasticity generated by sampling outside of the training domain.

Here are some techniques we use to monitor & mitigate ML drift.

The Fragility of Machine Learning

All machine & statistical learning rely heavily on mathematical assumptions, & more often than not poor model performance is a consequence of violating these assumptions.

The primary assumption is that the in-sample data Xi is equally distributed to the out-of-sample data P(Xi) = P(Xj) & further that the conditional distribution of the response variable given the input data P(y|X) is equivalent across both samples P(yi|Xi) = P(yj|Xj).

Engineering real systems often necessitate violating these assumpts.

  1. Out-of-Domain Prediction: We may wish to transform data outside of the training domain. For example, suppose we build a model to predict cancer. Applying the model in hospitals not included in the dataset may result in poor performance since P(Xi) != P(Xj).
  2. Change in data over time: The distribution of the data may change over time in unobvious ways. For example, suppose you deploy a model to predict consumer spending habits. External changes such as the move towards remote work or the consumer base ageing could change consumer behaviour (degrading model performance over time).

It is impractical to prevent these types of violations. It is, therefore, imperative to explicitly factor these violations into the machine learning process. I.E. Domain Adaptation & Generalisation.

Preliminaries

Data Drift

There are several types of data drift, broader categorised:

Concept Drift: The decision boundary P(y|X) has changed. Also referred to as Class Drift or Posterior Probability Shift, Concept Drift changes the mapping from X to y.

Data Drift: Also referred to as Covariate Shift, Data Drift refers to the instance where the decision boundary has not changed, but the probability distribution of the input P(X) or p(y) has.

Concept drift. Left) Original sample. Centre) Concept drift: a shift in the data generating process P(y|X). Right) The green region represents misclassification by failure to account for the concept drift.
Definitions taken from A Survey on Domain Generalisation.

Solutions can broadly be split into two categories: Monitoring (Detection) is concerned with identifying a violation & Weighting (Prevention) is concerned with updating the cost function or optimisation method used during training.

Dangers of ML shift

A demonstration of performance decay due to covariate shift &/or data drift. Setup: suppose we fit a B-Splines model (right) to capture the relationship in some data. The (unknown) data-generating process (left) is sufficiently captured by the model.
Over time, a large trend in the data-generating process emerges. (left). If remodelled, the new model adequately captures the data (centre). The original Spline fails to predict out-of-sample (right).

Monitoring

Good MLOps includes both model versioning & performance monitoring. A retraining & lifelong learning schedule should be derived from live feedback from models in production.

Here are methods to monitor model performance.

Supervised Learning Models

  • Metrics: statistical measures, accuracy, precision, FPR, AUC etc.
  • Supervised learning: “A survey of concept drift adaptation
  • Sequential analysis (SPRT) to tune alarms on false positives.
  • Statistical process control (SPC) — the rate of change.
  • Monitoring 2 distributions (ADWIN) — more precise more overhead.

Unsupervised Models

  • PSI (population stability index).
  • KL divergence.
  • Jensen-Shannon (JS).
  • KS test.

Monitoring is straightforward if your infrastructure is set up adequately.

Domain Adaptation & Generalisation

Monitoring is important to insure the longevity of a model, but there are also preventative measures to improve the likelihood of extrapolating the results beyond the training set. Primarily following some procedure akin to:

Given the input data X learn a lower dimensional manifold ø(X) (compressing the data to extract the latent data-generating process) & use the manifold transformation to either engineer features or model input such that y ~ f(ø(X)).

We tend to remove segments of the data during training & assess the sensitivity of model pipelines to these injected perturbations.

Assumption: the training data is drawn from the same set of sources as the target (test) data.

This formula supports many embedding/model variants. One common approach — which we have implemented & tested rigorously — is the TCA++ framework. The general trend is as follows:

  1. Derive a simple dimensionality reduction objective ∂(x) function that compresses that raw input in a way that maintains local & global structure (data that is clustered/distant in x is clustered/distant in ∂(x)).
  2. Add terms to account for supervision ∂(x,y).
  3. Add terms to account for regularisation (usually just L1 or L2).
  4. Use Linear Algebra to reduce the final objective function — usually by eigenvalue decomposition after expressing the function in its Lagrangian form.

TCA++ is a flexible framework that supports all of these variants, proving to be a useful starting point.

Transfer Component Analysis (TCA++)

Assuming P(Xi) != P(Xj) our goal is to find a feature map ø such that P(ø(Xi)) == P(ø(Xj)).

Beginning in an unsupervised setting, TCA can be used to learn a latent embedded transformation ø that is likely to map across domains. TCA is extendable in the following ways:

  • SSTCA: Semi-supervised TCA incorporates labels.
  • Multi-TCA: Extends TCA to domain generalisation adding regularisation.
  • Multi-SSTCA: Extends SSTCA to domain generalisation adding regularisation.

The model formulation is provided below, here is the high-level takeaway:

  1. MMD forms the base of the cost function. It relies on ø.
  2. MMD can be expressed in (computationally tractable) matrix algebra. This is the solution to standard TCA (equation 4).
  3. Multi-TCA entends TCA to domain generalisation by adding a regularization term to curb complexity (equation 5).
  4. MMD (in both its original & regularized form) can be approximated with eigenvectors (equation 6).
  5. W is an orthogonal weight matrix used to transform the data into the reduced space. Since the objective — identical to PCA — is to maximise variance in the reduced space, W is given as the m<<N leading eigenvectors (equation 6).
  6. Semi-Supervised SSTCA adds a term to the cost function to include label dependence ø(yi, yj) (equation 7).
  7. Multi-SSTCA adds a term to the cost function to preserve locality (equations 9 & 10).

Mathematical Formulation

Extract from Thomas Grubinger et al.

--

--