Table of Contents
- 1. Model saving and loading Saving & amp; Loading Model
-
- 1. Reason
- 2. Serialization and deserialization
- 3. PyTorch serialization and deserialization
- 4. Model saving
- 5. Model loading
- 2. Continuous training of model segments
-
- 1. Reason
- 2. Parameters of model saving
- 3. Breakpoint training
- 3. Reference
1. Saving & amp; Loading Model
The saving and loading of models can also be called serialization and deserialization.
1. Reason
The trained model is for more convenient use in the future. The trained model is stored in the memory, and the data in the memory generally does not have this long-term storage function, but the hard disk can store data for a long time. So after training the model, we need to transfer the model from the memory to the hard disk for long-term storage.
2. Serialization and deserialization
Serialization and deserialization mainly describe a conversion relationship between the memory and the hard disk. The trained model is stored in the form of an object in the memory, and stored in the form of a binary sequence in the hard disk.
- Serialization (pickling): The process of converting model objects into binary data and storing it on disk.
- Deserialization (unpickling): The process of storing binary serialized data stored on the hard disk in the form of model objects in memory again.
The two operations corresponding to each other are as follows:
3. PyTorch serialization and deserialization
- Serialization:
// serialization // Main parameters: obj-object, f-output path torch.save
Description:
Object: the data you want to save, such as model, data, etc.;
Output path: Specify the path of the hard disk.
- Deserialize:
// Deserialization // Main parameters: f-file path, map-location: specify storage location, cpu or gpu torch.load
Description:
map-location: When the model is saved by the GPU, it cannot be directly loaded, and the map-location needs to be set; it is directly loaded when the CPU is saved.
4. Save the model
- Module data structure
Description: There are 8 ordered dictionaries in the Module to manage a series of parameters and some attributes. The purpose of saving the model is to continue to use it next time. The parameters obtained after model training are a series of learnable parameters, and another method is to save only the learnable parameters, that is, some parameters obtained after training, and build the model next time Finally, the saved learnable parameters can be loaded into the new model, which completes the saving and loading of the model. - Save the entire Module
// Save the entire Module torch. save(net, path)
Advantages: Save the entire net, no need to consider which parameters to save
Disadvantages: takes up memory and takes time
- save model parameters
// state_dict() saves the learnable parameters in the model and returns the dictionary form state_dict = net. state_dict() torch. save(state_dict, path)
Description: The official recommended method.
- Example demonstration:
Run themodel_save.py
file
# -*- coding: utf-8 -*- """ # @brief : Save the model """ import torch import numpy as np import torch.nn as nn from tools.common_tools2 import set_seed class LeNet2(nn.Module): def __init__(self, classes): super(LeNet2, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 6, 5), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(2, 2) ) self.classifier = nn.Sequential( nn.Linear(16 * 5 * 5, 120), nn.ReLU(), nn. Linear(120, 84), nn.ReLU(), nn. Linear(84, classes) ) def forward(self, x): x = self. features(x) x = x.view(x.size()[0], -1) x = self. classifier(x) return x def initialize(self): for p in self.parameters(): p.data.fill_(20191104) net = LeNet2(classes=2019) # "train" print("Before training: ", net.features[0].weight[0, ...]) net.initialize() print("After training: ", net.features[0].weight[0, ...]) # Set the path to save the entire model and save model parameters path_model = "./model.pkl" path_state_dict = "./model_state_dict.pkl" # save the entire model torch.save(net, path_model) # Save the model parameters, call the state_dict() method to get the model parameters net_state_dict = net.state_dict() torch.save(net_state_dict, path_state_dict)
After running:
Two files are saved, one is to save the entire model; the other is to save the learnable parameters in the model.
5. Model loading
- load the whole model
# load the entire model path_model = "./model.pkl" net_load = torch.load(path_model) print(net_load)
- Load model parameters
# load model parameters path_state_dict = "./model_state_dict.pkl" state_dict_load = torch. load(path_state_dict) print(state_dict_load. keys()) net_new = LeNet2(classes=2019) net_new.load_state_dict(state_dict_load)
- Example demo:
Runmodel_load.py
file
# -*- coding: utf-8 -*- """ # @file name : model_load.py # @brief : Model loading """ import torch import numpy as np import torch.nn as nn from tools.common_tools import set_seed class LeNet2(nn.Module): def __init__(self, classes): super(LeNet2, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 6, 5), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(2, 2) ) self.classifier = nn.Sequential( nn.Linear(16*5*5, 120), nn.ReLU(), nn. Linear(120, 84), nn.ReLU(), nn. Linear(84, classes) ) def forward(self, x): x = self. features(x) x = x.view(x.size()[0], -1) x = self. classifier(x) return x def initialize(self): for p in self.parameters(): p.data.fill_(20191104) # ==================================== load net ============== =============== # flag = 1 flag = 0 if flag: # Read the path and load the entire saved model path_model = "./model.pkl" net_load = torch.load(path_model) print(net_load) # ==================================== load state_dict ============== =============== flag = 1 # flag = 0 if flag: path_state_dict = "./model_state_dict.pkl" state_dict_load = torch. load(path_state_dict) print(state_dict_load. keys()) # ===================================== update state_dict ============== =============== flag = 1 # flag = 0 if flag: # Need to create a new model (LeNet2) with the same structure as the saved model (parameters) net_new = LeNet2(classes=2019) # Get the key value of the model parameter dictionary (such as 'features.0.weight', 'features.0.bias'...etc.) print("Before loading: ", net_new.features[0].weight[0, ...]) # Load state_dict_load and put it in the new model net_new.load_state_dict(state_dict_load) print("After loading: ", net_new.features[0].weight[0, ...])
After running:
-
Loading the model: (You can single-step to print out the structure of the model.)
-
print key value:
Description:
features.0, 3-weights and biases of the first and second convolutional layers;
classifier.0 – Weights and biases for fully connected layers. -
Print the model parameters of “before loading” and “after loading”: (can be viewed in Debug)
Description: The weights of the convolutional layers of the saved model are all 20191104, which proves that the model is successfully saved and loaded.
2. Model segment point continuation training
1. Reason
For some reason, such as power failure, large model, etc., the model training terminated unexpectedly. The model power-off continuous training can ensure that after the model training is interrupted, the checkpoint (interruption point) can continue to train without retraining. Therefore, model parameters need to be saved during model training.
2. Parameters for saving the model
The data and loss function are constant during the training process, and the parameters in the model and optimizer (momentum optimizer needs to use previous information to continuously update the current value) will continue to change with iterations. Therefore, the parameters that need to be saved are: parameters in the model, parameters in the optimizer, and Epoch (or number of iterations).
- Code snippet to save parameters:
# Model parameters that need to be saved checkpoint = {<!-- --> "model_state_dict": net. state_dict(), "optimizer_state_dict": optimizer. state_dict(), "epoch": epoch path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch) torch.save(checkpoint, path_checkpoint) }
Description: Checkpoint should not be written in iteration, but should be written in epoch loop.
- Example demo:
Runsave_checkpoint.py
file
# -*- coding: utf-8 -*- """ # @file name : save_checkpoint.py # @brief : Simulation training stopped unexpectedly """ import os import random import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import torchvision.transforms as transforms import torch.optim as optim from PIL import Image from matplotlib import pyplot as plt from model.lenet import LeNet from tools.my_dataset import RMBDataset from tools.common_tools import set_seed import torchvision set_seed(1) # set random seed rmb_label = {<!-- -->"1": 0, "100": 1} # parameter settings checkpoint_interval = 5 # Save every 5 EPOCH MAX_EPOCH = 10 #A total of 10 EPOCH BATCH_SIZE=16 LR = 0.01 log_interval = 10 val_interval = 1 # ============================= step 1/5 data ================= ============= split_dir = os.path.join("..", "..", "data", "rmb_split") train_dir = os.path.join(split_dir, "train") valid_dir = os.path.join(split_dir, "valid") norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] train_transform = transforms. Compose([ transforms.Resize((32, 32)), transforms.RandomCrop(32, padding=4), transforms.RandomGrayscale(p=0.8), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) valid_transform = transforms. Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) # Build MyDataset instance train_data = RMBDataset(data_dir=train_dir, transform=train_transform) valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform) # build DataLoder train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) #shuffle=False means the data is not disturbed, it is recommended to be True valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE) # ============================== step 2/5 model ================= ============= net = LeNet(classes=2) net.initialize_weights() # ============================== step 3/5 loss function ================ ============== criterion = nn.CrossEntropyLoss() # choose loss function # ============================= step 4/5 optimizer ================ ============== optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # select optimizer scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1) # Set learning rate drop strategy # ============================= step 5/5 training ================= ============= train_curve = list() valid_curve = list() start_epoch = -1 for epoch in range(start_epoch + 1, MAX_EPOCH): loss_mean = 0. correct = 0. total = 0. net. train() for i, data in enumerate(train_loader): #forward inputs, labels = data outputs = net(inputs) #backward optimizer. zero_grad() loss = criterion(outputs, labels) loss. backward() # update weights optimizer. step() # Statistical classification _, predicted = torch.max(outputs.data, 1) total + = labels. size(0) correct + = (predicted == labels). squeeze(). sum(). numpy() # print training information loss_mean += loss.item() train_curve.append(loss.item()) if (i + 1) % log_interval == 0: loss_mean = loss_mean / log_interval print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2 %}". format( epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total)) loss_mean = 0. scheduler.step() # update learning rate if (epoch + 1) % checkpoint_interval == 0: checkpoint = {<!-- -->"model_state_dict": net.state_dict(), "optimizer_state_dict": optimizer. state_dict(), "epoch": epoch} path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch) torch.save(checkpoint, path_checkpoint) # If epoch is interrupted after the 5th if epoch > 5: print("Training was interrupted unexpectedly...") break # validate the model if (epoch + 1) % val_interval == 0: correct_val = 0. total_val = 0. loss_val = 0. net.eval() with torch.no_grad(): for j, data in enumerate(valid_loader): inputs, labels = data outputs = net(inputs) loss = criterion(outputs, labels) _, predicted = torch.max(outputs.data, 1) total_val + = labels. size(0) correct_val += (predicted == labels).squeeze().sum().numpy() loss_val += loss.item() valid_curve.append(loss.item()) print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{: .2%}".format( epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val/len(valid_loader), correct / total)) train_x = range(len(train_curve)) train_y = train_curve train_iters = len(train_loader) valid_x = np.arange(1, len(valid_curve) + 1) * train_iters*val_interval # Since the epochloss is recorded in valid, the record points need to be converted to iterations valid_y = valid_curve plt.plot(train_x, train_y, label='Train') plt.plot(valid_x, valid_y, label='Valid') plt. legend(loc='upper right') plt.ylabel('loss value') plt. xlabel('Iteration') plt. show()
After running:
- The interrupt is at the 6th epoch, and the stored parameters are at the 5th epoch:
Explanation: epoch starts from 0, so the 5th epoch=checkpoint_4_epoch.pkl; the interruption is at the 6th epoch=5, so the start of continuation training epoch=5.
- Save file for the 5th epoch:
3. Breakpoint training
In the continuous training code, the built model, data, loss function, and optimizer are all the same, and only the trained data needs to be loaded to the corresponding location.
- Code snippet to save parameters:
# breakpoint recovery path_checkpoint = "./checkpoint_4_epoch.pkl"#Restored file path checkpoint = torch.load(path_checkpoint)#load file net.load_state_dict(checkpoint['model_state_dict'])#Restore model parameters optimizer.load_state_dict(checkpoint['optimizer_state_dict'])#restore optimizer parameters start_epoch = checkpoint['epoch']#Set the epoch to be restored scheduler.last_epoch = start_epoch#Set learning rate
- Example demo:
Runsave_resume.py
file
# -*- coding: utf-8 -*- """ # @file name : checkpoint_resume.py # @brief : Simulation training stopped unexpectedly """ import os import random import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import torchvision.transforms as transforms import torch.optim as optim from PIL import Image from matplotlib import pyplot as plt from model.lenet import LeNet from tools.my_dataset import RMBDataset from tools.common_tools import set_seed import torchvision set_seed(1) # set random seed rmb_label = {<!-- -->"1": 0, "100": 1} # parameter settings checkpoint_interval = 5 MAX_EPOCH = 10 BATCH_SIZE=16 LR = 0.01 log_interval = 10 val_interval = 1 # ============================= step 1/5 data ================= ============= split_dir = os.path.join("..", "..", "data", "rmb_split") train_dir = os.path.join(split_dir, "train") valid_dir = os.path.join(split_dir, "valid") norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] train_transform = transforms. Compose([ transforms.Resize((32, 32)), transforms.RandomCrop(32, padding=4), transforms.RandomGrayscale(p=0.8), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) valid_transform = transforms. Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) # Build MyDataset instance train_data = RMBDataset(data_dir=train_dir, transform=train_transform) valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform) # build DataLoder train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) #shuffle=False means the data is not disturbed, it is recommended to be True valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE) # ============================== step 2/5 model ================= ============= net = LeNet(classes=2) net.initialize_weights() # ============================== step 3/5 loss function ================ ============== criterion = nn.CrossEntropyLoss() # choose loss function # ============================= step 4/5 optimizer ================ ============== optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # select optimizer scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1) # Set learning rate drop strategy # ============================= step 5 + /5 breakpoint recovery ============= ================ path_checkpoint = "./checkpoint_4_epoch.pkl" checkpoint = torch.load(path_checkpoint) net.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] scheduler.last_epoch = start_epoch # ============================= step 5/5 training ================= ============= train_curve = list() valid_curve = list() for epoch in range(start_epoch + 1, MAX_EPOCH): loss_mean = 0. correct = 0. total = 0. net. train() for i, data in enumerate(train_loader): #forward inputs, labels = data outputs = net(inputs) #backward optimizer. zero_grad() loss = criterion(outputs, labels) loss. backward() # update weights optimizer. step() # Statistical classification _, predicted = torch.max(outputs.data, 1) total + = labels. size(0) correct + = (predicted == labels). squeeze(). sum(). numpy() # print training information loss_mean += loss.item() train_curve.append(loss.item()) if (i + 1) % log_interval == 0: loss_mean = loss_mean / log_interval print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2 %}". format( epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total)) loss_mean = 0. scheduler.step() # update learning rate if (epoch + 1) % checkpoint_interval == 0: checkpoint = {<!-- -->"model_state_dict": net.state_dict(), "optimizer_state_dic": optimizer. state_dict(), "loss": loss, "epoch": epoch} path_checkpoint = "./checkpint_{}_epoch.pkl".format(epoch) torch.save(checkpoint, path_checkpoint) # if epoch > 5: # print("Training was interrupted unexpectedly...") #break # validate the model if (epoch + 1) % val_interval == 0: correct_val = 0. total_val = 0. loss_val = 0. net.eval() with torch.no_grad(): for j, data in enumerate(valid_loader): inputs, labels = data outputs = net(inputs) loss = criterion(outputs, labels) _, predicted = torch.max(outputs.data, 1) total_val + = labels. size(0) correct_val += (predicted == labels).squeeze().sum().numpy() loss_val += loss.item() valid_curve.append(loss.item()) print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{: .2%}".format( epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val/len(valid_loader), correct / total)) train_x = range(len(train_curve)) train_y = train_curve train_iters = len(train_loader) valid_x = np.arange(1, len(valid_curve) + 1) * train_iters*val_interval # Since the epochloss is recorded in valid, the record points need to be converted to iterations valid_y = valid_curve plt.plot(train_x, train_y, label='Train') plt.plot(valid_x, valid_y, label='Valid') plt. legend(loc='upper right') plt.ylabel('loss value') plt. xlabel('Iteration') plt. show()
After running:
- Load the saved parameter file checkpoint_4_epoch.pkl and start training from the 6th epoch:
Explanation: epoch starts from 0, so the fifth epoch=checkpoint_4_epoch.pkl.
- Comparison of loss curves:
Explanation:SHUFFLE=False in the data loader DataLoder before the breakpoint, the data is not disrupted, which affects the optimization, resulting in a high loss function of about 3.5.
Suggestion:Change DataLoder in save_checkpoint.py and checkpoint_resume.py to SHUFFLE=True.
3. Reference
[1] [Eye of Depth] [Pytorch Check-in Day 15]: Model saving and loading
[2] Pytorch series – model saving and loading, finetune
[3] [25] Deep learning Pytorch-model saving and loading, breakpoint continuous training
[4] [Eye of Depth] Pytorch Framework Class Fifth Phase – Model Saving and Loading Code Analysis
[5] 07-01- Model saving and loading