Now that we've explored the theoretical underpinnings of disentangled representation learning, including various model modifications and evaluation metrics, it's time to put this knowledge into practice. This hands-on section will guide you through training a VAE variant aimed at achieving better disentanglement, specifically, the β-VAE, and then evaluating its success using some of the metrics we discussed earlier, such as the Mutual Information Gap (MIG) and Separated Attribute Predictability (SAP).
The goal here isn't to provide a complete, copy-paste codebase, but rather to outline the essential steps and considerations, allowing you to experiment and deepen your understanding. We assume you're comfortable implementing a standard VAE in a framework like PyTorch or TensorFlow.
For disentanglement experiments, synthetic datasets with known ground-truth factors of variation are invaluable. The dSprites dataset is a popular choice. It consists of 2D shapes (squares, ellipses, hearts) generated from 6 independent latent factors: color (always white), shape, scale, orientation, X-position, and Y-position. Having access to these true factors allows us to quantitatively measure how well our model disentangles them.
You'll need your standard deep learning toolkit:
mutual_info_regression
or for training simple classifiers).Recall from our discussion that the β-VAE modifies the standard VAE objective by introducing a coefficient β to the KL divergence term:
Lβ−VAE=Eqϕ(z∣x)[logpθ(x∣z)]−β⋅DKL(qϕ(z∣x)∣∣p(z))A β>1 places a stronger constraint on the KL divergence, encouraging the approximate posterior qϕ(z∣x) to be closer to the prior p(z) (typically an isotropic Gaussian N(0,I)). This pressure can encourage the model to find more disentangled representations.
1. Model Architecture: Your VAE architecture can be a standard convolutional setup for image data like dSprites.
2. The β-VAE Loss Function:
The implementation change from a standard VAE is minimal. Assuming reconstruction_loss
is your negative log-likelihood term (e.g., binary cross-entropy or mean squared error) and kl_divergence
is the KL term, your combined loss calculation becomes:
# Pseudocode for beta-VAE loss
# mu, log_var are outputs from the encoder
# x_reconstructed is output from the decoder
# x_original is the input image
# beta is the hyperparameter
reconstruction_loss = reconstruction_criterion(x_reconstructed, x_original)
# kl_divergence for N(mu, sigma^2) vs N(0, I)
# 0.5 * sum(1 + log_var - mu.pow(2) - log_var.exp())
kl_divergence = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)
kl_divergence = torch.mean(kl_divergence) # Average over batch
total_loss = reconstruction_loss + beta * kl_divergence
# Backpropagate total_loss
3. Training Insights:
Once your models are trained, you need to evaluate how disentangled their learned representations are.
1. Qualitative Evaluation: Latent Traversal A simple yet insightful way to qualitatively assess disentanglement is by performing latent traversals.
You can create a grid of images where each row (or column) corresponds to traversing a different latent dimension. This visual inspection can be very revealing.
2. Quantitative Metrics For a more rigorous assessment, we use quantitative metrics. These typically require the ground-truth factor labels from the dataset.
Mutual Information Gap (MIG) MIG attempts to measure the extent to which each ground-truth factor is captured by a single latent dimension. For each ground-truth factor yk:
A higher MIG score suggests better disentanglement, as it implies that each factor is primarily represented by one latent dimension, with a clear "gap" to the next most informative one.
# Pseudocode for calculating MIG (simplified)
# latents: (N_samples, N_latents) - encoded means from VAE
# factors: (N_samples, N_factors) - ground-truth factor values
# n_bins_for_latent_discretization = 20
def calculate_mig(latents, factors):
num_latents = latents.shape[1]
num_factors = factors.shape[1]
mig_scores_per_factor = []
for k in range(num_factors): # For each ground truth factor y_k
y_k = factors[:, k]
# Estimate H(y_k) - might need discretization if continuous, or use known values
# For dSprites, factors are discrete, so H(y_k) can be computed directly
h_y_k = calculate_entropy(y_k)
mutual_informations = []
for j in range(num_latents): # For each latent z_j
z_j_discretized = discretize_latent(latents[:, j], n_bins_for_latent_discretization)
# Use sklearn.metrics.mutual_info_score or similar
mi_zj_yk = compute_mutual_information(z_j_discretized, y_k)
mutual_informations.append(mi_zj_yk)
sorted_mi = sorted(mutual_informations, reverse=True)
if len(sorted_mi) < 2: continue # Not enough latents to compute gap
gap_k = (sorted_mi[0] - sorted_mi[1]) / h_y_k if h_y_k > 0 else 0
mig_scores_per_factor.append(gap_k)
return sum(mig_scores_per_factor) / len(mig_scores_per_factor) if mig_scores_per_factor else 0
# Helper functions like discretize_latent, calculate_entropy, compute_mutual_information
# would need to be implemented. For dSprites, factors are discrete, simplifying entropy.
# sklearn.feature_selection.mutual_info_regression (if y_k continuous)
# or sklearn.metrics.mutual_info_score (if y_k discrete, after discretizing z_j) can be used.
Separated Attribute Predictability (SAP) SAP measures disentanglement by assessing how well each latent dimension predicts a single ground-truth factor. For each ground-truth factor yk:
A higher SAP score indicates that individual latent dimensions are predictive of individual factors of variation.
# Pseudocode for calculating SAP (simplified)
# latents: (N_samples, N_latents) - encoded means from VAE
# factors: (N_samples, N_factors) - ground-truth factor values
# (Assumes factors are discrete for classification_accuracy)
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
def calculate_sap(latents, factors, test_size=0.2, random_state=42):
num_latents = latents.shape[1]
num_factors = factors.shape[1]
sap_scores_per_factor = []
# Split data for training classifiers
# Note: A more evaluation would use cross-validation.
latents_train, latents_test, factors_train, factors_test = train_test_split(
latents, factors, test_size=test_size, random_state=random_state
)
# Scale latents (optional but good practice for linear models)
scaler = StandardScaler()
latents_train_scaled = scaler.fit_transform(latents_train)
latents_test_scaled = scaler.transform(latents_test)
for k in range(num_factors): # For each ground truth factor y_k
y_k_train = factors_train[:, k]
y_k_test = factors_test[:, k]
prediction_scores = []
for j in range(num_latents): # For each latent z_j
z_j_train = latents_train_scaled[:, j].reshape(-1, 1)
z_j_test = latents_test_scaled[:, j].reshape(-1, 1)
# Train a simple classifier (e.g., Logistic Regression)
# Handle cases where y_k has only one class in train/test
try:
if len(np.unique(y_k_train)) < 2:
score = 0.0 # Or handle as appropriate
else:
model = LogisticRegression(solver='liblinear', multi_class='auto', C=0.1) # Keep model simple
model.fit(z_j_train, y_k_train)
score = model.score(z_j_test, y_k_test)
prediction_scores.append(score)
except ValueError: # e.g. if y_k_train has only one class
prediction_scores.append(0.0)
if not prediction_scores: continue
sorted_scores = sorted(prediction_scores, reverse=True)
if len(sorted_scores) < 2: continue # Not enough latents
# Difference between top two scores
sap_k = sorted_scores[0] - sorted_scores[1]
sap_scores_per_factor.append(sap_k)
return sum(sap_scores_per_factor) / len(sap_scores_per_factor) if sap_scores_per_factor else 0
Note on Metric Implementation: The pseudocode above simplifies certain aspects. Implementations require careful handling of data splitting (train/validation/test for the metric classifiers), hyperparameter tuning for the probe classifiers (though typically simple models are preferred to test the inherent predictability from the latent), and potentially averaging results over multiple runs. For dSprites, factor values are discrete, simplifying things.
Results showing Mutual Information Gap (MIG) and Separated Attribute Predictability (SAP) scores increasing with β, while reconstruction loss also tends to increase. This illustrates the common trade-off in β-VAEs. Note the logarithmic scales for better visualization across orders of magnitude.
This hands-on exercise provides a starting point. You can extend it by:
By actively engaging with these models and metrics, you'll build a much stronger intuition for the challenges and successes in the field of disentangled representation learning. Remember that this is an active research area, and perfect disentanglement, especially on complex datasets without supervision, remains an open problem.
Was this section helpful?
© 2025 ApX Machine Learning