How to Build Your Own Generative Adversarial Network (GAN)

 How to Build Your Own Generative Adversarial Network (GAN)

๐Ÿ“Œ What is a GAN?


A Generative Adversarial Network (GAN) consists of two neural networks:


Generator (G) – learns to create fake data (e.g. fake images).


Discriminator (D) – learns to distinguish between real and fake data.


They play a game:


The Generator tries to fool the Discriminator.


The Discriminator tries to catch the Generator.


๐Ÿงฐ Tools You’ll Use


Python


PyTorch


NumPy, Matplotlib (for data and visualization)


Dataset: MNIST (handwritten digits)


๐Ÿ› ️ Step-by-Step: Build a GAN from Scratch

๐Ÿ”น Step 1: Install Required Libraries

pip install torch torchvision matplotlib


๐Ÿ”น Step 2: Import Libraries

import torch

import torch.nn as nn

import torch.optim as optim

from torchvision import datasets, transforms

from torch.utils.data import DataLoader

import matplotlib.pyplot as plt


๐Ÿ”น Step 3: Load the Dataset (MNIST)

transform = transforms.Compose([

    transforms.ToTensor(),

    transforms.Normalize((0.5,), (0.5,))

])


dataloader = DataLoader(

    datasets.MNIST(root='./data', train=True, download=True, transform=transform),

    batch_size=128,

    shuffle=True

)


๐Ÿ”น Step 4: Define the Generator Network

class Generator(nn.Module):

    def __init__(self):

        super(Generator, self).__init__()

        self.main = nn.Sequential(

            nn.Linear(100, 256),

            nn.ReLU(True),

            nn.Linear(256, 512),

            nn.ReLU(True),

            nn.Linear(512, 784),

            nn.Tanh()  # Output range [-1, 1]

        )


    def forward(self, x):

        return self.main(x)


๐Ÿ”น Step 5: Define the Discriminator Network

class Discriminator(nn.Module):

    def __init__(self):

        super(Discriminator, self).__init__()

        self.main = nn.Sequential(

            nn.Linear(784, 512),

            nn.LeakyReLU(0.2),

            nn.Linear(512, 256),

            nn.LeakyReLU(0.2),

            nn.Linear(256, 1),

            nn.Sigmoid()  # Probability real/fake

        )


    def forward(self, x):

        return self.main(x)


๐Ÿ”น Step 6: Initialize Networks and Optimizers

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


G = Generator().to(device)

D = Discriminator().to(device)


loss_function = nn.BCELoss()

optimizer_G = optim.Adam(G.parameters(), lr=0.0002)

optimizer_D = optim.Adam(D.parameters(), lr=0.0002)


๐Ÿ”น Step 7: Train the GAN

import numpy as np


epochs = 50

for epoch in range(epochs):

    for batch_idx, (real_images, _) in enumerate(dataloader):

        batch_size = real_images.size(0)


        # Prepare real and fake data

        real_images = real_images.view(batch_size, -1).to(device)

        real_labels = torch.ones(batch_size, 1).to(device)

        fake_labels = torch.zeros(batch_size, 1).to(device)


        # 1. Train Discriminator

        z = torch.randn(batch_size, 100).to(device)

        fake_images = G(z)


        real_loss = loss_function(D(real_images), real_labels)

        fake_loss = loss_function(D(fake_images.detach()), fake_labels)

        d_loss = real_loss + fake_loss


        optimizer_D.zero_grad()

        d_loss.backward()

        optimizer_D.step()


        # 2. Train Generator

        z = torch.randn(batch_size, 100).to(device)

        fake_images = G(z)

        g_loss = loss_function(D(fake_images), real_labels)  # Want D to think fakes are real


        optimizer_G.zero_grad()

        g_loss.backward()

        optimizer_G.step()


    print(f"Epoch [{epoch+1}/{epochs}]  D Loss: {d_loss.item():.4f}  G Loss: {g_loss.item():.4f}")


๐Ÿ”น Step 8: Visualize Generated Images

def show_generated_images(generator, num_images=16):

    generator.eval()

    z = torch.randn(num_images, 100).to(device)

    fake_images = generator(z).view(-1, 1, 28, 28).detach().cpu()

    

    grid = torchvision.utils.make_grid(fake_images, nrow=4, normalize=True)

    plt.imshow(grid.permute(1, 2, 0).squeeze())

    plt.axis("off")

    plt.show()


show_generated_images(G)


๐Ÿ“Š What You Just Built


A basic GAN that learns to generate handwritten digits from random noise.


You trained a Generator to make digits.


You trained a Discriminator to distinguish real vs. fake.


Over time, the Generator gets better at "fooling" the Discriminator.


๐Ÿช„ How to Improve It


Use Convolutional GANs (DCGAN) for image data.


Add training tricks (like label smoothing, noise injection).


Try conditional GANs (cGANs) to generate specific digit classes.


Use Wasserstein GAN (WGAN) for more stable training.

Learn AI ML Course in Hyderabad

Read More

Transfer Learning with Pre-trained Models: A Practical Guide

Understanding the Vanishing Gradient Problem in Neural Networks

Building a Neural Network with PyTorch: A Beginner’s Guide

How to Use TensorFlow for Deep Learning Projects

Visit Our Quality Thought Training Institute in Hyderabad

Get Directions

Comments

Popular posts from this blog

Understanding Snowflake Editions: Standard, Enterprise, Business Critical

Installing Tosca: Step-by-Step Guide for Beginners

Entry-Level Cybersecurity Jobs You Can Apply For Today