As you've learned to construct neural network architectures and understand the core mechanics of a training loop, you'll find that the loop itself can become quite busy. In addition to calculating gradients and updating weights, you'll want to monitor progress, save your best models, stop training if it's no longer productive, and perhaps adjust parameters like the learning rate dynamically. Stuffing all this logic directly into your main training code can lead to a tangled mess. This is precisely where callbacks shine, offering a clean, modular way to inject custom actions into the training lifecycle.
Callbacks are essentially functions that Flux.jl can execute at predefined points during the training process, for instance, at the beginning or end of an epoch, or even after processing each mini-batch. They act as hooks, allowing you to extend the training functionality without directly modifying the core Flux.train!
logic. This approach keeps your primary training script focused and makes your auxiliary tasks reusable and easier to manage.
Employing callbacks in your training routine brings several practical benefits:
By delegating these tasks to callbacks, your main training loop remains uncluttered and dedicated to the essential learning process.
Flux.jl's Flux.train!
function supports callbacks through its cb
keyword argument. You typically pass an array of functions or callable objects, which Flux will execute at appropriate intervals, usually after each training step (parameter update).
# A simplified view of using callbacks with Flux.train!
# Flux.train!(loss_function, parameters, data_iterable, optimizer, cb = my_callbacks)
Each element in my_callbacks
is a function that Flux.train!
will call. Flux also provides utilities like Flux.throttle
that are often used in conjunction with callbacks to control the frequency of actions, such as logging.
Let's explore some widely used callback patterns and how you might implement them.
Frequently, you'll want to log the training loss. Flux.throttle
is a handy utility here. It wraps a function and ensures it's only called if a certain amount of time has passed since its last execution.
using Flux, Logging, Statistics
# Assume model, loss_function, train_loader, and opt are defined
# loss_function(x, y) calculates the loss for a batch
# Callback to log the current batch loss, throttled to every 5 seconds
# Note: For this to work directly in Flux.train!, loss_function needs to be
# defined in a scope where x_batch and y_batch are accessible, or more typically,
# the loss is calculated inside the Flux.train! call itself.
# A more common pattern for Flux.train! is to pass it the loss function directly,
# and Flux itself handles evaluation. The callback then might not recalculate loss.
# Let's show a simple info log:
iter_count = 0
log_callback = () -> begin
global iter_count += 1
# You would typically get the loss from Flux or calculate it if needed
# For this example, let's just log the iteration.
@info "Iteration: $(iter_count)"
end
# Throttle the logging callback to run at most once every 5 seconds
throttled_logger = Flux.throttle(log_callback, 5)
# Example usage (actual data flow depends on your loop)
# Flux.train!(loss_function, params(model), train_loader, opt, cb = throttled_logger)
When Flux.train!
is called, throttled_logger
will be invoked after parameter updates. Flux.throttle
ensures that the log_callback
logic doesn't execute too frequently, preventing your console from being overwhelmed with messages.
For epoch-level logging or more complex metrics, you often write a more structured training loop around Flux.gradient
and Flux.update!
, giving you explicit points to call your custom epoch-end callbacks.
Saving your model during a long training run is essential. A checkpointing callback can automate this, saving the model whenever, for instance, the validation loss improves.
using Flux, BSON
using Dates # For timestamping filenames
# Assume `model` is your Flux model
# `val_loss_fn()` is a function that computes validation loss
# `current_epoch` is tracked by your training loop
# State for the checkpointing callback (often encapsulated in a callable struct)
best_validation_loss = Float32(Inf)
checkpoint_dir = "model_checkpoints/"
mkpath(checkpoint_dir) # Create directory if it doesn't exist
function save_best_model_callback(epoch_num::Int) # Pass current epoch
global best_validation_loss # Access global state (or use a callable struct)
current_val_loss = val_loss_fn() # User-defined function to get validation loss
@info "Epoch $(epoch_num) - Validation Loss: $(current_val_loss)"
if current_val_loss < best_validation_loss
best_validation_loss = current_val_loss
model_cpu = cpu(model) # Move model to CPU before saving
timestamp = Dates.format(now(), "YYYYmmdd_HHMMSS")
filename = joinpath(checkpoint_dir, "model_epoch_$(epoch_num)_val_$(best_validation_loss)_$(timestamp).bson")
BSON.@save filename model_state=Flux.state(model_cpu) epoch=epoch_num val_loss=best_validation_loss
@info "Checkpoint saved: $(filename)"
end
end
# This callback would be called at the end of each epoch in a custom training loop.
# If using Flux.train!, you'd need to adapt it or use a helper library
# as Flux.train!'s `cb` is per-step.
In this example, save_best_model_callback
checks the validation loss and saves the model (its state) using BSON.jl
if a new best loss is achieved. Using Flux.state(model)
is often preferred over saving the entire model object directly, as it's more reliable to changes in Flux versions or model definitions.
The following diagram shows where different types of callbacks can integrate into the training process:
Callbacks provide hooks at various stages of training: overall start/end, epoch start/end, and batch start/end. This allows for fine-grained control and monitoring.
Early stopping is a critical callback for preventing overfitting and saving computation. It halts training if a monitored metric (like validation loss) stops improving for a "patience" number of epochs.
While Flux.jl
itself doesn't provide a built-in EarlyStopping
callback that directly plugs into Flux.train!
to halt it, you can easily implement the logic if you're writing your own training loop, or use libraries like FluxTraining.jl which offer this.
Here's how you might structure an EarlyStopper
as a callable struct:
mutable struct EarlyStopper
patience::Int
min_delta::Float64
best_val_loss::Float64
epochs_without_improvement::Int
verbose::Bool
triggered::Bool # To signal stopping
end
EarlyStopper(patience::Int=5, min_delta::Float64=0.001; verbose::Bool=true) =
EarlyStopper(patience, min_delta, Inf, 0, verbose, false)
function (es::EarlyStopper)(current_val_loss::Float64; epoch::Int=0)
if es.triggered return true end # Already triggered
if es.verbose
@info "Epoch $epoch: EarlyStopper checking val_loss $(current_val_loss). Best: $(es.best_val_loss). No improvement for $(es.epochs_without_improvement) epochs."
end
if current_val_loss < es.best_val_loss - es.min_delta
es.best_val_loss = current_val_loss
es.epochs_without_improvement = 0
if es.verbose
@info "Epoch $epoch: Val_loss improved to $(es.best_val_loss)."
end
else
es.epochs_without_improvement += 1
if es.epochs_without_improvement >= es.patience
if es.verbose
@info "Epoch $epoch: Early stopping triggered after $(es.epochs_without_improvement) epochs without improvement (val_loss: $(current_val_loss))."
end
es.triggered = true
return true # Signal to stop
end
end
return false # Signal to continue
end
# Usage in a custom loop:
# stopper = EarlyStopper(patience=3, verbose=true)
# for epoch = 1:num_epochs
# # ... training for one epoch ...
# val_loss = calculate_validation_loss() # Your function
# if stopper(val_loss, epoch=epoch)
# @info "Stopping training early."
# break
# end
# end
This EarlyStopper
keeps track of the best validation loss and the number of epochs since the last improvement. If patience is exceeded, it signals that training should stop.
The chart below illustrates how early stopping can prevent overfitting by halting training when validation loss begins to degrade.
Training is halted at Epoch 10 when validation loss (orange) starts increasing, even as training loss (blue) continues to decrease. This prevents the model from further overfitting to the training data.
Callbacks can also manage learning rate adjustments. For instance, you might want to reduce the learning rate if the validation loss plateaus.
using Flux
# Assume `opt` is your optimizer, e.g., opt = Adam(0.001)
# And `val_loss_fn()` calculates validation loss
mutable struct ReduceLROnPlateau
optimizer::Any
factor::Float64
patience::Int
min_lr::Float64
best_val_loss::Float64
epochs_without_improvement::Int
verbose::Bool
end
ReduceLROnPlateau(optimizer; factor=0.1, patience=3, min_lr=1e-6, verbose=true) =
ReduceLROnPlateau(optimizer, factor, patience, min_lr, Inf, 0, verbose)
function (rlrop::ReduceLROnPlateau)(current_val_loss::Float64; epoch::Int=0)
if current_val_loss < rlrop.best_val_loss
rlrop.best_val_loss = current_val_loss
rlrop.epochs_without_improvement = 0
else
rlrop.epochs_without_improvement += 1
if rlrop.epochs_without_improvement >= rlrop.patience
# Check if optimizer has 'eta' field (common for learning rate)
# This is simplified; real optimizers might have nested structures.
current_lr = Flux.Optimise.getattr(rlrop.optimizer, :eta)
if current_lr === nothing # Try to get from first optimizer in a chain
if !isempty(rlrop.optimizer.os) && Flux.Optimise.getattr(rlrop.optimizer.os[1], :eta) !== nothing
current_lr = Flux.Optimise.getattr(rlrop.optimizer.os[1], :eta)
end
end
if current_lr !== nothing && current_lr > rlrop.min_lr
new_lr = max(current_lr * rlrop.factor, rlrop.min_lr)
Flux.Optimise.setattr!(rlrop.optimizer, :eta, new_lr) # This might need adjustment for specific optimizers
if !isempty(rlrop.optimizer.os) && Flux.Optimise.getattr(rlrop.optimizer.os[1], :eta) !== nothing
Flux.Optimise.setattr!(rlrop.optimizer.os[1], :eta, new_lr)
end
if rlrop.verbose
@info "Epoch $epoch: Reducing learning rate from $current_lr to $new_lr due to plateau."
end
rlrop.epochs_without_improvement = 0 # Reset patience after reduction
end
end
end
end
# Usage in a custom loop:
# lr_scheduler = ReduceLROnPlateau(opt, factor=0.5, patience=2)
# for epoch = 1:num_epochs
# # ... training ...
# val_loss = calculate_validation_loss()
# lr_scheduler(val_loss, epoch=epoch)
# end
This ReduceLROnPlateau
callback monitors validation loss and reduces the optimizer's learning rate if no improvement is seen for a set number of epochs. Accessing and modifying the learning rate (opt.eta
or similar) depends on the specific optimizer being used. Flux's OptimiserChain
requires careful handling to adjust the learning rate of the appropriate internal optimizer.
When using Flux.train!
, you can pass an array of callback functions. In a custom loop, you would simply call each registered callback at the appropriate stage.
# Example with a custom loop structure
# all_callbacks = [logging_cb, checkpointing_cb, early_stopping_cb, lr_scheduler_cb]
# for epoch in 1:max_epochs
# # ... run training for one epoch ...
# # At end of epoch:
# current_val_loss = calculate_validation_loss()
# for cb in all_callbacks
# if cb isa EarlyStopper || cb isa ReduceLROnPlateau # Callbacks needing val_loss
# if cb(current_val_loss, epoch=epoch) && cb isa EarlyStopper && cb.triggered
# # Handle early stop signal
# break_training_loop = true
# end
# else # Other callbacks might not need val_loss or epoch
# cb()
# end
# end
# if break_training_loop break end
# end
best_val_loss
, epochs_without_improvement
), callable structs are an excellent way to encapsulate this state.on_epoch_begin
, on_batch_end
).Callbacks are indispensable for orchestrating complex training processes efficiently. They allow you to automate monitoring, adapt training dynamically, and ensure your model development is both effective and insightful. As you build more models, you'll likely create a personalized toolkit of callbacks that streamline your common deep learning tasks in Julia.
Was this section helpful?
© 2025 ApX Machine Learning