Using Pytorch to implement Spectral Normalized Generative Adversarial Network (SN-GAN)

d5da426eedaf883d26bbbfd9eb8aee11.png

Source: DeepHub IMBA
This article is about 3800 words, it is recommended to read for 5 minutes
Since the release of the diffusion model, GAN has received less and less attention and papers, but some of their ideas are still worth understanding and learning. So in this article we use Pytorch to implement SN-GAN. 

A spectral normalized generative adversarial network is a type of generative adversarial network that uses spectral normalization techniques to stabilize the training of the discriminator. Spectral normalization is a weight normalization technique that constrains the spectral norm of each layer in the discriminator. This helps prevent the discriminator from becoming too powerful, leading to instability and poor results.

SN-GAN was proposed by Miyato et al. (2018) in the paper “Spectral Normalization for Generative Adversarial Networks”, where the authors demonstrated that sn-GAN has better performance than other gans on various image generation tasks.

SN-GAN is trained in the same way as other GANs. The generator network learns to generate images that are indistinguishable from real images, while the discriminator network learns to differentiate between real and generated images. The two networks are trained in a competitive manner, and they eventually reach a point where the generator is able to produce realistic images that fool the discriminator.

The following is a summary of the advantages of SN-GAN compared to other gans:

  • More stable and easier to train

  • Can produce higher quality images

  • More versatile and can be used to generate a wider range of content.

Mode Collapse

Mode collapse is a common problem in the training of generative adversarial networks (GANs). Mode collapse occurs when a GAN’s generator network fails to produce diverse outputs and instead gets stuck in a specific pattern. This results in generated output that is repetitive, lacks variety and detail, and is sometimes completely irrelevant to the training data.

There are several reasons why mode collapse occurs in GANs. One reason is that the generator network may overfit the training data. This can happen if the training data is not diverse enough, or if the generator network is too complex. Another reason is that the generator network can get stuck in a local minimum of the loss function. This can happen if the learning rate is too high, or if the loss function is not well defined.

There have been many techniques used to prevent schema collapse. Such as using a more diverse training data set. Or use regularization techniques such as dropout or batch normalization. It is also important to use appropriate learning rates and loss functions.

d4226e534452e614c1f7b9b0e2c04dc4.jpeg

Wassersteian loss

891543bfb663bf3d597e61720c50d008.png

Wasserstein loss, also known as Earth Mover’s Distance (EMD) or Wasserstein GAN (WGAN) loss, is a loss function used in generative adversarial networks (GAN). It was introduced to solve some problems associated with traditional GAN loss functions, such as Jensen-Shannon divergence and Kullback-Leibler divergence.

Wasserstein loss measures the difference between the probability distribution of real data and generated data while ensuring that it has certain mathematical properties. The idea is to minimize the Wassersteian distance (also known as the Earth Mover distance) between these two distributions. The Wasserstein distance can be thought of as the minimum “cost” required to transform one distribution into another, where “cost” is defined as the “work” required to move the probability mass from one location to another.

d3f6663d021ad603f732c751203edffb.jpeg

The mathematical definition of Wasserstein loss is as follows:

For the generator G and the discriminator D, the Wasserstein loss (Wasserstein distance) can be expressed as:

2693d0cff55cea69a8b0308532c1310f.png

Jensen-Shannon Divergence (JSD): Jensen-Shannon Divergence is a symmetry measure used to quantify the difference between two probability distributions

For probability distributions P and Q, JSD is defined as follows:


JSD(P∥Q)=1/2(KL(P∥M) + KL(Q∥M))

M is the mean distribution, KL is the Kullback-Leibler divergence, and P∥Q is the JSD between distribution P and distribution Q.

JSD is always non-negative, bounded between 0 and 1, and symmetric (JSD(P|Q) = JSD(Q|P)). It can be interpreted as a “smoothed” version of the KL divergence.

Kullback-Leibler divergence (KL divergence): Kullback-Leibler divergence, often called KL divergence or relative entropy, measures the difference between two probability distributions by quantifying “extra information” that is One distribution needs to be encoded using another distribution as a reference.

