Generative Adversarial Network (GAN) is a self-supervised learning technique for the training of a diverse array of sophisticated deep learning architectures. It learns intrinsic correlations and patterns in unlabeled data without training labels. However, self-supervised type of learnings measure results against a ground truth, which is derived from training data. In other words, self-supervised learning make edits to the given dataset and ask the model to predict the edited part.
Self-supervised models use gradient descent during backpropagation to adjust model weights. They use loss functions to measure the divergence (“loss”) between ground truth and model predictions in order to converge and optimize the model weights.
GANs consist of two neural networks, Generator and Discriminator.
Generator takes random noise and produces synthetic data that is similar to a given dataset. Discriminator is a classifier which is trained to distinguish between the real given dataset and the synthetic samples, produced by Generator. Generator is trained against Discriminator to minimize the divergence from ground truth. Discriminator is trained simultaneously to become more adept at distinguishing between real and generated data.
The adversarial process repeats until the generator produces samples that are ideally indistinguishable from the given dataset.
This project will use GANs to recreate the Kuzushiji-MNIST hand-written images. The GANs consist of two neural networks, one for Generator and the other for Discriminator.
Generator produces synthetic images, which are used as training data with added class label "Fake". Discriminator is also trained with the given ground truth dataset with added class label "Real" and with the synthetic datasets with class label "Fake".
By following the backpropagation process, the neural network weights are optimized such that Discriminator minimize the classification loss and Generator minimize the image divergence loss.
In this project, the feature distance algorithm is used to produce an index for evaluating the GANs performance.
Extracted features from key points in an image, detected by algorithms, are used to quantify the difference between two images, rather than comparing pixel values directly.
FID calculates the “distance” between real and generated images in the feature space of a pre-trained Inception model such as InceptionV3. A low FID score indicates higher quality and similarity to real images.
The inception network expects 3-channel RGB images. Therefore, it is required to convert single-channel grayscale KMNIST image by repeating the single channel three times. The pretrained InceptionV3 model uses 299x299 images. The pytorch FrechetInceptionDistance() function automatically resizes the input images to 299x299.
The KMNIST images may not have wide data distribution and may not use FID score effectively. However, it is used as an evaluation example to visualize that the GAN model has improved at each epoch.
Step 1. Pass the real and fake images to the pretrained InceptionV3 network to retrieve the feature sets.
Step 2. Calculate the mean ${\mu}$ and covariance ${\sum}$ of each feature set.
Step 3. Compute FID scores.
The formula to compute the FID score is:
Note that the number of samples used to calculate FID can influence the score. Use sufficient and consistent number of samples.
This project will utilize pytorch libraries to train the GANs and to evaluate the model performance.
import torch
import torch.nn as nn
When generating the images, they are typically normalized to be either in the range [0,1] or [-1,1].
The tanh activation function is frequently used in the output layer of a generator for the following reasons:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
])
# Load Kuzushiji-MNIST dataset directly from torchvision datasets
train_dataset = datasets.KMNIST(root='./data', train=True, download=True, transform=transform)
# Pass the dataset to a dataloader for batch operations
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
Generator is a neural network which takes noise as an input and generate an image as an output.
Latent dimension
Latent dimension specifies the feature dimension of the noise vector that is fed into the first layer of the generator.
Number of features
128 is a common choice for the dimensionality of the latent space or the number of feature maps/channels at the first layer of the generator. The latent space dimensionality controls a balance between expressive power (ability to generate diverse images) and manageable computational complexity. It determines the richness of the features extracted at that stage to learn a good representation of the image features.
Leaky ReLU activation
Leaky ReLU function is used to help prevent the "dying ReLU" problem and promote more stable training. It allows a small, non-zero gradient for negative inputs associated with the image in the range [-1,1]. This overcomes the ReLU limitation of deactivating neurons which have negative values.
Batch normalization
The batch normalization layer in the generator helps stabilize training by reducing the change in the distribution of activations and improving gradient flow.
In GANs, the generator and discriminator are constantly competing, and the generator's output distribution changes rapidly as it learns. This can cause the discriminator to have to constantly adapt to new input distributions, making training unstable.
Batch normalization normalizes the activations within each mini-batch, ensuring they have a zero mean and unit variance, i.e. consistent distribution.
Number of network layers
The depth of the neural network depends on the complexity of the data distribution in an image. The generator and discriminator should be relatively balanced in their capacity. If one is significantly more powerful than the other, it can lead to training instability. Overfitting can occur with too many layers or neurons, leading to decreased performance on unseen data.
Output image shape
The output image shape returns to the KMNIST dataset image shape of (28,28)
# Generator Network
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(128),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2),
nn.Linear(512, np.prod(img_shape)),
nn.Tanh() # Output pixel values in [-1, 1]
The difference between generator and discriminator is that
The optimal goal is that the generator learns to mimic the distribution of real data, while the discriminator learns to detect the generated data, creating an ongoing "arms race" that leads to the generation of increasingly realistic data.
# Discriminator Network
self.model = nn.Sequential(
nn.Linear(np.prod(img_shape), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid() # Output probability in [0, 1]
In addition to setting up the layers in neural networks, it is also important to set up their optimizers and loss functions.
Adam's combination of adaptive learning rates, momentum, and stability makes it a popular and effective choice for training both the generator and discriminator.
Momentum:
Adam incorporates a momentum term, by calculating an exponentially decaying average of past gradients (first moment) and past squared gradients (second moment)
Adaptive learning rates:
Adam uses these moving averages on the first and second moments of the gradients to scale the learning rate for each parameter. Parameters that have consistently large gradients will have their learning rates reduced, while parameters with smaller, more consistent gradients will have their learning rates increased. This adaptivity can lead to faster convergence and better performance, especially in complex, non-convex optimization landscape in GANs.
Bias correction:
In the early stages of training, the moving averages can be biased towards zero. Adam includes a bias correction step to mitigate this.
Reference:
https://docs.pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
where Y: Ground Truth Label and D: Discriminator Output Label
When Y=1,
\begin{equation} BCE_{Loss}=-(Log(D_{Real})) \end{equation}
When Y=0,
\begin{equation}BCE_{Loss}=-(Log(1-D_{Fake})) \end{equation}
where D_Real: Discriminator Output Label with Real image as an input
D_Fake: Discriminator Output Label with Fake image as an input
The discriminator's overall loss is the sum of the BCE loss calculated for real samples and fake samples. The discriminator aims to minimize this combined BCE loss, indicating its success in correctly classifying both real and fake inputs.
The generator's goal is to produce fake data that is indistinguishable from real data, effectively "fooling" the discriminator. In terms of BCE loss, this means the generator attempts to make the discriminator's output for fake images as close to label "Real" (1) as possible. In other words, maximize the BCE loss. This creates the adversarial dynamic where the generator and discriminator are in a "minimax game."
Reference:
https://docs.pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
import torch.optim as optim
# Model Initialization and Training Setup
latent_dim = 100
img_shape = (1, 28, 28) # KMNIST image is grayscale with 28x28 pixels
generator = Generator(latent_dim, img_shape)
discriminator = Discriminator(img_shape)
# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Loss function
adversarial_loss = nn.BCELoss()
References:
https://docs.pytorch.org/docs/stable/optim.html
from torcheval.metrics import FrechetInceptionDistance
# Initialize FID metric
metric = FrechetInceptionDistance()
Adversarial Training:
The generator and discriminator are trained in an alternating fashion.
# For each epoch, train both generator and discriminator and calculate the losses
for i, (imgs, _) in enumerate(train_loader):
# Adversarial ground truths
valid = torch.ones(imgs.size(0), 1)
fake = torch.zeros(imgs.size(0), 1)
# Test Discriminator using Discriminator's real images classification and Real Labels
optimizer_d.zero_grad()
real_loss = adversarial_loss(discriminator(imgs), valid)
# z is the noise input, whose feature dimenson is specified by latent_dim
z = torch.randn(imgs.size(0), latent_dim)
# Generate fake images using Generator and noise
gen_imgs = generator(z)
#Test Discriminator using Discriminator's fake images classification and Fake Labels
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
#Conclude Discriminator Loss
d_loss = (real_loss + fake_loss) / 2
#Use Backpropagation to train Discriminator
d_loss.backward()
#Increment Optimizer Process to improve model weights of the Discriminator
optimizer_d.step()
# Train Generator (next 4 steps)
optimizer_g.zero_grad()
#Test Generator using Discriminator's fake images classification and Real Labels
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
#Use Backpropagation to train Generator
g_loss.backward()
#Increment Optimizer Process to improve model weights of the Generator
optimizer_g.step()
The FID scores utilize the Inception network which expects 3-channel RGB images in the range of [0,1].
Therefore, it is required to
# update metrics
if imgs.shape[1]==1:
tempimgs=(imgs+1)/2
fidRealImg=tempimgs.repeat(1,3,1,1)
if gen_imgs.shape[1]==1:
tempimgs=(gen_imgs+1)/2
fidFakeImg=tempimgs.repeat(1,3,1,1)
### add images to FID metric
metric.update(fidRealImg, is_real=True)
metric.update(fidFakeImg, is_real=False)
#Compute the FID score
fid_score=metric.compute()
The following chart shows the history of FID scores at each epoch and it indicates that the FID score reduces while the synthetic image quality improves.