GAN adversarial neural network

Preface: Because the project requires learning GAN for image repair, the author wrote a blog like this to learn about GAN in depth against the God General Network.

1. Introduction to GAN

GAN (Generative Adversarial Network) is a deep learning model consisting of a pair of neural networks: a generator and a discriminator.

The purpose of the generator is to learn a distribution of generated data that takes random noise as input and produces an output similar to real data. The purpose of the discriminator is to learn to distinguish the data generated by the generator from real data. The generator and discriminator compete against each other during training, with the generator trying to generate increasingly realistic data to fool the discriminator, and the discriminator trying to identify which data is real and which is generated.

During the training process, the generator and the discriminator compete with each other. By repeatedly adjusting the parameters, the generator can generate high-quality data and trick the discriminator into difficulty in judging which data is real. GAN has been widely used in image generation, speech synthesis, natural language processing and other fields, and has produced many interesting applications, such as DeepFake, etc.

2. Main components of GAN network

The main components of a GAN network include:

  1. Generator: A generator is a neural network model that receives random noise as input and outputs some data similar to the original data. The goal of the generator is to simulate the distribution of the original data as accurately as possible.

  2. Discriminator: The discriminator is also a neural network model. Its input is a set of data and it outputs a probability value to determine whether the data is real data or data generated by the generator. The goal of the discriminator is to differentiate between real data and data generated by the generator as accurately as possible.

  3. Loss Function: The loss function of the GAN network consists of two parts. One part is the loss function of the generator, which measures the difference between the data generated by the generator and the real data; the other part is the loss function of the discriminator, which measures whether the probability predicted by the discriminator is accurate.

  4. Optimizer: GAN network uses the back propagation algorithm to train the model. The function of the optimizer is to update the parameters of the model based on the results of the loss function, so that the generator and discriminator are continuously optimized, making the data generated by the generator more realistic. Let the discriminator better distinguish between real data and data generated by the generator.

  5. Dataset: GAN network requires a large amount of data for training. The quality and quantity of the training data set have a great impact on the performance of the GAN network. Typically, the dataset should contain real data and labels and be compared to the generator’s output data.

3. Simplify network structure

Author here I hand-drawn the network structure while studying. Here the author will outline the overall structure. First of all, we must make it clear that there are two types of data under the premise of using this network, one is real data, and the other is fake data. Then the discriminator D should recognize the real data as much as possible, and the generator should try its best to let the fake data be recognized by D. Identified as real data. In fact, in general, it is a game relationship between the two. The first is the generator (G). Its function is actually to accept the noise Z and use it to generate data similar to the real data. Note: the size of the data generated by the generator is consistent with the size of the data in the real data. Then there is the discriminator (D). In the simplest GAN, what is done here is a two-classification task, which is to determine authenticity.

In order to explain gan more vividly, the author gives an example. Suppose you want to buy a famous watch, but you have never bought a famous watch before. It is difficult for you to judge the authenticity of the watch, and the experience of buying famous watches can prevent you from being deceived by profiteers. When you start to say that most famous watches are marked as fake (after being deceived), sellers start producing high imitation watches. Then you go buy a watch. The two compete with each other, your experience is increasing, and the seller’s counterfeiting experience is also improving. In the end, the things generated by the generator are as close as possible to the real things.

4. Various components

4.1 Generator

The generator of GAN is mainly used to generate data similar to real data. Specifically, the generator receives a random noise vector as input and generates some data through a neural network that has the same characteristics, distribution, and patterns as the real data. The goal of the generator is to gradually learn the distribution and pattern of real data during the training process, thereby generating data similar to real data. The generated results of the generator will be sent to the discriminator for judgment to determine whether it has the characteristics and distribution of real data. If it is judged to be real data, the generator has achieved the expected goal. During the training process, the generator’s goal is to optimize the generated data so that it approximates the distribution of real data as closely as possible, making it difficult for the discriminator to distinguish between real data and generated data.

The author simply wrote a D network here

class Generator(nn.Module):
    def __init__(self):#Initialization
        super(Generator, self).__init__()#What is done here is a rewriting operation
        self.fc1 = nn.Linear(100, 256)#Fully connected layer#Input 100 dimensions, output 256 dimensions
        self.fc2 = nn.Linear(256, 512)#Fully connected layer#Input 256 dimensions, output 512 dimensions
        self.fc3 = nn.Linear(512, 1024)#Fully connected layer#Input 512 dimensions, output 1024 dimensions
        self.fc4 = nn.Linear(1024, 28*28)#Fully connected layer#Input 1024 dimensions, output 28*28 dimensions
        self.relu = nn.ReLU()#Activation function, the relu function is used here to increase the nonlinearity of the network
        self.tanh = nn.Tanh()#Activation function, the tanh function is used here to increase the nonlinearity of the network
    def forward(self, x):#Forward propagation
        x = self.relu(self.fc1(x))#Input x, go through the fully connected layer, and then go through the relu activation function
        x = self.relu(self.fc2(x))#Input x, go through the fully connected layer, and then go through the relu activation function
        x = self.relu(self.fc3(x))#Input x, go through the fully connected layer, and then go through the relu activation function
        return self.tanh(self.fc4(x))#Input x, go through the fully connected layer, and then go through the tanh activation function
    

The specific values in it depend on the parameters of your image.

4.2 Discriminator

The discriminator of GAN is used to determine whether the data generated by the generator is real data or fake data. Specifically, the discriminator receives a data as input and determines whether it is real data through a neural network. The goal of the discriminator is to gradually learn the distribution and pattern of real data during the training process, so that it can distinguish real data from data generated by the generator. During the training process, the goal of the discriminator is to maximize the probability of correctly classifying real data and incorrectly classifying generated data, so that the data generated by the generator is closer to the distribution of real data. The generator and discriminator of GAN compete with each other. The generator continuously optimizes the generated data to make it closer to the distribution of real data, while the discriminator continuously learns the distribution and pattern of real data, so that it can more accurately distinguish between real data and generated data. Relying on the game between the two parties and continuously optimizing the training objectives of the generator and the discriminator, the data generated by the generator can eventually be made close to the distribution of real data.

