Generative Adversarial Networks using Pytorch
In this tutorial, we will learn about Generative Adversarial Networks (GAN) using Pytorch in Python.
What is GAN?
GAN(Generative Adversarial Network) is a type of Neural Network that is used to generate any type of data that is close to real data. The data can be images, videos, or audio. The GAN consists of two neural networks which are complimentary to each other:
- Generator: This neural network creates data, often called fake data, similar to real data. Initially, it generates random data, which the discriminator further checks, and based on the discriminator’s feedback, the weights are adjusted, and another set of fake data is generated. This process goes on until the Discriminator is no longer able to differentiate between real data and fake data.
- Discriminator: This neural network works as a classifier, and its purpose is to classify the data generated by the Generator. Both the real data and fake data are fed into the discriminator so that it can initially understand the differences. After analyzing the data the Generator receives, it generates feedback and updates its weights of the loss function. This is what we call Backpropagation. It helps in enhancing the classification process. This way, both the networks become better and better at each epoch, and finally, the fake data is generated, which looks the same as real data.
Python Code: Generative Adversarial Network
I am running the code on Google Colab as it is advisable to run on Google Colab because of its higher processing capabilities.
import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt import numpy as np # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Hyperparameters latent_size = 64 hidden_size = 256 image_size = 784 # 28x28 num_epochs = 100 batch_size = 100 learning_rate = 0.0002 # MNIST dataset transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) # Discriminator class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(image_size, hidden_size), nn.LeakyReLU(0.2), nn.Linear(hidden_size, hidden_size), nn.LeakyReLU(0.2), nn.Linear(hidden_size, 1), nn.Sigmoid() ) def forward(self, x): x = x.view(x.size(0), -1) output = self.model(x) return output # Generator class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.model = nn.Sequential( nn.Linear(latent_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, image_size), nn.Tanh() ) def forward(self, z): output = self.model(z) output = output.view(output.size(0), 1, 28, 28) return output # Create models D = Discriminator().to(device) G = Generator().to(device) # Loss and optimizer criterion = nn.BCELoss() d_optimizer = optim.Adam(D.parameters(), lr=learning_rate) g_optimizer = optim.Adam(G.parameters(), lr=learning_rate) # Training loop total_step = len(train_loader) for epoch in range(num_epochs): for i, (images, _) in enumerate(train_loader): images = images.to(device) real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device) # Train Discriminator outputs = D(images) d_loss_real = criterion(outputs, real_labels) real_score = outputs z = torch.randn(batch_size, latent_size).to(device) fake_images = G(z) outputs = D(fake_images.detach()) d_loss_fake = criterion(outputs, fake_labels) fake_score = outputs d_loss = d_loss_real + d_loss_fake d_optimizer.zero_grad() g_optimizer.zero_grad() d_loss.backward() d_optimizer.step() # Train Generator z = torch.randn(batch_size, latent_size).to(device) fake_images = G(z) outputs = D(fake_images) g_loss = criterion(outputs, real_labels) d_optimizer.zero_grad() g_optimizer.zero_grad() g_loss.backward() g_optimizer.step() if (i+1) % 200 == 0: print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}')
Code Explanation
- Device Configuration: This sets up the device to CPU or GPU, which is available (GPU is preferred for faster computations) used by PyTorch.
- Hyperparameters: The variables that affect the architecture of the neural networks and the final output are defined separately.
- Dataset: Loading the dataset of MNIST images, a built-in dataset available with PyTorch. The images are further processed and converted into tensors.
- Discriminator: Neural network consisting of various linear layers, leakyReLU activation, and the sigmoid function.
- Generator: Neural network consisting of various linear layers, ReLU activation, and tanh activation function to generate images.
- BCE Loss function: Binary cross entropy function is used for binary classification. It measures the difference between the real data distribution and predicted probability distribution output by neural networks.
- Adam Optimizer: This helps modify the weights of each neural network independently.
- Training: Finally, each neural network works as it is assigned, and after several epochs, we get the output.
Leave a Reply