How we use Self-Supervised Learning in Personalised Medicine

Operating in the realm of physiological and behavioural data has never been more fruitful, nor more challenging. As a consequence of the widespread adoption of wearables, coupled with the ability to decode behaviour information from apps/wearables, users effectively generate data continuously.

While the associated software engineering challenges to store and query this information are largely circumvented by accessible cloud infrastructure APIs, integrating this data into efficient machine learning workflows is non-trivial. Positing models in a Bayesian or RL paradigm allows for online posterior updates to tune models in light of additional information, but this is not applicable to unlocking new unlabeled data sources.

How do we efficiently leverage the acquisitions of novel and large data sources without rebuilding our entire data pipeline?

Data Flow

In the medical machine learning world our data often constitutes:

  • Clinical & academic studies: small, highly reliable, explicit datasets that investigate the relationships between physiological, psychological and behavioural phenomena. Often addressing a specific question of interest.
  • Realtime user-generated data: large, messy, unlabeled data that is regularly generated during the deployment of products.

Example product

Some of our applications are concerned with stress management. The premise and initial machine learning models are built on data from validated clinical studies. These studies allow us to make assumptions about behaviours that both cause and follow periods of high stress.

Consider a product that uses EEG, EKG and behavioural data (extracted from a mobile device) to monitor levels of stress and anxiety. Behavioural data may include things like where somebody is spending their time (work, outdoors, social etc), the use of social media, spending habits etc. The product can nudge behaviour to manage stress, and detect signs of mental health decay.

Clinical studies may be used to build the initial product. After launch, the subsequent data flow may be inconsistent and diverse.

Assume this product requires EEG readings and provides diagnostics and long-term monitoring of stress levels. New users generate large amounts of unlabeled additional EEG data.

Objective

Can we leverage the large (unlabeled) EEG datasets to enhance the product offering? To this aid we rely on Self-Supervised Learning (SSL).

Self-Supervised Learning (SSL)

For context, here is an SSL primer, to introduce the broad language, techniques and ideas used in the literature.

The success of human or animal learning in sparse data environments — sometimes able to learn from a single example — is often attributed to common sense. Object permanence and understanding gravity, are two of the most widely cited examples of how humans parse information about the world that prime subsequent learning. Humans, even in infancy, develop an intuition of physics that dictates their understanding of reality.

Statistically, one can consider this common sense as a function that:

  1. Regularises the parameter space when learning a new task: allowing the individual to approximate generalisable functions despite data sparsity.
  2. Maps the new task to previously learnt ideas by some distance metrics: capturing the relationship between ideas.

The idea behind self-supervised learning is an attempt to emulate the acquisition of background knowledge (analogous to common sense) to greatly improve the efficiency of learning new tasks.

Basic Idea

Most SSL applications entail learning latent, generalisable, structures from large unlabeled data, and then subsequently exploiting this structure to boost performance on downstream tasks.

(Left) Autoencoders are regularly used in SSL models, where latent structure can be extracted by data compression. (Right) Autoencoders map the data off, and then back onto, the observable manifold. This is an example of a denoising autoencoder (compare the input and output images in the image on the left) that learns a smooth (regularised) transformation. The latent space f(x) represents some smooth properties about the data that can be exploited in downstream tasks.

Text vs Image Data

Recent developments have seen much greater success in natural language processing (NLP) related SSL models than comparable computer vision (CV) models. The current community consensus is that this results from the inability to adequately quantify uncertainty in CV.

Quantifying uncertainty in alternative model paradigms. NLP models can intuitively quantify uncertainty by computing the frequency of words in the text corpora. CV models, however, have infinite variability that adds great uncertainty & dimensionality. Contrastive methods (a subclass of SSL) is able to reduce this complexity. Click here for more detail.
Energy-based models (EBMs) are used as a framework to explain SSL. Directly analogous to loss-functions, EBMs take in a set of covariates X and response Y and compute a single value to quantify the relationship. Loosely based on principles from physics and information theory, if the energy is low the variables are said to be similar, and if the energy is high the variables are said to be distant. Training an EMB requires learning some approximate distance between data points. Although it is generally easy to ensure that nearby data points are mapped to low energy, it is difficult to ensure sufficient negative examples (geometrically far) data produce large energy (a problem known as EMB collapse).