For two probability distributions P and Q, the KL divergence from Q to P is defined as: KL(P∥Q)=∑x P(x)log(Q(x)/P(x)). KL divergence is non-negative and asymmetric, that is, KL(P∥Q)≠KL(Q∥P). It is zero if and only if P and Q are equal. KL divergence is unbounded and can be used to measure the dissimilarity between distributions.

c17c0b13bc221d9f3a09bb6455b144a5.png

1-Lipschitz Continity

1- The lipschitz function is a function whose absolute value of slope is bounded by 1. This means that for any two inputs x and y, the difference between the function outputs is no more than the difference between the inputs.

Mathematically a function f is 1-Lipschitz if for all x and y within the domain of f the following inequality holds:

|f(x) - f(y)| <= |x - y|

Enforcing Lipschitz continuity in generative adversarial networks (GANs) is a technique used to stabilize training and prevent some of the problems associated with traditional GANs, such as mode collapse and training instability. The main way to achieve Lipschitz continuity in GANs is through the use of Lipschitz constraints or regularization. One commonly used method is Wasserstein GAN (WGAN).

In standard GAN, the discriminator (also called critic in WGAN) is trained to distinguish between real and fake data. In order to strengthen Lipschitz continuity, WGAN adds a constraint that the discriminator function should be Lipschitz continuous, which means that the gradient of the function should not grow too large. Mathematically, it is restricted to:

∥∣D(x)?D(y)∣≤K?∥x?y∥

where D(x) is the critic’s output for data point x, D(y) is the output of y, and K is the Lipschitz constant.

Weight clipping of WGAN: In the original WGAN, this constraint is enforced by clipping the weights of the discriminator network to a small range (e.g., [-0.01, 0.01]) after each training step. Weight clipping ensures that the gradient of the discriminator remains within a certain range and enforces Lipschitz continuity.

Gradient penalty for WGAN: A variant of WGAN, called WGAN-GP, uses gradient penalty instead of weight clipping to enforce Lipschitz constraints. WGAN-GP adds a penalty term to the loss function based on the gradient of the discriminator’s output with respect to a random point between real and fake data. This penalty encourages Lipschitz constraints without requiring weight clipping.

57816423fb5e06967de58b183c820103.png

Spectral norm

From a symbolic point of view, the spectral norm of a matrix is usually expressed as: For a neural network, the matrix represents a weight matrix in the network layer. The spectral norm of a matrix is the maximum singular value of the matrix, which can be obtained through singular value decomposition (SVD).

Singular value decomposition is a generalization of eigendecomposition and is used to decompose a matrix into

098c2eb399947932fc8f33fcabaf9934.png

Among them, q is an orthogonal matrix, and Σ is its singular value matrix on the diagonal. Note that Σ is not necessarily square.

7dadce5256f5cbd681851b745290f30a.png

Among them, 1 and are the maximum singular value and the minimum singular value respectively. Larger values correspond to a greater amount of stretching that one matrix can apply to another vector. According to this expression, ()=1.

Application of SVD in spectral normalization

To spectrally normalize the weight matrix, divide each value in the matrix by its spectral norm. The spectral normalization matrix can be expressed as

7c1aa303da76f038fffa015545054d8d.png

Computing SVD of is is very expensive, so the authors of the SN-GAN paper made some simplifications. They approximate the left and right singular vectors and through power iteration, respectively: )≈

43df634340f63ac97e0b411926a62321.png

Code implementation

Now we start using Pytorch to implement


import torch
 from torch import nn
 from tqdm.auto import tqdm
 from torchvision import transforms
 from torchvision.datasets import MNIST
 from torchvision.utils import make_grid
 from torch.utils.data import DataLoader
 import matplotlib.pyplot as plt
 torch.manual_seed(0)
 
 def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

Builder:

