Generative Adversarial Networks using Pytorch

In this tutorial, we will learn about Generative Adversarial Networks (GAN) using Pytorch in Python.

What is GAN?

GAN(Generative Adversarial Network) is a type of Neural Network that is used to generate any type of data that is close to real data. The data can be images, videos, or audio. The GAN consists of two neural networks which are complimentary to each other:

  • Generator: This neural network creates data, often called fake data, similar to real data. Initially, it generates random data, which the discriminator further checks, and based on the discriminator’s feedback, the weights are adjusted, and another set of fake data is generated. This process goes on until the Discriminator is no longer able to differentiate between real data and fake data.
  • Discriminator: This neural network works as a classifier, and its purpose is to classify the data generated by the Generator. Both the real data and fake data are fed into the discriminator so that it can initially understand the differences. After analyzing the data the Generator receives, it generates feedback and updates its weights of the loss function. This is what we call Backpropagation. It helps in enhancing the classification process. This way, both the networks become better and better at each epoch, and finally, the fake data is generated, which looks the same as real data.

Python Code: Generative Adversarial Network

I am running the code on Google Colab as it is advisable to run on Google Colab because of its higher processing capabilities.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
latent_size = 64
hidden_size = 256
image_size = 784  # 28x28
num_epochs = 100
batch_size = 100
learning_rate = 0.0002

# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size, hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        output = self.model(x)
        return output

# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, image_size),
            nn.Tanh()
        )

    def forward(self, z):
        output = self.model(z)
        output = output.view(output.size(0), 1, 28, 28)
        return output

# Create models
D = Discriminator().to(device)
G = Generator().to(device)

# Loss and optimizer
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=learning_rate)
g_optimizer = optim.Adam(G.parameters(), lr=learning_rate)

# Training loop
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        images = images.to(device)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Train Discriminator
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs

        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # Train Generator
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)

        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}')

Code Explanation

  • Device Configuration: This sets up the device to CPU or GPU, which is available (GPU is preferred for faster computations) used by PyTorch.
  • Hyperparameters: The variables that affect the architecture of the neural networks and the final output are defined separately.
  • Dataset: Loading the dataset of MNIST images, a built-in dataset available with PyTorch. The images are further processed and converted into tensors.
  • Discriminator: Neural network consisting of various linear layers, leakyReLU activation, and the sigmoid function.
  • Generator: Neural network consisting of various linear layers, ReLU activation, and tanh activation function to generate images.
  • BCE Loss function: Binary cross entropy function is used for binary classification. It measures the difference between the real data distribution and predicted probability distribution output by neural networks.
  • Adam Optimizer: This helps modify the weights of each neural network independently.
  • Training: Finally, each neural network works as it is assigned, and after several epochs, we get the output.

Leave a Reply

Your email address will not be published. Required fields are marked *