This article is the learning record blog in the 365-day deep learning training camp
Reference article: 365-day deep learning training camp – Week G7: Semi-Supervised GAN theory and practice (readable by internal members of the training camp)
Original author: K student | Tutoring, project customization
Operating environment:
Computer system: Windows 10
Locale: python 3.10
Compiler: Pycharm 2022.1.1
Deep learning environment: Pytorch
Table of Contents
1. Explanation of theoretical knowledge
2. Code implementation
1. Configuration code
2. Initialize weights
3. Define algorithm model
4. Configure the model
5. Training model
1. Explanation of theoretical knowledge
This algorithm extends the Generative Adversarial Network (GAN) to semi-supervised learning by forcing the discriminator D to output category labels. us
Train a generator G and a discriminator D on a data set, and the input is one of N categories. During training, the discriminator D is used to predict which of the N + 1 categories the input belongs to. This N + 1 corresponds to the output of the generator G. The discriminator here
D also acts as a classifier C. This method can be used to train a better discriminator D and can produce higher quality samples than ordinary GANs. Semi-Supervised GAN has the following advantages:
(1) The author made a new extension to GANs, allowing it to learn a generative model and a classifier simultaneously. We call this extension semi-supervised GAN or SGAN
(2) The experimental results of the paper show that SGAN improves classification performance in limited data sets compared with the baseline classifier without generation part.
(3) The experimental results of the paper show that SGAN can significantly improve the quality of generated samples and reduce the training time of the generator.
2. Code Implementation
1. Configuration code
import argparse import os import numpy as np import math import torchvision.transforms as transforms from torchvision.utils import save_image from torch.utils.data import DataLoader from torchvision import datasets from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F import torch os.makedirs("images", exist_ok=True) parser = argparse.ArgumentParser() parser.add_argument("--n_epochs", type=int, default=2, help="number of epochs of training") parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation") parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") parser.add_argument("--num_classes", type=int, default=10, help="number of classes for dataset") parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension") parser.add_argument("--channels", type=int, default=1, help="number of image channels") parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling") opt = parser.parse_args(args=[]) print(opt) cuda = True if torch.cuda.is_available() else False
Namespace(n_epochs=2, batch_size=64, lr=0.0002, b1=0.5, b2=0.999, n_cpu=2, latent_dim=100, num_classes=10, img_size=32, channels=1, sample_interval=400)< /pre> </blockquote> <h5 id=" 2. Initialization weight"> 2. Initialization weight</h5> <pre>def weights_init_normal(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm") != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0)3. Define algorithm model
class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.label_emb = nn.Embedding(opt.num_classes, opt.latent_dim) self.init_size = opt.img_size // 4 # Initial size before upsampling self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2)) self.conv_blocks = nn.Sequential( nn.BatchNorm2d(128), nn.Upsample(scale_factor=2), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, opt.channels, 3, stride=1, padding=1), nn.Tanh(), ) def forward(self, noise): out = self.l1(noise) out = out.view(out.shape[0], 128, self.init_size, self.init_size) img = self.conv_blocks(out) return img class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() def discriminator_block(in_filters, out_filters, bn=True): """Returns layers of each discriminator block""" block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)] ifbn: block.append(nn.BatchNorm2d(out_filters, 0.8)) return block self.conv_blocks = nn.Sequential( *discriminator_block(opt.channels, 16, bn=False), *discriminator_block(16, 32), *discriminator_block(32, 64), *discriminator_block(64, 128), ) # The height and width of downsampled image ds_size = opt.img_size // 2 ** 4 #Output layers self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid()) self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.num_classes + 1), nn.Softmax()) def forward(self, img): out = self.conv_blocks(img) out = out.view(out.shape[0], -1) validity = self.adv_layer(out) label = self.aux_layer(out) return validity, label4. Configuration model
# Loss functions adversarial_loss = torch.nn.BCELoss() auxiliary_loss = torch.nn.CrossEntropyLoss() # Initialize generator and discriminator generator = Generator() discriminator = Discriminator() if cuda: generator.cuda() discriminator.cuda() adversarial_loss.cuda() auxiliary_loss.cuda() #Initialize weights generator.apply(weights_init_normal) discriminator.apply(weights_init_normal) # Configure data loader os.makedirs("../../data/mnist", exist_ok=True) dataloader = torch.utils.data.DataLoader( datasets.MNIST( "../../data/mnist", train=True, download=True, transform=transforms.Compose( [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] ), ), batch_size=opt.batch_size, shuffle=True, ) #Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensorDownloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw\train-images-idx3-ubyte. gzExtracting ../../data/mnist\MNIST\raw\train-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw\train-labels-idx1-ubyte. gzExtracting ../../data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw\t10k-images-idx3-ubyte. gzExtracting ../../data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw\t10k-labels-idx1-ubyte. gzExtracting ../../data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw5. Training model
# ---------- #Training # ---------- for epoch in range(opt.n_epochs): for i, (imgs, labels) in enumerate(dataloader): batch_size = imgs.shape[0] #Adversarial ground truths valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False) fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False) fake_aux_gt = Variable(LongTensor(batch_size).fill_(opt.num_classes), requires_grad=False) # Configure input real_imgs = Variable(imgs.type(FloatTensor)) labels = Variable(labels.type(LongTensor)) # ----------------- #TrainGenerator # ----------------- optimizer_G.zero_grad() # Sample noise and labels as generator input z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim)))) #Generate a batch of images gen_imgs = generator(z) # Loss measures generator's ability to fool the discriminator validity, _ = discriminator(gen_imgs) g_loss = adversarial_loss(validity, valid) g_loss.backward() optimizer_G.step() #------------------------ #TrainDiscriminator #------------------------ optimizer_D.zero_grad() #Loss for real images real_pred, real_aux = discriminator(real_imgs) d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2 #Loss for fake images fake_pred, fake_aux = discriminator(gen_imgs.detach()) d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2 # Total discriminator loss d_loss = (d_real_loss + d_fake_loss) / 2 # Calculate discriminator accuracy pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0) gt = np.concatenate([labels.data.cpu().numpy(), fake_aux_gt.data.cpu().numpy()], axis=0) d_acc = np.mean(np.argmax(pred, axis=1) == gt) d_loss.backward() optimizer_D.step() batches_done = epoch * len(dataloader) + i if batches_done % opt.sample_interval == 0: save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) print( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]" % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item()) )[Epoch 0/2] [Batch 937/938] [D loss: 1.358861, acc: 50%] [G loss: 0.671799] [Epoch 1/2] [Batch 937/938] [D loss: 1.343094, acc: 50%] [G loss: 0.681119]The knowledge points of the article match the official knowledge files, and you can further learn relevant knowledge. Python entry skill treeArtificial intelligenceDeep learning 386,407 people are learning the system