class Discriminator(nn.Module):
    def __init__(self):#Initialization
        super(Discriminator, self).__init__()#What is done here is a rewriting operation
        self.fc1 = nn.Linear(28*28, 1024)#Fully connected layer#Input 28*28 dimensions, output 1024 dimensions
        self.fc2 = nn.Linear(1024, 512)#Fully connected layer#Input 1024 dimensions, output 512 dimensions
        self.fc3 = nn.Linear(512, 256)#Fully connected layer#Input 512 dimensions, output 256 dimensions
        self.fc4 = nn.Linear(256, 1)#Fully connected layer#Input 256 dimensions, output 1 dimension
        self.relu = nn.ReLU()#Activation function, the relu function is used here to increase the nonlinearity of the network
        self.sigmoid = nn.Sigmoid()#Activation function, the sigmoid function is used here to increase the nonlinearity of the network
    def forward(self, x):#Forward propagation
        x = self.relu(self.fc1(x))#Input x, go through the fully connected layer, and then go through the relu activation function
        x = self.relu(self.fc2(x))#Input x, go through the fully connected layer, and then go through the relu activation function
        x = self.relu(self.fc3(x))#Input x, go through the fully connected layer, and then go through the relu activation function
        return self.sigmoid(self.fc4(x))#Input x, through the fully connected layer, and then through the sigmoid activation function. The output result is a probability between 0-1. The sigomid function is used for binary classification, above 0.5 is 1, and below 0.5 is 0

The author also wrote a class about the discriminator. Finally, it is output as a one-dimensional vector, in my example.

4.3 Testing and training functions

Here the author’s knowledge is briefly explained and what the testing and training functions do.

# Training function
def train(G, D, G_optimizer, D_optimizer, loss_func, train_loader, epoch):
    G_losses = []#Record the loss of the generator
    D_losses = []#Record the loss of the discriminator
    for step, (x, y) in enumerate(train_loader):#Enumerate data
        b_x = x.view(-1, 28*28)#Convert data to 28*28 dimensions
        b_y = y#tag
        b_z = torch.randn((x.shape[0], 100))#Generate random noise
        G_result = G(b_z)#The result generated by the generator
        D_real = D(b_x)#Real data judged by the discriminator
        D_fake = D(G_result)#generated data determined by the discriminator
        D_real_loss = loss_func(D_real, torch.ones_like(D_real))#The discriminator determines the loss of real data
        D_fake_loss = loss_func(D_fake, torch.zeros_like(D_fake))#The discriminator determines the loss of the generated data
        D_loss = D_real_loss + D_fake_loss#loss of the discriminator
        D_optimizer.zero_grad()#Discriminator gradient clear
        D_loss.backward()#Backward propagation
        D_optimizer.step()#Discriminator gradient descent
        G_result = G(b_z)#The result generated by the generator
        D_fake = D(G_result)#generated data determined by the discriminator
        G_loss = loss_func(D_fake, torch.ones_like(D_fake))#generator’s loss
        G_optimizer.zero_grad()#Generator gradient reset
        G_loss.backward()#Backward propagation
        G_optimizer.step()#Generator gradient descent
        G_losses.append(G_loss.item())#Record the loss of the generator
        D_losses.append(D_loss.item())#Record the loss of the discriminator
        if step % 100 == 0:#Output the result every 100 times
            print('Epoch: ', epoch, '| Step: ', step, '| G loss: ', G_loss.item(), '| D loss: ', D_loss.item() )
    #Save model
    torch.save(G.state_dict(), './model/G.pth')
    torch.save(D.state_dict(), './model/D.pth')
    return G_losses, D_losses#Return the loss of the generator and discriminator

4.4 Main function (add the loss function)

Because the author did not specify the loader and loss function of the data set, I agree to put them in the main function. The loss function basically uses BCELoss.

if __name__ == '__main__':
    G = Generator()#Generator
    D = Discriminator()#Discriminator
    G_optimizer = Adam(G.parameters(), lr=0.0001)#Generator optimizer
    D_optimizer = Adam(D.parameters(), lr=0.0001)#Discriminator optimizer
    loss_func = nn.BCELoss()#loss function
    train_loader = Data.DataLoader(dataset=torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=True), batch_size=64, shuffle= True)#training data
    for epoch in range(10):#Train for 10 rounds
        G_losses, D_losses = train(G, D, G_optimizer, D_optimizer, loss_func, train_loader, epoch)#training
        test(G, epoch)#test

Add something here

BCELoss stands for Binary Cross Entropy Loss (Binary Cross Entropy Loss), which is used for binary classification problems. For each data point, BCELoss calculates the cross-entropy loss between the true label (0 or 1) and the model-predicted label (a probability value between 0 and 1). Specifically, for each data point, the formula for BCELoss is as follows:

BCELOSS\left ( o,t \right ) =- \frac{1}{n}\sum(ti*logoi + (1-log)*(1-oi)))

5. Summary

Here is a brief introduction to some of the structures and components of gan. The author cannot use this model to write anything yet. I hope that next time we complete the data set, the author can present a GAN network for image repair, and everything written here The code is processed with reference to the data type of the handwritten data set, which is the MINST data set. However, what the author of this article wants to explain is some of the composition and introduction of GAN adversarial neural networks, so he does not give better examples. He is just studying the meaning of some parameters and what should be done.

It’s rare that the author has started writing articles to improve himself again. Please support him! ! !