Saturday, December 13, 2025

thumbnail

Implementing a VAE for Image Generation: A Hands-On Example

 Implementing a Variational Autoencoder (VAE) for Image Generation: A Hands-On Example


A Variational Autoencoder (VAE) is a generative deep learning model that learns a probabilistic latent space and can generate new images similar to the training data. VAEs are widely used for image generation, anomaly detection, and representation learning.


This hands-on example walks through building and training a VAE for image generation using PyTorch.


1. What Is a VAE? (Quick Intuition)


A VAE consists of two neural networks:


Encoder


Compresses input images into a latent distribution (mean μ and variance σ²)


Decoder


Reconstructs images from sampled latent vectors


Unlike a normal autoencoder, a VAE:


Learns a continuous, smooth latent space


Enables random sampling for image generation


2. VAE Architecture Overview

Input Image

     ↓

 Encoder → μ, log(σ²)

     ↓

 Reparameterization Trick

     ↓

 Decoder

     ↓

 Reconstructed Image


3. Environment Setup

pip install torch torchvision matplotlib


4. Import Libraries

import torch

import torch.nn as nn

import torch.optim as optim

from torchvision import datasets, transforms

import matplotlib.pyplot as plt


5. Prepare Dataset (MNIST)

transform = transforms.Compose([

    transforms.ToTensor()

])


train_loader = torch.utils.data.DataLoader(

    datasets.MNIST(

        root="./data",

        train=True,

        download=True,

        transform=transform

    ),

    batch_size=128,

    shuffle=True

)


6. Define the VAE Model

class VAE(nn.Module):

    def __init__(self, latent_dim=20):

        super(VAE, self).__init__()


        # Encoder

        self.fc1 = nn.Linear(28 * 28, 400)

        self.fc_mu = nn.Linear(400, latent_dim)

        self.fc_logvar = nn.Linear(400, latent_dim)


        # Decoder

        self.fc3 = nn.Linear(latent_dim, 400)

        self.fc4 = nn.Linear(400, 28 * 28)


    def encode(self, x):

        h = torch.relu(self.fc1(x))

        return self.fc_mu(h), self.fc_logvar(h)


    def reparameterize(self, mu, logvar):

        std = torch.exp(0.5 * logvar)

        eps = torch.randn_like(std)

        return mu + eps * std


    def decode(self, z):

        h = torch.relu(self.fc3(z))

        return torch.sigmoid(self.fc4(h))


    def forward(self, x):

        mu, logvar = self.encode(x)

        z = self.reparameterize(mu, logvar)

        return self.decode(z), mu, logvar


7. Loss Function (Reconstruction + KL Divergence)

def vae_loss(recon_x, x, mu, logvar):

    recon_loss = nn.functional.binary_cross_entropy(

        recon_x, x, reduction='sum'

    )


    kl_div = -0.5 * torch.sum(

        1 + logvar - mu.pow(2) - logvar.exp()

    )


    return recon_loss + kl_div


8. Training the VAE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VAE().to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)


epochs = 10


for epoch in range(epochs):

    model.train()

    total_loss = 0


    for data, _ in train_loader:

        data = data.view(-1, 28 * 28).to(device)

        optimizer.zero_grad()


        recon, mu, logvar = model(data)

        loss = vae_loss(recon, data, mu, logvar)


        loss.backward()

        optimizer.step()

        total_loss += loss.item()


    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader.dataset):.4f}")


9. Generate New Images

model.eval()

with torch.no_grad():

    z = torch.randn(64, 20).to(device)

    samples = model.decode(z).cpu()


10. Visualize Generated Images

fig, axes = plt.subplots(8, 8, figsize=(8, 8))

for i, ax in enumerate(axes.flat):

    ax.imshow(samples[i].view(28, 28), cmap="gray")

    ax.axis("off")

plt.show()


11. Key Concepts to Understand

Reparameterization Trick


Allows backpropagation through randomness:


z = μ + σ ⊙ ε, where ε ~ N(0,1)


KL Divergence


Encourages latent space to follow a standard normal distribution


12. Improvements & Extensions


Use CNN-based VAE for higher-quality images


Increase latent dimension


Train on CIFAR-10 or CelebA


Add β-VAE for disentangled representations


Save & interpolate latent vectors


13. Common Issues

Issue Fix

Blurry images Use CNNs

Posterior collapse Reduce KL weight

Poor generation Train longer

Conclusion


This hands-on example demonstrated how to implement a Variational Autoencoder from scratch and use it to generate images. VAEs provide a principled probabilistic framework for generative modeling and form the foundation for more advanced models like β-VAE, VQ-VAE, and diffusion models.

Learn Generative AI Training in Hyderabad

Read More

How to Use DALL·E for Text-to-Image Creation: A Beginner’s Guide

Creating Music with AI: A Practical Introduction to AI Music Generation

Exploring Style Transfer with Neural Networks: A Hands-On Guide

Building Your First GAN: A Step-by-Step Tutorial

Visit Our Quality Thought Training Institute in Hyderabad

Get Directions

Subscribe by Email

Follow Updates Articles from This Blog via Email

No Comments

About

Search This Blog

Powered by Blogger.

Blog Archive