Debugging PyTorch Transfer Learning Jobs

Zach Wolpe
3 min readDec 10, 2024

--

Before computing the gradients (and performing backpropergation) requires_grad==False on b.

PyTorch is a marvel of engineering. Every time I use it its design blows me away.

A well-designed tool is makes it easy to get started, but allows power users to discover advanced features when needed.

Pytorch achieves this with its modular design and variable levels of abstraction.

Like all software projects, this flexibility can come at a cost if code complexity is not managed.

Transfer Learning

Transfer learning & fine tuning is straight forward in theory:

  1. Take a pre-trained model.
  2. Freeze the majority of the weights, usually only allowing the last layer to update.
  3. Train the unfrozen nodes in the network on custom data.

Transfer Learning in PyTorch

To get started with transfer learning, see this tutorial: Transfer Learning in Pytorch.

  1. Load a pre-trained model.
  2. Prepare a data loader.
  3. Freeze all layers except the final (or desired) layers.
  4. Define a loss function, optimizer and scheduler (optional).
  5. Define a training loop.
  6. Evaluate and save the updated model.

This is a straightforward process if the model and its supporting code are easy to follow. As a project grows in complexity, bugs are inevitable.

Potential Bugs

In practice, several bugs — particularly relating to the PyTorch Computational Graph — can arise:

Solution

To mitigate these issues, I’ve written a Pytorch Transfer Learning Debugger.

To provide an interactive debugging experience for PyTorch code, enabling users to diagnose issues related to weight updates during training, as well as to address problems with slow or unstable convergence.

Example

Building off this example, here is how to implement the debugger. See the full example on GitHub or PyPI.

Purpose: An interactive debugging tool for PyTorch code that helps diagnose issues related to weight updates during training and addresses problems with slow or unstable convergence.

Key Components

1. Debug Mode Features

The debugger is implemented through a torch_debugger class that provides two main debugging modes:

  1. Basic Debug Mode (debug_mode=True):
  • Tracks model weight updates
  • Only monitors weights with requires_grad=True
  • Populates torch_debugger._track_weights dictionary for analysis

2. Granular Debug Mode (granular_logging=True):

  • Exits after one cycle through the dataloader.

Logs each stage of the computational graph:

  • Optimizer verification
  • Gradient zeroing
  • Data movement to the device
  • Forward pass and loss computation
  • Gradient computation and optimizer steps

2. Implementation Example

Here’s how the debugger is integrated into the training loop:

def train_model(model, criterion, optimizer, scheduler, num_epochs=25, 
debug_mode=True, granular_logging=False):
# Initialize debugger if debug_mode is enabled
torch_debugger_inst = None
if debug_mode:
torch_debugger_inst = torch_debugger(enable_log_stage=granular_logging)
torch_debugger_inst.verify_optimizer_state(optimizer)
torch_debugger_inst.initial_model_weights_state = model.state_dict().copy()

# Training loop with debugging hooks
for epoch in range(num_epochs):
for phase in ['train', 'val']:
for inputs, labels in dataloaders[phase]:
if debug_mode:
torch_debugger_inst.log_stage('Moving data to device')
# ... more logging stages ...

if phase == 'train':
torch_debugger_inst.track_weights(model.named_parameters())

3. Usage

To use the debugger, launch a training job (that returns the tracked results and model):

torch_debugger_inst, model = train_model(
model_ft,
criterion,
optimizer_ft,
exp_lr_scheduler,
num_epochs=25,
debug_mode=True,
granular_logging=True # For detailed stage logging
)

4. Analysis Tools

The debugger provides visualization tools to analyze training:

# Plot weight updates over time
torch_debugger_inst.plot_weight_updates()

# Access tracked weights
weights = torch_debugger_inst._track_weights['conv1.weight']

5. Result

We can track the convolutional layer weights during training — and observe steady, progressive optimisation with some sampling noise/exploration of the search space.

6. Benefits

  • Helps identify if weights are properly updating during training
  • Provides visibility into each stage of the training process
  • Enables early detection of training issues
  • Supports both high-level and granular debugging approaches

This debugging tool is particularly useful when:

  • Troubleshooting training convergence issues
  • Verifying proper weight updates
  • Understanding the flow of data through the model
  • Diagnosing optimization problems

--

--

Zach Wolpe
Zach Wolpe

Written by Zach Wolpe

Machine Learning Engineer. Writing for fun.

No responses yet