Okay, you've built your neural network architectures using Flux.jl. Now, it's time to make them learn. The core of this learning process is the model training loop. This loop is where the model repeatedly processes data, measures its errors, and adjusts its internal parameters to improve. While Flux.jl provides high-level abstractions like Flux.train!
that encapsulate this loop, understanding its mechanics is valuable for debugging, customization, and grasping more advanced training procedures.
Training a deep learning model is an iterative refinement process. We don't just show the model the entire dataset once and expect it to learn everything. Instead, we typically make multiple passes over the data.
Epochs: An epoch represents one complete pass through the entire training dataset. For instance, if you have 10,000 training samples and you complete one epoch, your model has seen each of those 10,000 samples once. Training usually involves running for multiple epochs, allowing the model to see the data repeatedly and learn more complex patterns.
Batches: Processing the entire dataset at once can be computationally expensive and memory-intensive, especially for large datasets. It can also lead to a less stable learning process. To manage this, the dataset is commonly divided into smaller, more manageable chunks called mini-batches. The model processes one mini-batch at a time, calculates the error, and updates its weights.
batch_size
. Common batch sizes range from 32 to 256, but can vary depending on the dataset size and available memory.In Julia, you'll often use a data loader, such as those provided by MLUtils.jl
(which we touched upon in Chapter 3), to iterate through your dataset in mini-batches.
Within each epoch, the model iterates through all the mini-batches. For each mini-batch, several essential operations occur:
The sequence of operations performed for each mini-batch within the training loop.
Obtain a Mini-Batch: The data loader provides the next mini-batch of input features (e.g., x_batch
) and corresponding target labels (e.g., y_batch
).
Forward Pass: The input features (x_batch
) are fed into the neural network. The network processes these inputs through its layers, applying weights and activation functions, to produce predictions.
predictions = model(x_batch)
Here, model
is your Flux.jl neural network.
Calculate Loss: The predictions are compared against the true target labels (y_batch
) using a predefined loss function (e.g., Flux.crossentropy
for classification, Flux.mse
for regression). The loss quantifies how "wrong" the model's predictions are for the current batch.
current_loss = loss_function(predictions, y_batch)
A lower loss value generally indicates better performance on that batch.
Compute Gradients (Backward Pass): This is where automatic differentiation comes into play, typically handled by Zygote.jl in the Flux ecosystem. We calculate the gradients of the loss function with respect to all trainable parameters in the model (weights and biases). These gradients indicate the direction and magnitude of change needed for each parameter to reduce the loss.
# ps contains the model's trainable parameters (e.g., Flux.params(model))
grads = Flux.gradient(() -> loss_function(model(x_batch), y_batch), ps)
The Flux.gradient
function takes an anonymous function (the first argument) that computes the loss, and the parameters ps
(obtained, for example, via Flux.params(model)
) for which gradients are needed. It returns a Zygote.Grads
object containing the gradients.
Update Parameters (Optimization Step): The optimizer (e.g., Adam, SGD) uses the computed gradients to update the model's parameters. The learning rate, a hyperparameter of the optimizer, controls the size of these updates. The goal is to nudge the parameters in the direction that minimizes the loss.
# opt is your optimizer (e.g., ADAM())
Flux.Optimise.update!(opt, ps, grads)
The Flux.Optimise.update!
function modifies the parameters in ps
based on the gradients grads
and the optimizer's logic.
(Implicit) Gradient Reset: For most standard training loops, the gradients are calculated afresh for each batch based on that batch's specific loss. Zygote's Flux.gradient
computes gradients for the specific function call, so there's no persistent gradient accumulation across batches unless you explicitly design your loop to do so. After the parameters are updated, the gradients from the current batch have served their purpose.
These steps are repeated for every mini-batch in the training dataset. Once all mini-batches have been processed, one epoch is complete. The entire process (all epochs) continues until a stopping criterion is met, such as a predefined number of epochs or when the model's performance on a validation set stops improving.
Let's see how these components fit together in a simplified Julia code structure. Assume model
, loss_function
, opt
(optimizer), and train_loader
(your batched data) are already defined.
# Assume:
# model = Your Flux model (e.g., Chain(...))
# loss_function = Your chosen loss (e.g., Flux.crossentropy)
# opt = Your optimizer (e.g., ADAM())
# train_loader = Your data loader providing (x_batch, y_batch) tuples
# num_epochs = Number of epochs to train for
# Get model parameters
ps = Flux.params(model)
for epoch in 1:num_epochs
epoch_loss = 0.0
num_batches = 0
for (x_batch, y_batch) in train_loader
# Step 4: Compute Gradients
# The loss calculation (Step 3) is inside the gradient computation
# The forward pass (Step 2) is also implicitly inside
current_loss, grads = Flux.withgradient(ps) do
# Forward pass and loss calculation
predictions = model(x_batch)
loss_function(predictions, y_batch)
end
# Step 5: Update Parameters
Flux.Optimise.update!(opt, ps, grads)
epoch_loss += current_loss
num_batches += 1
end
average_epoch_loss = epoch_loss / num_batches
println("Epoch: $epoch, Average Loss: $average_epoch_loss")
end
println("Training finished.")
In this example, Flux.withgradient
is a useful function that computes the value of the provided anonymous function (which is our loss) and its gradient with respect to ps
simultaneously. This is often more efficient than calling Flux.gradient
separately after calculating the loss if you also need the loss value.
This loop iterates through epochs and then through batches within each epoch. For every batch, it calculates the loss and gradients, then updates the model parameters. It also accumulates the loss to print an average for the epoch, giving you a basic way to monitor training progress.
While Flux.jl offers Flux.train!(loss, params, data, opt; cb = ...)
as a high-level utility to handle this entire loop, breaking it down, as we've done, provides several benefits:
NaN
, model doesn't learn), understanding each step allows you to inspect intermediate values (predictions, loss, gradients) and pinpoint issues more effectively.Flux.train!
are doing under the hood, making their usage less of a "black box."As you progress, you'll see how this fundamental loop structure serves as the foundation for training various types of models and for implementing more sophisticated training strategies, which we will cover in subsequent sections of this chapter, like using callbacks and evaluating model performance more formally.
Was this section helpful?
© 2025 ApX Machine Learning