Contrastive SSL: uses both positive and negative examples to learn the energy function.

Non-contrastive SSL: uses only positive examples in the data — primarily used to extract predictors/feature engineering.

Now we return to our problem: scaling medical machine learning.

Physiological Data

Physiological data regularly constitutes high-frequency signals: EEG, EKG or MEG; or high-resolution images (fMRI). fMRI more readily fits into the existing SSL literature — which is heavily centred around images — whereas neurological and physiological signals require more tailoring.

SSL Implementation

The SSL pipeline is relatively straightforward:

  1. Latent Space Mapping: z~ f(x)

Use large (unstructured) datasets, in conjunction with the primary (often labelled) dataset, to learn robust, generalisable, latent feature representations.

2. Downstream Tasks: y~ f(z)

Leverage the transformed space to perform downstream tasks that are applicable to the specific use case.

SSL essentially adds a layer of feature engineering conducted on additional data. Unsupervised learning is used to extract features from large datasets (with similar signals to those in the main dataset but without labels); these structures are used to learn generalisable mappings in the transformed space.

EEG SSL Applications

Returning to our example product, how exactly do we apply SSL to learn feature invariant EEG transformations?

Artifact Detection

One promising use of SSL in EEG is to detect and remove artifacts.

EEG datasets generally contain a series of montages, which may be offset to compute relative frequency activity. Nonetheless, a plethora of noisy occurrences distort the observable frequencies. The most common of which include: cardiac, pulse, respiratory, sweat, glossokinetic, eye movement (blink, lateral rectus spikes from lateral eye movement), and muscle and movement artifacts (EEG Intro).

Categorisation of EEG artifacts.

Signal Types: From a signal processing perspective it is useful to categorise artifacts by oscillatory properties:

a) Spikes: Sporadic high-frequency changes in the signal. Often a consequence of movement, muscular impulses, or other irregular activity. Observable by substantial changes in signal variation.

b) Prolonged disturbances: Electrical interferences due to auxiliary cortical regions, interference by machinery, myogenic cardiac signals etc.

Artifact Detection, Smoothing and Removal: We employ a number of artifact removal strategies, the most effective of which include:

  • Wavelets transformations, GAMs, Spline.
  • Variance models.
  • Fast Fourier Transforms.
  • Discrete-time warping (DTW).
  • Independent Component Analysis (ICA).
  • Matrix Profiling (MPs).
  • Slow Feature Analysis (SFA).

Applied to EEG-SSL: These strategies — largely unsupervised — can naturally be improved by adding additional data. Thus the latent space mapping can include all available data, and thereafter the feature transformation can be applied to downstream tasks.

An illustration of using mixing matrices (as done when deploying ICA, MPs, and SFA) to decouple frequencies captured during EEG.

Hyperarousal

Another application of EEG-SSL, Hyperarousals (observable in raw EEG data) provides insight into an individual’s resting state. Large-data based hyperarousal detection naturally lends itself to downstream models.

Particularly relevant in the realm of psychiatry, hyperarousal and spindles present themselves prolonged as elevated electrical signals. This phenomenon reveals prolonged states of agitated mental stimulation (prohibiting sufficient rest, often present in individuals with anxiety and depression).

This is where nuance (or simply large) datasets can become very powerful, the distinction between hyperarousal and artifacts can have substantial overlap.

Again, we employ a series of statistical learning algorithms to flag signals of hyperarousal. Many of which are identical to the aforementioned (though with different permissible ranges); coupled with fitting statistical distributions to the data to quantify significant deviations from the expected cortical behaviour.

Empirical experimentation has led us to believe that Gamma distributions can adequately model Cortical (Hyper) Arousal.

EEG-SSL example: We leverage large unlabeled data to fit Gamma distributions to quantify the variability in cortical activity, and then use the associated mappings to predict elevated prolonged anxiety.

Big Data Domain

As datasets scale, many traditional statistical time-series techniques suffer from severe computational limitations. EEG-SSL is scaled by leveraging deep learning, using neural networks as efficient universal function approximators in data-rich environments.

The aforementioned techniques are able to “encode” medical and psychological information by setting theoretically appropriate hyper-parameter configurations. This is mostly lost in the deep learning domain. There is increasing evidence, however, to suggest that the lack of medical and psychological information can be circumvented by the nuance information extracted by the raw data when datasets scale.