class Generator(nn.Module):
      def __init__(self,z_dim=10,im_chan = 1,hidden_dim = 64):
          super(Generatoe,self).__init__()
          self.gen = nn.Sequential(
          self.make_gen_block(z_dim,hidden_dim * 4),
          self.make_gen_block(hidden_dim*4,hidden_dim * 2,kernel_size = 4, stride =1),
          self.make_gen_block(hidden_dim * 2,hidden_dim),
          self.make_gen_block(hidden_dim,im_chan,kernel_size=4,final_layer = True),
          )
    def make_gen_block(self,input_channels,output_channels,kernel_size=3,stride=2,final_layer = False):
          if not final_layer :
            return nn.Sequential(nn.ConvTranspose2D(input_layer,output_layer,kernel_size,stride),
                    nn.BatchNorm2d(output_channels),
                    nn.ReLU(inplace = True),
                      )
          else:
            return nn.Sequential(nn.ConvTranspose2D(input_layer,output_layer,kernel_size,stride),
                    nn.Tanh(),)
  def unsqueeze_noise():
        return noise.view(len(noise), self.z_dim, 1, 1)
  def forward(self,noise):
      x = self.unsqueeze_noise(noise)
      return self.gen(x)
 def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

frequency discriminator

For the discriminator, we can use spectral_norm for each Conv2D. In addition to , , , and other parameters are also introduced, so that the binary binary operators of can be calculated at runtime: , y, y, y, y

Because Pytorch also provides nn.utils.spectral_norm, nn.utils.remove_spectral_norm functions, it is very convenient for us to operate.

0a95e6a722673d752b42bbd7b6b65cff.png

We only apply nn.utils.remove_spectral_norm to convolutional layers during inference to improve runtime speed.

It is worth noting that the spectral norm does not eliminate the need for the batch norm. The spectral norm affects the weight of each layer, and the batch norm affects the activation of each layer.


class Discriminator(nn.Module):
      def __init__(self, im_chan=1, hidden_dim=16):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2),
            self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
        )
      def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.utils.spectral_norm(nn.Conv2d(input_channels, output_channels, kernel_size, stride)),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.utils.spectral_norm(nn.Conv2d(input_channels, output_channels, kernel_size, stride)),
            )
    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

Training

We use the MNIST data set here, and the bcewithlogitsloss() function calculates the binary cross-entropy loss between logit and target labels. Binary cross-entropy loss is a measure of how different two distributions are. In binary classification, these two distributions are the logical distribution and the target label distribution.


criterion = nn.BCEWithLogitsLoss()
 n_epochs = 50
 z_dim = 64
 display_step = 500
 batch_size = 128
 # A learning rate of 0.0002 works well on DCGAN
 lr = 0.0002
 beta_1 = 0.5
 beta_2 = 0.999
 device = 'cuda'
 transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
 ])
 
 dataloader = DataLoader(
    MNIST(".", download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True)

Create generator and discriminator


gen = Generator(z_dim).to(device)
 gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
 disc = Discriminator().to(device)
 disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))
 
 # initialize the weights to the normal distribution
 # with mean 0 and standard deviation 0.02
 def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
 gen = gen.apply(weights_init)
 disc = disc.apply(weights_init)

The following are the training steps


cur_step = 0
 mean_generator_loss = 0
 mean_discriminator_loss = 0
 for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)
 
        ## Update Discriminator ##
        disc_opt.zero_grad()
        fake_noise = get_noise(cur_batch_size, z_dim, device=device)
        fake = gen(fake_noise)
        disc_fake_pred = disc(fake.detach())
        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        disc_real_pred = disc(real)
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2
 
        # Keep track of the average discriminator loss
        mean_discriminator_loss + = disc_loss.item() / display_step
        # Update gradients
        disc_loss.backward(retain_graph=True)
        # Update optimizer
        disc_opt.step()
 
        ## Update Generator ##
        gen_opt.zero_grad()
        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
        fake_2 = gen(fake_noise_2)
        disc_fake_pred = disc(fake_2)
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
        gen_loss.backward()
        gen_opt.step()
 
        # Keep track of the average generator loss
        mean_generator_loss + = gen_loss.item() / display_step
 
        ## Visualization code ##
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step + = 1

The training results are as follows:

fa6cf0a98176f5dd481e1dd4ac78ea63.png

fcf66c51ef82691c1e7f55f2eeda69b5.png


Summary

In this article, we introduce the principle and simple code implementation of SN-GAN. SN-GAN has been widely used in image generation tasks, including image synthesis, style transfer and super-resolution. It has achieved significant results in improving the performance and stability of generative models, so learning his code will be more helpful for our understanding.

Editor: Wen Jing

3223b6802089b1869c8a31d55454391e.png