You've learned about the building blocks of neural networks in Flux.jl, from individual layers to combining them into models. Now, it's time to put that knowledge into practice by constructing a Convolutional Neural Network (CNN) for a common task: image classification. This exercise will walk you through loading a dataset, defining a CNN architecture, and preparing it for training.
We'll use the well-known MNIST dataset, which consists of grayscale images of handwritten digits (0-9). Our goal is to build a CNN that can look at one of these images and correctly predict the digit it represents.
First, ensure you have the necessary Julia packages. If you've been following along, Flux.jl
should already be in your environment. For this exercise, we'll also need MLDatasets.jl
to easily access MNIST, MLUtils.jl
for data handling utilities, and potentially Images.jl
for image-specific transformations if we were working with raw image files (though MLDatasets
often provides data in a fairly ready-to-use format). We'll also use Statistics
for mean calculations and Random
for reproducibility.
using Flux
using MLDatasets
using MLUtils: DataLoader, flatten, onehotbatch, unsqueeze
using Statistics: mean
using Random: MersenneTwister # For reproducible results
using Printf # For formatted output
It's good practice to set a random seed if you want your results to be reproducible, especially during development and debugging.
const T = Float32 # Define a default float type for our data and model
Random.seed!(MersenneTwister(123)); # For reproducibility
MLDatasets.jl
makes loading standard datasets like MNIST straightforward.
# Load MNIST dataset
train_x_raw, train_y_raw = MNIST(T, :train)[:]
test_x_raw, test_y_raw = MNIST(T, :test)[:]
The raw image data train_x_raw
will typically be a 3D array of size (width, height, num_samples)
, for example, (28, 28, 60000)
for MNIST training images. CNNs in Flux usually expect input data in (width, height, channels, batch_size)
format, often referred to as WHCN. Since MNIST images are grayscale, they have one channel.
Let's reshape and preprocess the data:
# Reshape data to WHCN format (Width, Height, Channels, Batch)
# MNIST images are 28x28, grayscale (1 channel)
train_x = unsqueeze(train_x_raw, dims=3) # Add channel dimension: 28x28x1x60000
test_x = unsqueeze(test_x_raw, dims=3) # Add channel dimension: 28x28x1x10000
# Normalize pixel values to [0, 1] (already done by MLDatasets for MNIST, but good practice to be aware of)
# MNIST from MLDatasets loads data as Float32 in [0,1] range by default.
# If not, you might do: train_x = train_x ./ T(255.0)
# One-hot encode the labels
# Labels are 0-9, so 10 classes
train_y = onehotbatch(train_y_raw, 0:9) # Output: 10x60000
test_y = onehotbatch(test_y_raw, 0:9) # Output: 10x10000
# Create DataLoaders for batching
batch_size = 128
train_loader = DataLoader((train_x, train_y), batchsize=batch_size, shuffle=true)
test_loader = DataLoader((test_x, test_y), batchsize=batch_size) # No need to shuffle test data
Here, unsqueeze(data, dims=3)
adds a new dimension at the 3rd position, transforming our (28, 28, 60000)
array into (28, 28, 1, 60000)
, which is the WHCN
format Flux expects. onehotbatch
converts integer labels (e.g., 5
) into one-hot vectors (e.g., [0,0,0,0,0,1,0,0,0,0]
). DataLoader
from MLUtils.jl
is a convenient utility for iterating over data in mini-batches, which is essential for training neural networks efficiently.
Now, let's define the layers of our CNN. A common CNN structure for image classification includes:
Conv
) to detect features.relu
) to introduce non-linearity.MaxPool
) to downsample and reduce dimensionality.Flux.flatten
) to convert the 2D feature maps into a 1D vector.Dense
) for final classification.Here's an example architecture:
model = Chain(
# First convolutional block
Conv((5, 5), 1=>6, relu), # 28x28x1 -> 24x24x6 (assuming no padding)
MaxPool((2, 2)), # 24x24x6 -> 12x12x6
# Second convolutional block
Conv((5, 5), 6=>16, relu), # 12x12x6 -> 8x8x16
MaxPool((2, 2)), # 8x8x16 -> 4x4x16
# Flatten output and feed to dense layers
Flux.flatten, # 4x4x16 -> 256
Dense(256, 120, relu), # 256 -> 120
Dense(120, 84, relu), # 120 -> 84
Dense(84, 10) # 84 -> 10 (logits for 10 classes)
)
Let's break this down:
Conv((5, 5), 1=>6, relu)
: A convolutional layer with a 5x5 kernel, taking 1 input channel (grayscale) and producing 6 output channels (feature maps). relu
is applied as the activation function.MaxPool((2, 2))
: A max pooling layer with a 2x2 window, reducing the spatial dimensions by half.Conv
and MaxPool
block further processes the features.Flux.flatten
: Converts the multi-dimensional output of the convolutional/pooling layers (e.g., 4x4x16
) into a flat vector (e.g., 256
elements) suitable for dense layers.Dense
layers perform classification based on the learned features. The final Dense
layer has 10 outputs, corresponding to the 10 digit classes. We don't apply softmax
here because we'll use a loss function (logitcrossentropy
) that expects raw logits.You can visualize this architecture as a sequence of transformations:
The flow of data through our CNN, from input image to output logits for each class. N represents the batch size.
To train our model, we need a loss function to measure how wrong its predictions are and an optimizer to adjust its parameters.
# Loss function: logitcrossentropy is suitable for classification with raw logit outputs
loss(m, x, y) = Flux.logitcrossentropy(m(x), y)
# Optimizer: ADAM is a popular choice
opt_state = Flux.setup(Adam(0.001), model) # Adam learning rate 0.001
Flux.logitcrossentropy
is efficient and numerically stable for classification tasks where the model outputs raw scores (logits) rather than probabilities (which would be obtained after a softmax
). Adam
is an adaptive learning rate optimization algorithm that generally works well. Flux.setup
initializes the optimizer state for the given model parameters.
We can also define a helper function to calculate accuracy:
# Function to calculate accuracy
function accuracy(data_loader, model)
acc = 0
num_samples = 0
for (x_batch, y_batch) in data_loader
# Get model predictions (logits)
y_hat_batch = model(x_batch)
# Convert logits to class predictions (index of max logit)
# and one-hot labels to class indices
acc += sum(Flux.onecold(y_hat_batch) .== Flux.onecold(y_batch))
num_samples += size(x_batch, ndims(x_batch)) # Count samples in the batch
end
return acc / num_samples
end
Flux.onecold
converts one-hot encoded vectors back to integer labels, or finds the index of the maximum value in a prediction vector (logits or probabilities).
Let's check the accuracy of our untrained model. It should be around 10% (random guessing for 10 classes).
@printf("Initial test accuracy: %.2f%%\n", accuracy(test_loader, model) * 100)
You should see an output similar to:
Initial test accuracy: 9.80%
(The exact value might vary slightly due to random weight initialization).
The actual training loop involves iterating over the data multiple times (epochs), calculating the loss, computing gradients, and updating the model parameters using the optimizer. Chapter 4 will cover training loops, evaluation, and fine-tuning in detail. For now, let's see a very basic training step using Flux.train!
:
# A simplified training example for one epoch
# In Chapter 4, we'll build more comprehensive training loops.
# For demonstration, let's train for just a few batches to see the loss decrease.
# We'll use a manual loop here to illustrate the components.
# In practice for full training, Flux.train! or a custom loop over epochs is used.
num_epochs = 1 # For this practice, just one epoch to illustrate
@printf("Starting training for %d epoch(s)...\n", num_epochs)
for epoch in 1:num_epochs
epoch_loss = T(0.0)
batches_processed = 0
for (x_batch, y_batch) in train_loader
# Calculate loss and gradients
batch_loss, grads = Flux.withgradient(model) do m
loss(m, x_batch, y_batch)
end
# Update model parameters
Flux.update!(opt_state, model, grads[1])
epoch_loss += batch_loss
batches_processed += 1
# Optional: Print progress for a few batches
if batches_processed % 100 == 0
@printf("Epoch %d, Batch %d/%d: Batch Loss: %.4f\n",
epoch, batches_processed, length(train_loader), batch_loss)
end
end
avg_epoch_loss = epoch_loss / batches_processed
train_acc = accuracy(train_loader, model)
test_acc = accuracy(test_loader, model)
@printf("Epoch %d finished. Avg Loss: %.4f, Train Acc: %.2f%%, Test Acc: %.2f%%\n",
epoch, avg_epoch_loss, train_acc*100, test_acc*100)
end
Running this simplified training snippet for one epoch should show an increase in accuracy and a decrease in loss, indicating that the model is learning. For example, after one epoch, you might see something like:
Epoch 1 finished. Avg Loss: 0.2831, Train Acc: 94.75%, Test Acc: 94.99%
These results are promising for just one epoch! With more epochs and potentially hyperparameter tuning (covered later), performance can be further improved.
Once you have a trained model, you'll often want to save its structure and learned parameters. BSON.jl
is commonly used for this in the Julia ecosystem.
using BSON: @save, @load
# To save the model:
# Note: Saving only `model` saves the structure and parameters.
# For some optimizers or more complex scenarios, you might save opt_state too.
@save "mnist_cnn_model.bson" model
# To load the model later:
# @load "mnist_cnn_model.bson" loaded_model
# `loaded_model` will now contain the structure and parameters of your saved model.
This allows you to reuse your trained model for inference or further training without having to retrain from scratch.
This practice session guided you through the core steps of implementing a CNN in Flux.jl: data loading and preparation with MLUtils.jl
, defining a Chain
of layers including Conv
, MaxPool
, and Dense
, and setting up the loss function and optimizer. While we only touched upon the training process, you now have a solid foundation for building various neural network architectures. The next chapter will explore much deeper into the mechanics of training, evaluation, and refinement of these models.
Was this section helpful?
© 2025 ApX Machine Learning