The conceptual methodology remains unchanged: we wish to extract meaningful (reduced) representations of the data to provide a suitable transformation to inprove the performance on downstream tasks.

Neural Network-based EEG-SSL Methods

3 broad approaches are generally employed:

  1. Data Compression
  2. Signal Segmentation
  3. Temporal Dependence

1. Data Compression

Recurrent neural nets (usually LSTMs) are regularly coupled with autoencoders to extract meaningful low dimensionality representations of variable-length sequences.

A schema of the raw waveform (including EEG) processing pipeline.

RNNs are used out of necessity, as traditional neural networks required fixed input dimensions. LSTMs (a special case of RNNs) are able to represent both long and short temporal dependencies, capturing the flow of information through time.

The autoencoder is then used to compress the data to its essential latent constituents. An autoencoder is comparable to PCA (learning orthogonal projections) with additional non-linear link functions and increased dimensionality.

Here are some great references deploying these techniques:

2. Signal Segmentation

Another emergent trend is reliance on generative models. Generative models (which include many autoencoder variants) allow us to capture assumptions about the data generating process.

We regularly assume that signals are generated by a series of basis functions (some global long-term trends; some local sporadic movements and some noise). Signal segmentation techniques aim to decompose signals into their constituents by learning the latent data generating process.

A graphical representation of a Hierarchical Variational Autoencoder (FHVAE) proposed by Zhang et al. The model is split into generative and inference models. The data X is assumed to be generated by 2 latent functions Z1 and Z2 (with Z2 dependent on μ_2). The inference model is used to extract this information from the observed data X. θ and ϕ are the respective free parameters. Although this is a very specific example, the principle may be extrapolated to any generative model configuration.
An illustration of denoising signals with the FHVAE. The reconstructed signal is able to recapitulate the true signal and remove the noise, by decomposing the signal into latent basis functions.

3. Temporal Dependence

Another, more novel, EEG-SSL implementation attempts to uncover temporal dependence in raw EEG data by pseudo-labelling sequences providing a simple distance metric. Banville et el propose an SSL algorithm that randomly samples EEG windows and then binarily labels them by their distance. A (CNN) classifier is then trained to estimate which signals are close to one another.

The idea is, nearby signals possess some similarities. For example, if using EEG to sleep stage, temporal dependence exists in that individuals gradually transition between states (from Wake → N1 → N2 → N3 → REM etc). The SSL can therefore be considered preprocessing, learning a latent representation that measures the ‘closeness’ between signal segments.

These features are used for downstream tasks (in this example, sleep staging), showing promising results.

A graphical representation of the algorithm proposed by Banville et al. Temporal dependence is extracted to train a classifier to estimate (binarily) if signals are near one another. To generate the data, random segments are sampled from the raw signal and labelled according to their absolute distance.

Summary

There’s a lot of information here! Let’s see if we can compress it to its latent constructs ;).

It is our belief, and experience, that combining both purely data-driven data interpolation and medical/psychological informed inference produces the best results.

Objective

We wish to use additional unlabeled data to boost the performance of EEG based algorithms.

This is achieved in 3 steps.

1. Given dataset X, and additionl unlabelled data D

2. Learning medically relevant feature transformations z ~ P(X,D)

3. Use the learnt transformation on the original set X to perform downstream tasks y ~ z(X)

Method

This is achieved by learning latent representations on ALL the available data. The additional information: (1) regularises the smaller labelled set, (2) produces more domain-general results, and (3) can be updated in light of new datasets.

Medical or Psychological Domain Expertise

We often use medical and psychological domain knowledge to define the context of the pre-training (SSL) tasks: an advantage seldom seen in the literature as it is less applicable to CV and NPL applications. We take advantage of the medical literature to guide hyper-parameterisation. Hyperarousal detection is one example of this.

These methods allow us to maximise the utility of available data resources. Turning raw physiology into information.

Additional References

The techniques employed are largely inspired by the Neurips 2021 SSL workshop & an illustrative article by Meta AI.

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Zach Wolpe

Zach Wolpe

Statistician, scientist, technologist — writing about stats, data science, math, philosophy, poetry & any other flavours that occupy my mind. Get in touch