Preface: Because the project requires learning GAN for image repair, the author wrote a blog like this to learn about GAN in depth against the God General Network.
1. Introduction to GAN
GAN (Generative Adversarial Network) is a deep learning model consisting of a pair of neural networks: a generator and a discriminator.
The purpose of the generator is to learn a distribution of generated data that takes random noise as input and produces an output similar to real data. The purpose of the discriminator is to learn to distinguish the data generated by the generator from real data. The generator and discriminator compete against each other during training, with the generator trying to generate increasingly realistic data to fool the discriminator, and the discriminator trying to identify which data is real and which is generated.
During the training process, the generator and the discriminator compete with each other. By repeatedly adjusting the parameters, the generator can generate high-quality data and trick the discriminator into difficulty in judging which data is real. GAN has been widely used in image generation, speech synthesis, natural language processing and other fields, and has produced many interesting applications, such as DeepFake, etc.
2. Main components of GAN network
The main components of a GAN network include:
-
Generator: A generator is a neural network model that receives random noise as input and outputs some data similar to the original data. The goal of the generator is to simulate the distribution of the original data as accurately as possible.
-
Discriminator: The discriminator is also a neural network model. Its input is a set of data and it outputs a probability value to determine whether the data is real data or data generated by the generator. The goal of the discriminator is to differentiate between real data and data generated by the generator as accurately as possible.
-
Loss Function: The loss function of the GAN network consists of two parts. One part is the loss function of the generator, which measures the difference between the data generated by the generator and the real data; the other part is the loss function of the discriminator, which measures whether the probability predicted by the discriminator is accurate.
-
Optimizer: GAN network uses the back propagation algorithm to train the model. The function of the optimizer is to update the parameters of the model based on the results of the loss function, so that the generator and discriminator are continuously optimized, making the data generated by the generator more realistic. Let the discriminator better distinguish between real data and data generated by the generator.
-
Dataset: GAN network requires a large amount of data for training. The quality and quantity of the training data set have a great impact on the performance of the GAN network. Typically, the dataset should contain real data and labels and be compared to the generator’s output data.
3. Simplify network structure
Author here I hand-drawn the network structure while studying. Here the author will outline the overall structure. First of all, we must make it clear that there are two types of data under the premise of using this network, one is real data, and the other is fake data. Then the discriminator D should recognize the real data as much as possible, and the generator should try its best to let the fake data be recognized by D. Identified as real data. In fact, in general, it is a game relationship between the two. The first is the generator (G). Its function is actually to accept the noise Z and use it to generate data similar to the real data. Note: the size of the data generated by the generator is consistent with the size of the data in the real data. Then there is the discriminator (D). In the simplest GAN, what is done here is a two-classification task, which is to determine authenticity.
In order to explain gan more vividly, the author gives an example. Suppose you want to buy a famous watch, but you have never bought a famous watch before. It is difficult for you to judge the authenticity of the watch, and the experience of buying famous watches can prevent you from being deceived by profiteers. When you start to say that most famous watches are marked as fake (after being deceived), sellers start producing high imitation watches. Then you go buy a watch. The two compete with each other, your experience is increasing, and the seller’s counterfeiting experience is also improving. In the end, the things generated by the generator are as close as possible to the real things.
4. Various components
4.1 Generator
The generator of GAN is mainly used to generate data similar to real data. Specifically, the generator receives a random noise vector as input and generates some data through a neural network that has the same characteristics, distribution, and patterns as the real data. The goal of the generator is to gradually learn the distribution and pattern of real data during the training process, thereby generating data similar to real data. The generated results of the generator will be sent to the discriminator for judgment to determine whether it has the characteristics and distribution of real data. If it is judged to be real data, the generator has achieved the expected goal. During the training process, the generator’s goal is to optimize the generated data so that it approximates the distribution of real data as closely as possible, making it difficult for the discriminator to distinguish between real data and generated data.
The author simply wrote a D network here
class Generator(nn.Module): def __init__(self):#Initialization super(Generator, self).__init__()#What is done here is a rewriting operation self.fc1 = nn.Linear(100, 256)#Fully connected layer#Input 100 dimensions, output 256 dimensions self.fc2 = nn.Linear(256, 512)#Fully connected layer#Input 256 dimensions, output 512 dimensions self.fc3 = nn.Linear(512, 1024)#Fully connected layer#Input 512 dimensions, output 1024 dimensions self.fc4 = nn.Linear(1024, 28*28)#Fully connected layer#Input 1024 dimensions, output 28*28 dimensions self.relu = nn.ReLU()#Activation function, the relu function is used here to increase the nonlinearity of the network self.tanh = nn.Tanh()#Activation function, the tanh function is used here to increase the nonlinearity of the network def forward(self, x):#Forward propagation x = self.relu(self.fc1(x))#Input x, go through the fully connected layer, and then go through the relu activation function x = self.relu(self.fc2(x))#Input x, go through the fully connected layer, and then go through the relu activation function x = self.relu(self.fc3(x))#Input x, go through the fully connected layer, and then go through the relu activation function return self.tanh(self.fc4(x))#Input x, go through the fully connected layer, and then go through the tanh activation function
The specific values in it depend on the parameters of your image.
4.2 Discriminator
The discriminator of GAN is used to determine whether the data generated by the generator is real data or fake data. Specifically, the discriminator receives a data as input and determines whether it is real data through a neural network. The goal of the discriminator is to gradually learn the distribution and pattern of real data during the training process, so that it can distinguish real data from data generated by the generator. During the training process, the goal of the discriminator is to maximize the probability of correctly classifying real data and incorrectly classifying generated data, so that the data generated by the generator is closer to the distribution of real data. The generator and discriminator of GAN compete with each other. The generator continuously optimizes the generated data to make it closer to the distribution of real data, while the discriminator continuously learns the distribution and pattern of real data, so that it can more accurately distinguish between real data and generated data. Relying on the game between the two parties and continuously optimizing the training objectives of the generator and the discriminator, the data generated by the generator can eventually be made close to the distribution of real data.
class Discriminator(nn.Module): def __init__(self):#Initialization super(Discriminator, self).__init__()#What is done here is a rewriting operation self.fc1 = nn.Linear(28*28, 1024)#Fully connected layer#Input 28*28 dimensions, output 1024 dimensions self.fc2 = nn.Linear(1024, 512)#Fully connected layer#Input 1024 dimensions, output 512 dimensions self.fc3 = nn.Linear(512, 256)#Fully connected layer#Input 512 dimensions, output 256 dimensions self.fc4 = nn.Linear(256, 1)#Fully connected layer#Input 256 dimensions, output 1 dimension self.relu = nn.ReLU()#Activation function, the relu function is used here to increase the nonlinearity of the network self.sigmoid = nn.Sigmoid()#Activation function, the sigmoid function is used here to increase the nonlinearity of the network def forward(self, x):#Forward propagation x = self.relu(self.fc1(x))#Input x, go through the fully connected layer, and then go through the relu activation function x = self.relu(self.fc2(x))#Input x, go through the fully connected layer, and then go through the relu activation function x = self.relu(self.fc3(x))#Input x, go through the fully connected layer, and then go through the relu activation function return self.sigmoid(self.fc4(x))#Input x, through the fully connected layer, and then through the sigmoid activation function. The output result is a probability between 0-1. The sigomid function is used for binary classification, above 0.5 is 1, and below 0.5 is 0
The author also wrote a class about the discriminator. Finally, it is output as a one-dimensional vector, in my example.
4.3 Testing and training functions
Here the author’s knowledge is briefly explained and what the testing and training functions do.
# Training function def train(G, D, G_optimizer, D_optimizer, loss_func, train_loader, epoch): G_losses = []#Record the loss of the generator D_losses = []#Record the loss of the discriminator for step, (x, y) in enumerate(train_loader):#Enumerate data b_x = x.view(-1, 28*28)#Convert data to 28*28 dimensions b_y = y#tag b_z = torch.randn((x.shape[0], 100))#Generate random noise G_result = G(b_z)#The result generated by the generator D_real = D(b_x)#Real data judged by the discriminator D_fake = D(G_result)#generated data determined by the discriminator D_real_loss = loss_func(D_real, torch.ones_like(D_real))#The discriminator determines the loss of real data D_fake_loss = loss_func(D_fake, torch.zeros_like(D_fake))#The discriminator determines the loss of the generated data D_loss = D_real_loss + D_fake_loss#loss of the discriminator D_optimizer.zero_grad()#Discriminator gradient clear D_loss.backward()#Backward propagation D_optimizer.step()#Discriminator gradient descent G_result = G(b_z)#The result generated by the generator D_fake = D(G_result)#generated data determined by the discriminator G_loss = loss_func(D_fake, torch.ones_like(D_fake))#generator’s loss G_optimizer.zero_grad()#Generator gradient reset G_loss.backward()#Backward propagation G_optimizer.step()#Generator gradient descent G_losses.append(G_loss.item())#Record the loss of the generator D_losses.append(D_loss.item())#Record the loss of the discriminator if step % 100 == 0:#Output the result every 100 times print('Epoch: ', epoch, '| Step: ', step, '| G loss: ', G_loss.item(), '| D loss: ', D_loss.item() ) #Save model torch.save(G.state_dict(), './model/G.pth') torch.save(D.state_dict(), './model/D.pth') return G_losses, D_losses#Return the loss of the generator and discriminator
4.4 Main function (add the loss function)
Because the author did not specify the loader and loss function of the data set, I agree to put them in the main function. The loss function basically uses BCELoss.
if __name__ == '__main__': G = Generator()#Generator D = Discriminator()#Discriminator G_optimizer = Adam(G.parameters(), lr=0.0001)#Generator optimizer D_optimizer = Adam(D.parameters(), lr=0.0001)#Discriminator optimizer loss_func = nn.BCELoss()#loss function train_loader = Data.DataLoader(dataset=torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=True), batch_size=64, shuffle= True)#training data for epoch in range(10):#Train for 10 rounds G_losses, D_losses = train(G, D, G_optimizer, D_optimizer, loss_func, train_loader, epoch)#training test(G, epoch)#test
Add something here
BCELoss stands for Binary Cross Entropy Loss (Binary Cross Entropy Loss), which is used for binary classification problems. For each data point, BCELoss calculates the cross-entropy loss between the true label (0 or 1) and the model-predicted label (a probability value between 0 and 1). Specifically, for each data point, the formula for BCELoss is as follows:
5. Summary
Here is a brief introduction to some of the structures and components of gan. The author cannot use this model to write anything yet. I hope that next time we complete the data set, the author can present a GAN network for image repair, and everything written here The code is processed with reference to the data type of the handwritten data set, which is the MINST data set. However, what the author of this article wants to explain is some of the composition and introduction of GAN adversarial neural networks, so he does not give better examples. He is just studying the meaning of some parameters and what should be done.
It’s rare that the author has started writing articles to improve himself again. Please support him! ! !