Practice and reinforce the concepts from Lesson 10
In this activity, you'll build a complete Variational Autoencoder (VAE) from scratch for MNIST digit generation. You'll implement the encoder, decoder, reparameterization trick, and ELBO loss, then explore the learned latent space through interpolation and arithmetic.
By completing this activity, you will:
Download the activity template from the Templates folder:
AI25-Template-activity-10-variational-autoencoders.zipTemplates/AI25-Template-activity-10-variational-autoencoders.zipactivity-10-variational-autoencoders.ipynb to Google ColabExecute the first few cells to:
TODO 1: Implement encoder to output μ and log σ²
class Encoder(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
super().__init__()
# TODO 1: Define encoder architecture
# Input: Flattened image (784)
# Hidden layer: (hidden_dim) with ReLU
# Outputs: mu (latent_dim), logvar (latent_dim)
self.fc1 = nn.Linear(input_dim, hidden_dim)
# Your code here: Define fc_mu and fc_logvar
def forward(self, x):
# TODO 1: Implement forward pass
# x: (batch, 784)
# Returns: mu, logvar (both batch × latent_dim)
# Your code here
pass
TODO 2: Implement z = μ + σ * ε (where ε ~ N(0,1))
def reparameterize(self, mu, logvar):
"""
Reparameterization trick: z = mu + sigma * epsilon
Args:
mu: Mean (batch, latent_dim)
logvar: Log variance (batch, latent_dim)
Returns:
z: Sampled latent code (batch, latent_dim)
"""
# TODO 2: Implement reparameterization
# Step 1: Compute std = exp(0.5 * logvar)
# Step 2: Sample epsilon ~ N(0, 1) using torch.randn_like
# Step 3: Return mu + std * epsilon
# Your code here
pass
TODO 3: Implement decoder to reconstruct images from latent codes
class Decoder(nn.Module):
def __init__(self, latent_dim=20, hidden_dim=400, output_dim=784):
super().__init__()
# TODO 3: Define decoder architecture
# Input: Latent code (latent_dim)
# Hidden layer: (hidden_dim) with ReLU
# Output: Reconstructed image (output_dim) with Sigmoid
# Your code here
pass
def forward(self, z):
# TODO 3: Implement forward pass
# z: (batch, latent_dim)
# Returns: x_recon (batch, 784) in range [0, 1]
# Your code here
pass
TODO 4: Implement reconstruction + KL divergence loss
def vae_loss(x, x_recon, mu, logvar):
"""
Compute ELBO loss = Reconstruction Loss + KL Divergence
Args:
x: Original images (batch, 784)
x_recon: Reconstructed images (batch, 784)
mu: Encoder mean (batch, latent_dim)
logvar: Encoder log variance (batch, latent_dim)
Returns:
loss: Scalar loss value
recon_loss: Reconstruction component
kl_loss: KL divergence component
"""
# TODO 4a: Compute reconstruction loss (Binary Cross-Entropy)
# Use F.binary_cross_entropy(x_recon, x, reduction='sum')
recon_loss = None # Your code here
# TODO 4b: Compute KL divergence
# KL(N(μ,σ²) || N(0,1)) = 0.5 * sum(1 + log(σ²) - μ² - σ²)
# Use: -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
kl_loss = None # Your code here
# Total loss
loss = recon_loss + kl_loss
return loss, recon_loss, kl_loss
TODO 5: Combine encoder and decoder into full VAE
class VAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
super().__init__()
self.latent_dim = latent_dim
# TODO 5: Initialize encoder and decoder
self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
def forward(self, x):
# TODO 5: Implement full forward pass
# 1. Encode: x → mu, logvar
# 2. Reparameterize: mu, logvar → z
# 3. Decode: z → x_recon
# Returns: x_recon, mu, logvar
# Your code here
pass
def sample(self, num_samples):
"""
Generate new samples from prior N(0, I)
Args:
num_samples: Number of samples to generate
Returns:
Generated images (num_samples, 784)
"""
# TODO 5: Implement sampling
# 1. Sample z from N(0, I)
# 2. Decode z to get generated images
# Your code here
pass
Pre-built training loop with:
Features:
TODO 6: Implement latent space interpolation
def interpolate(vae, img1, img2, num_steps=10):
"""
Interpolate between two images in latent space
Args:
vae: Trained VAE model
img1, img2: Input images (1, 784)
num_steps: Number of interpolation steps
Returns:
Interpolated images (num_steps, 784)
"""
# TODO 6: Implement interpolation
# 1. Encode img1 and img2 to get mu1, mu2
# 2. Interpolate: z_interp = (1-alpha) * mu1 + alpha * mu2
# for alpha in linspace(0, 1, num_steps)
# 3. Decode each z_interp
# Your code here
pass
TODO 7: Implement latent space traversal (vary one dimension)
def traverse_latent_dim(vae, dim, num_samples=10, range_vals=(-3, 3)):
"""
Traverse a single latent dimension
Args:
vae: Trained VAE model
dim: Dimension to vary (0 to latent_dim-1)
num_samples: Number of samples along traversal
range_vals: (min, max) values for traversal
Returns:
Generated images (num_samples, 784)
"""
# TODO 7: Implement traversal
# 1. Create base latent code z = zeros(latent_dim)
# 2. Vary z[dim] from range_vals[0] to range_vals[1]
# 3. Decode each z
# Your code here
pass
TODO 8: Implement β-VAE for disentangled representations
def beta_vae_loss(x, x_recon, mu, logvar, beta=4.0):
"""
Compute β-VAE loss = Reconstruction Loss + β * KL Divergence
Args:
beta: Weight on KL term (β > 1 encourages disentanglement)
Returns:
loss, recon_loss, kl_loss
"""
# TODO 8: Modify VAE loss to include beta
# Same as vae_loss, but multiply kl_loss by beta
# Your code here
pass
Training Progress:
Epoch 1/10:
Loss: 154.2 | Recon: 142.5 | KL: 11.7
Epoch 5/10:
Loss: 95.3 | Recon: 88.1 | KL: 7.2
Epoch 10/10:
Loss: 89.7 | Recon: 83.4 | KL: 6.3
✓ Training complete
✓ Model saved to vae_mnist.pth
Reconstruction Quality:
✓ Original digits clearly recognizable
✓ Reconstructions slightly blurry but accurate
✓ All 10 digit classes reconstructed well
Generated Samples (from random z ~ N(0,I)):
✓ 100 generated digits
✓ All digit classes represented (0-9)
✓ Smooth, realistic digit shapes
✓ Some variation within each class
Interpolation between digit "3" and "8":
Step 0: Clear "3"
Step 2: "3" morphing
Step 5: Hybrid (between 3 and 8)
Step 7: "8" forming
Step 10: Clear "8"
✓ Smooth transition
✓ All intermediate images valid digits
✓ No abrupt changes
Latent Dimension Traversal:
Dimension 5:
- Low values (-3): Thin, left-leaning digits
- Medium values (0): Normal digits
- High values (+3): Thick, right-leaning digits
✓ Dimension controls specific attribute
✓ Smooth changes across traversal
Comparing β = 1 (standard VAE) vs β = 4 (β-VAE):
β = 1:
- Better reconstructions
- Entangled latent dimensions
- Dimension 3 controls thickness + rotation
β = 4:
- Slightly blurrier reconstructions
- More disentangled dimensions
- Dimension 3: thickness only
- Dimension 7: rotation only
✓ Trade-off: Disentanglement vs reconstruction quality
Your implementation is complete when:
Common Issues:
One. Reconstruction Loss Doesn't Decrease:
2. KL Divergence Collapses to 0:
3. Generated Samples are Blurry:
Good latent space:
# Test: Interpolation should be smooth
img1 = mnist_test[0] # Digit "7"
img2 = mnist_test[1] # Digit "2"
interpolated = interpolate(vae, img1, img2, num_steps=10)
# All intermediate steps should be valid digits
for step in interpolated:
assert is_valid_digit(step), "Interpolation jumped!"
Poor latent space:
| Parameter | Recommended | Effect |
|---|---|---|
| Latent dim | 10-20 | Smaller: compressed, Larger: expressive |
| Hidden dim | 400-512 | More capacity, slower training |
| Learning rate | 1e-3 | Standard for VAEs |
| Beta (β-VAE) | 1-10 | Higher: more disentanglement, worse reconstruction |
| Epochs | 10-20 | More: better quality, diminishing returns |
Add class conditioning to generate specific digits:
class CVAE(nn.Module):
def __init__(self, num_classes=10, ...):
# Encoder: concat image + one-hot label
# Decoder: concat z + one-hot label
def generate_digit(self, digit_class):
"""Generate specific digit (0-9)"""
pass
Replace fully-connected layers with convolutions:
class ConvEncoder(nn.Module):
def __init__(self):
self.conv1 = nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
# ...
Benefit: Better image quality, fewer parameters
Improve ELBO estimate with multiple samples:
def iwae_loss(x, x_recon_samples, mu, logvar, k=5):
"""
IWAE with k samples per datapoint
Better gradient estimates than standard VAE
"""
pass
Multi-level latent variables:
z2 (global) → z1 (local) → x (image)
Use case: Capture structure at multiple scales
Completed Notebook: activity-10-variational-autoencoders.ipynb
Generated Artifacts:
Analysis (5-7 sentences):
pytorch/examplesNext Activity: Activity 11 - Build a GAN for higher-quality image generation
This activity is graded on:
Passing Grade: 70% or higher
Excellent work on your first deep generative model! 🎉🎨