requires_grad
)backward()
).grad
)torch.nn
torch.nn.Module
Base Classtorch.nn
losses)torch.optim
)torch.utils.data.Dataset
torchvision.transforms
)torch.utils.data.DataLoader
As we've seen, calling .backward()
on a scalar tensor (like a loss value) triggers the computation of gradients for all tensors in the computation graph that have requires_grad=True
. A significant behavior to understand is that PyTorch accumulates gradients by default.
When you call .backward()
multiple times without clearing the gradients in between, PyTorch adds the newly computed gradients to the existing values stored in the .grad
attribute of the leaf tensors (parameters).
Let's illustrate this with a simple example:
import torch
# Create a tensor that requires gradients
x = torch.tensor([2.0], requires_grad=True)
# Perform some operations
y = x * x
z = y * 3 # z = 3 * x^2
# First backward pass
# dz/dx = 6*x = 6*2 = 12
z.backward(retain_graph=True) # retain_graph=True allows subsequent backward calls
print(f"After first backward pass, x.grad: {x.grad}")
# Perform another operation (can be the same or different)
# For simplicity, let's use the same z again for demonstration
# Note: In a real scenario, you'd likely compute a new loss
# based on a new input or different part of the model.
z.backward() # Second backward pass
# We expect the new gradient (12) to be added to the existing one (12)
print(f"After second backward pass, x.grad: {x.grad}")
# Manually zero the gradient
x.grad.zero_()
print(f"After zeroing, x.grad: {x.grad}")
Running this code produces output similar to:
After first backward pass, x.grad: tensor([12.])
After second backward pass, x.grad: tensor([24.])
After zeroing, x.grad: tensor([0.])
Notice how the second call to z.backward()
added the newly computed gradient (12) to the previously stored gradient (12), resulting in 24. This accumulation is intentional and has important applications.
The primary reason for this default behavior is to facilitate gradient accumulation. This technique is useful when training large models that require large batch sizes for stable convergence, but the available GPU memory cannot accommodate such large batches at once.
Instead of processing one large batch, you can:
.backward()
on the loss for the current mini-batch. The gradients computed for this mini-batch will be added to the .grad
attributes of the model parameters.optimizer.step()
. This step updates the model weights using the sum of gradients from all the mini-batches, effectively simulating a single update step for the larger batch size.optimizer.zero_grad()
before starting the processing for the next large batch (or the next mini-batch if not accumulating).This allows you to train with the effective batch size your model needs, even if it doesn't fit into memory all at once, trading off computation time for memory efficiency.
optimizer.zero_grad()
in Standard TrainingIn a standard training loop where you process one batch, calculate loss, compute gradients, and update weights in each iteration, you typically do not want gradients from previous batches to influence the current update step. Each batch's gradient calculation should be independent.
Because PyTorch accumulates gradients by default, failing to clear them before computing gradients for a new batch would lead to incorrect updates. The optimizer would use a mixture of old and new gradients, corrupting the training process.
This is why you almost always see optimizer.zero_grad()
called within a standard PyTorch training loop. It resets the .grad
attribute of all parameters the optimizer is managing, ensuring that the subsequent .backward()
call computes gradients solely based on the current batch's loss.
A typical training iteration structure looks like this:
# Assume model, dataloader, loss_fn, and optimizer are defined
# Loop over epochs...
# Loop over batches...
# 1. Get data batch
inputs, labels = data_batch
inputs, labels = inputs.to(device), labels.to(device) # Move data to appropriate device
# 2. Zero the gradients
# IMPORTANT: Clear previous gradients before processing the new batch
optimizer.zero_grad()
# 3. Forward pass: Compute model predictions
outputs = model(inputs)
# 4. Calculate the loss
loss = loss_fn(outputs, labels)
# 5. Backward pass: Compute gradients
loss.backward()
# 6. Optimizer step: Update model weights
optimizer.step()
# ... (logging, evaluation, etc.)
The placement of optimizer.zero_grad()
is important. It should happen before you calculate the loss and perform the backward pass for the current iteration, ensuring a clean slate for the gradient computation of the current batch. While often placed at the beginning of the loop, it technically only needs to happen before loss.backward()
. However, placing it at the start is common practice and clearly delineates the start of processing for a new batch.
In summary, gradient accumulation is a built-in PyTorch feature useful for simulating larger batch sizes. However, in standard training loops, you must explicitly prevent this accumulation by calling optimizer.zero_grad()
at the beginning of each iteration to ensure correct model updates based only on the current batch's data.
© 2025 ApX Machine Learning