In the process of training a neural network, we need a way to quantify how well our model is performing. Specifically, we need to measure the discrepancy between the model's predictions and the actual target values. This measure is provided by a loss function, also known as an error function or cost function. It computes a single scalar value that represents the "badness" of the model's current predictions for a given set of data. The smaller the loss, the better the model is performing. This loss value is fundamental because it's what the optimizer will try to minimize by adjusting the model's parameters.
Flux.jl offers a suite of pre-defined loss functions, conveniently accessible within the Flux
module. The choice of loss function is highly dependent on the nature of your machine learning task, primarily whether it's a regression or a classification problem.
When your model aims to predict continuous numerical values, you'll typically use a regression loss function.
For regression tasks, where the goal is to predict continuous values, Mean Squared Error is a very common choice. It calculates the average of the squared differences between predicted values (y^i) and true values (yi).
LMSE=N1i=1∑N(yi−y^i)2Where N is the number of samples. Squaring the error term has two effects: it ensures the error is always positive, and it penalizes larger errors more significantly than smaller ones. In Flux, you can use Flux.mse(y_pred, y_true)
.
Another popular loss function for regression is Mean Absolute Error. MAE computes the average of the absolute differences between predictions and true values.
LMAE=N1i=1∑N∣yi−y^i∣MAE measures the average magnitude of the errors in a set of predictions, without considering their direction. Unlike MSE, MAE's penalty for errors increases linearly. This makes MAE less sensitive to outliers compared to MSE. You can use it in Flux with Flux.mae(y_pred, y_true)
.
Comparison of how Mean Squared Error (MSE) and Mean Absolute Error (MAE) penalize prediction errors of different magnitudes. MSE's quadratic nature results in a steeper increase in loss for larger errors.
For tasks where the model predicts a class label from a set of discrete categories, classification loss functions are appropriate.
For binary classification problems, where there are only two possible output classes (e.g., 0 or 1, spam or not-spam), binary cross-entropy is the standard loss function. It measures the performance of a classification model whose output is a probability value between 0 and 1.
LBCE=−N1i=1∑N[yilog(p^i)+(1−yi)log(1−p^i)]Here, yi is the true label (0 or 1) and p^i is the predicted probability for class 1. This loss function is minimized when the predicted probability p^i is close to the true label yi.
In Flux, you use Flux.binarycrossentropy(y_pred_probs, y_true)
. It's important that y_pred_probs
are probabilities, typically obtained by passing the raw model outputs (logits) through a sigmoid activation function. Flux's Flux.logitbinarycrossentropy(y_pred_logits, y_true)
can directly take logits as input, which is often more numerically stable and is generally recommended.
When dealing with multi-class classification problems (more than two classes), categorical cross-entropy is the go-to loss function. It measures the dissimilarity between the true class distribution and the predicted probability distribution over the classes.
LCCE=−N1i=1∑Nj=1∑Cyijlog(p^ij)In this formula, N is the number of samples, C is the number of classes, yij is a binary indicator (1 if sample i belongs to class j, 0 otherwise, usually one-hot encoded), and p^ij is the model's predicted probability that sample i belongs to class j.
The predictions p^ij are typically the output of a softmax activation function applied to the final layer of the network, ensuring they sum to 1 across all classes for each sample. Flux provides Flux.crossentropy(y_pred_probs, y_true_onehot)
. Similar to binary cross-entropy, Flux also offers Flux.logitcrossentropy(y_pred_logits, y_true_onehot)
which accepts raw logits and applies the log-softmax transformation internally for improved numerical stability and efficiency. This is the preferred function when working with raw model outputs. For y_true_onehot
, you'll usually represent your target labels using one-hot encoding (e.g., using Flux.onehot
or Flux.onehotbatch
).
The loss function is the final computation in the forward pass of your neural network during training. It distills all the information about the model's performance on a batch of data into a single number. This scalar loss value is then used by Zygote.jl, Flux's automatic differentiation engine, to calculate the gradients of the loss with respect to all the model's parameters (weights and biases). These gradients indicate how each parameter should be adjusted to reduce the loss, forming the basis for the optimizer's update step. The choice and correct implementation of the loss function are therefore directly tied to the success of the training process.
Selecting an appropriate loss function is determined by your specific problem:
Flux.mse
: A good default. Sensitive to outliers.Flux.mae
: More to outliers than MSE.Flux.binarycrossentropy
: Use when model outputs are probabilities (after a sigmoid).Flux.logitbinarycrossentropy
: Preferred when model outputs are raw logits (before sigmoid), as it's more numerically stable. Target labels should be 0 or 1.Flux.crossentropy
: Use when model outputs are probability distributions (after a softmax). Target labels should be one-hot encoded.Flux.logitcrossentropy
: Preferred when model outputs are raw logits (before softmax). Target labels should be one-hot encoded. This function combines the softmax activation and cross-entropy calculation for better numerical stability.Always ensure that the output activation function of your model (or lack thereof, if using logit-based losses) is compatible with the chosen loss function.
Let's see a quick example of how to use these functions in Flux:
using Flux
using Statistics: mean # For a custom loss example later
# Sample data for a regression task
y_actual_reg = [1.5f0, 2.0f0, 3.5f0]
y_predicted_reg = [1.3f0, 2.4f0, 3.1f0]
# Calculate Mean Squared Error
mse_loss = Flux.mse(y_predicted_reg, y_actual_reg)
println("MSE Loss: ", mse_loss)
# Calculate Mean Absolute Error
mae_loss = Flux.mae(y_predicted_reg, y_actual_reg)
println("MAE Loss: ", mae_loss)
# Sample data for a multi-class classification task (3 samples, 2 classes)
# Raw model outputs (logits)
# Dimensions: (number_of_classes, number_of_samples)
y_predicted_logits_clf = Float32[0.2 0.8; -0.5 0.5; 1.2 -0.1]'
# True labels (as one-hot vectors)
# Sample 1: Class 2, Sample 2: Class 1, Sample 3: Class 1
y_actual_clf = Flux.onehotbatch([2, 1, 1], 1:2) # Creates a 2x3 OneHotMatrix
# Calculate Categorical Cross-Entropy from logits
# Flux.logitcrossentropy expects (logits, targets)
ce_loss = Flux.logitcrossentropy(y_predicted_logits_clf, y_actual_clf)
println("Categorical Cross-Entropy Loss: ", ce_loss)
# For binary classification (1 sample, predicting logit for class 1)
y_predicted_logit_bin = 0.7f0 # Raw logit output for a single sample
y_actual_bin = 1.0f0 # True label is 1 (float, as expected by logitbinarycrossentropy)
# Calculate Binary Cross-Entropy from logit
bce_loss = Flux.logitbinarycrossentropy(y_predicted_logit_bin, y_actual_bin)
println("Binary Cross-Entropy Loss (from logit): ", bce_loss)
In these examples, y_predicted_reg
and y_predicted_logits_clf
would typically be the outputs of your Flux model (e.g., model(input_data)
). The actual labels y_actual_reg
and y_actual_clf
come from your dataset.
While Flux provides a comprehensive set of standard loss functions, Julia's flexibility allows you to easily define your own if your problem requires a specialized measure of error. A custom loss function in Flux is just a regular Julia function that takes the model's predictions and the true targets as input and returns a single scalar value representing the loss.
For example, if you needed a weighted version of Mean Squared Error, you could define it as:
function my_weighted_mse(y_pred, y_true, weights)
# Ensure weights, y_pred, and y_true are broadcastable
# and compute the mean of weighted squared errors.
return mean(weights .* (y_pred .- y_true).^2)
end
# Example usage (assuming model_outputs, y_labels, sample_weights are defined):
# model_outputs = model(x_batch)
# loss_value = my_weighted_mse(model_outputs, y_batch_labels, batch_sample_weights)
This custom function can then be used in your training loop just like any built-in Flux loss function. Zygote.jl will be able to differentiate through it, provided all operations within it are themselves differentiable. This composability is a powerful feature of the Julia ecosystem for deep learning.
Was this section helpful?
© 2025 ApX Machine Learning