[PyTorch Framework] Model saving, loading and continuous training at breakpoints

Table of Contents

  • 1. Model saving and loading Saving & amp; Loading Model
    • 1. Reason
    • 2. Serialization and deserialization
    • 3. PyTorch serialization and deserialization
    • 4. Model saving
    • 5. Model loading
  • 2. Continuous training of model segments
    • 1. Reason
    • 2. Parameters of model saving
    • 3. Breakpoint training
  • 3. Reference

1. Saving & amp; Loading Model

The saving and loading of models can also be called serialization and deserialization.

1. Reason

The trained model is for more convenient use in the future. The trained model is stored in the memory, and the data in the memory generally does not have this long-term storage function, but the hard disk can store data for a long time. So after training the model, we need to transfer the model from the memory to the hard disk for long-term storage.

2. Serialization and deserialization

Serialization and deserialization mainly describe a conversion relationship between the memory and the hard disk. The trained model is stored in the form of an object in the memory, and stored in the form of a binary sequence in the hard disk.

  • Serialization (pickling): The process of converting model objects into binary data and storing it on disk.
  • Deserialization (unpickling): The process of storing binary serialized data stored on the hard disk in the form of model objects in memory again.

The two operations corresponding to each other are as follows:

3. PyTorch serialization and deserialization

  • Serialization:
// serialization
// Main parameters: obj-object, f-output path
torch.save

Description:
Object: the data you want to save, such as model, data, etc.;
Output path: Specify the path of the hard disk.

  • Deserialize:
// Deserialization
// Main parameters: f-file path, map-location: specify storage location, cpu or gpu
torch.load

Description:
map-location: When the model is saved by the GPU, it cannot be directly loaded, and the map-location needs to be set; it is directly loaded when the CPU is saved.

4. Save the model

  • Module data structure

    Description: There are 8 ordered dictionaries in the Module to manage a series of parameters and some attributes. The purpose of saving the model is to continue to use it next time. The parameters obtained after model training are a series of learnable parameters, and another method is to save only the learnable parameters, that is, some parameters obtained after training, and build the model next time Finally, the saved learnable parameters can be loaded into the new model, which completes the saving and loading of the model.
  • Save the entire Module
// Save the entire Module
torch. save(net, path)

Advantages: Save the entire net, no need to consider which parameters to save
Disadvantages: takes up memory and takes time

  • save model parameters
// state_dict() saves the learnable parameters in the model and returns the dictionary form
state_dict = net. state_dict()
torch. save(state_dict, path)

Description: The official recommended method.

  • Example demonstration:
    Run the model_save.py file
# -*- coding: utf-8 -*-
"""
# @brief : Save the model
"""
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools2 import set_seed


class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn. Linear(120, 84),
            nn.ReLU(),
            nn. Linear(84, classes)
        )

    def forward(self, x):
        x = self. features(x)
        x = x.view(x.size()[0], -1)
        x = self. classifier(x)
        return x

    def initialize(self):
        for p in self.parameters():
            p.data.fill_(20191104)


net = LeNet2(classes=2019)

# "train"
print("Before training: ", net.features[0].weight[0, ...])
net.initialize()
print("After training: ", net.features[0].weight[0, ...])

# Set the path to save the entire model and save model parameters
path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"

# save the entire model
torch.save(net, path_model)

# Save the model parameters, call the state_dict() method to get the model parameters
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)


After running:
Two files are saved, one is to save the entire model; the other is to save the learnable parameters in the model.

5. Model loading

  • load the whole model
# load the entire model
path_model = "./model.pkl"
net_load = torch.load(path_model)

print(net_load)
  • Load model parameters
# load model parameters
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch. load(path_state_dict)

print(state_dict_load. keys())

net_new = LeNet2(classes=2019)
net_new.load_state_dict(state_dict_load)
  • Example demo:
    Run model_load.py file
# -*- coding: utf-8 -*-
"""
# @file name : model_load.py
# @brief : Model loading
"""
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed


class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn. Linear(120, 84),
            nn.ReLU(),
            nn. Linear(84, classes)
        )

    def forward(self, x):
        x = self. features(x)
        x = x.view(x.size()[0], -1)
        x = self. classifier(x)
        return x

    def initialize(self):
        for p in self.parameters():
            p.data.fill_(20191104)


# ==================================== load net ============== ===============
# flag = 1
flag = 0
if flag:
    # Read the path and load the entire saved model
    path_model = "./model.pkl"
    net_load = torch.load(path_model)

    print(net_load)

# ==================================== load state_dict ============== ===============

flag = 1
# flag = 0
if flag:

    path_state_dict = "./model_state_dict.pkl"
    state_dict_load = torch. load(path_state_dict)

    print(state_dict_load. keys())

# ===================================== update state_dict ============== ===============
flag = 1
# flag = 0
if flag:
    # Need to create a new model (LeNet2) with the same structure as the saved model (parameters)
    net_new = LeNet2(classes=2019)
    # Get the key value of the model parameter dictionary (such as 'features.0.weight', 'features.0.bias'...etc.)
    print("Before loading: ", net_new.features[0].weight[0, ...])
    # Load state_dict_load and put it in the new model
    net_new.load_state_dict(state_dict_load)
    print("After loading: ", net_new.features[0].weight[0, ...])

After running:

  • Loading the model: (You can single-step to print out the structure of the model.)

  • print key value:


    Description:
    features.0, 3-weights and biases of the first and second convolutional layers;
    classifier.0 – Weights and biases for fully connected layers.

  • Print the model parameters of “before loading” and “after loading”: (can be viewed in Debug)
    Description: The weights of the convolutional layers of the saved model are all 20191104, which proves that the model is successfully saved and loaded.

2. Model segment point continuation training

1. Reason

For some reason, such as power failure, large model, etc., the model training terminated unexpectedly. The model power-off continuous training can ensure that after the model training is interrupted, the checkpoint (interruption point) can continue to train without retraining. Therefore, model parameters need to be saved during model training.

2. Parameters for saving the model

The data and loss function are constant during the training process, and the parameters in the model and optimizer (momentum optimizer needs to use previous information to continuously update the current value) will continue to change with iterations. Therefore, the parameters that need to be saved are: parameters in the model, parameters in the optimizer, and Epoch (or number of iterations).

  • Code snippet to save parameters:
# Model parameters that need to be saved
checkpoint = {<!-- -->
    "model_state_dict": net. state_dict(),
    "optimizer_state_dict": optimizer. state_dict(),
    "epoch": epoch
    path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)
}

Description: Checkpoint should not be written in iteration, but should be written in epoch loop.

  • Example demo:
    Run save_checkpoint.py file
# -*- coding: utf-8 -*-
"""
# @file name : save_checkpoint.py
# @brief : Simulation training stopped unexpectedly
"""
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from PIL import Image
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed
import torchvision


set_seed(1) # set random seed
rmb_label = {<!-- -->"1": 0, "100": 1}

# parameter settings
checkpoint_interval = 5 # Save every 5 EPOCH
MAX_EPOCH = 10 #A total of 10 EPOCH
BATCH_SIZE=16
LR = 0.01
log_interval = 10
val_interval = 1


# ============================= step 1/5 data ================= =============

split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms. Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.8),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms. Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# Build MyDataset instance
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# build DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) #shuffle=False means the data is not disturbed, it is recommended to be True
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================== step 2/5 model ================= =============

net = LeNet(classes=2)
net.initialize_weights()

# ============================== step 3/5 loss function ================ ==============
criterion = nn.CrossEntropyLoss() # choose loss function

# ============================= step 4/5 optimizer ================ ==============
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # select optimizer
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1) # Set learning rate drop strategy

# ============================= step 5/5 training ================= =============
train_curve = list()
valid_curve = list()

start_epoch = -1
for epoch in range(start_epoch + 1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net. train()
    for i, data in enumerate(train_loader):

        #forward
        inputs, labels = data
        outputs = net(inputs)

        #backward
        optimizer. zero_grad()
        loss = criterion(outputs, labels)
        loss. backward()

        # update weights
        optimizer. step()

        # Statistical classification
        _, predicted = torch.max(outputs.data, 1)
        total + = labels. size(0)
        correct + = (predicted == labels). squeeze(). sum(). numpy()

        # print training information
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i + 1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2 %}". format(
                epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step() # update learning rate

    if (epoch + 1) % checkpoint_interval == 0:

        checkpoint = {<!-- -->"model_state_dict": net.state_dict(),
                      "optimizer_state_dict": optimizer. state_dict(),
                      "epoch": epoch}
        path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)
    # If epoch is interrupted after the 5th
    if epoch > 5:
        print("Training was interrupted unexpectedly...")
        break

    # validate the model
    if (epoch + 1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val + = labels. size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss.item())
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{: .2%}".format(
                epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val/len(valid_loader), correct / total))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve) + 1) * train_iters*val_interval # Since the epochloss is recorded in valid, the record points need to be converted to iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt. legend(loc='upper right')
plt.ylabel('loss value')
plt. xlabel('Iteration')
plt. show()

After running:

  • The interrupt is at the 6th epoch, and the stored parameters are at the 5th epoch:
    Explanation: epoch starts from 0, so the 5th epoch=checkpoint_4_epoch.pkl; the interruption is at the 6th epoch=5, so the start of continuation training epoch=5.
  • Save file for the 5th epoch:

3. Breakpoint training

In the continuous training code, the built model, data, loss function, and optimizer are all the same, and only the trained data needs to be loaded to the corresponding location.

  • Code snippet to save parameters:
# breakpoint recovery
path_checkpoint = "./checkpoint_4_epoch.pkl"#Restored file path
checkpoint = torch.load(path_checkpoint)#load file

net.load_state_dict(checkpoint['model_state_dict'])#Restore model parameters

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])#restore optimizer parameters

start_epoch = checkpoint['epoch']#Set the epoch to be restored

scheduler.last_epoch = start_epoch#Set learning rate
  • Example demo:
    Run save_resume.py file
# -*- coding: utf-8 -*-
"""
# @file name : checkpoint_resume.py
# @brief : Simulation training stopped unexpectedly
"""
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from PIL import Image
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed
import torchvision


set_seed(1) # set random seed
rmb_label = {<!-- -->"1": 0, "100": 1}

# parameter settings
checkpoint_interval = 5
MAX_EPOCH = 10
BATCH_SIZE=16
LR = 0.01
log_interval = 10
val_interval = 1


# ============================= step 1/5 data ================= =============

split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms. Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.8),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms. Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# Build MyDataset instance
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# build DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) #shuffle=False means the data is not disturbed, it is recommended to be True
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================== step 2/5 model ================= =============

net = LeNet(classes=2)
net.initialize_weights()

# ============================== step 3/5 loss function ================ ==============
criterion = nn.CrossEntropyLoss() # choose loss function

# ============================= step 4/5 optimizer ================ ==============
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # select optimizer
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1) # Set learning rate drop strategy


# ============================= step 5 + /5 breakpoint recovery ============= ================

path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

start_epoch = checkpoint['epoch']

scheduler.last_epoch = start_epoch

# ============================= step 5/5 training ================= =============
train_curve = list()
valid_curve = list()

for epoch in range(start_epoch + 1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net. train()
    for i, data in enumerate(train_loader):

        #forward
        inputs, labels = data
        outputs = net(inputs)

        #backward
        optimizer. zero_grad()
        loss = criterion(outputs, labels)
        loss. backward()

        # update weights
        optimizer. step()

        # Statistical classification
        _, predicted = torch.max(outputs.data, 1)
        total + = labels. size(0)
        correct + = (predicted == labels). squeeze(). sum(). numpy()

        # print training information
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i + 1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2 %}". format(
                epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step() # update learning rate

    if (epoch + 1) % checkpoint_interval == 0:

        checkpoint = {<!-- -->"model_state_dict": net.state_dict(),
                      "optimizer_state_dic": optimizer. state_dict(),
                      "loss": loss,
                      "epoch": epoch}
        path_checkpoint = "./checkpint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)

    # if epoch > 5:
    # print("Training was interrupted unexpectedly...")
    #break

    # validate the model
    if (epoch + 1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val + = labels. size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss.item())
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{: .2%}".format(
                epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val/len(valid_loader), correct / total))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve) + 1) * train_iters*val_interval # Since the epochloss is recorded in valid, the record points need to be converted to iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt. legend(loc='upper right')
plt.ylabel('loss value')
plt. xlabel('Iteration')
plt. show()

After running:

  • Load the saved parameter file checkpoint_4_epoch.pkl and start training from the 6th epoch:
    Explanation: epoch starts from 0, so the fifth epoch=checkpoint_4_epoch.pkl.

  • Comparison of loss curves:

    Explanation: SHUFFLE=False in the data loader DataLoder before the breakpoint, the data is not disrupted, which affects the optimization, resulting in a high loss function of about 3.5.
    Suggestion: Change DataLoder in save_checkpoint.py and checkpoint_resume.py to SHUFFLE=True.

3. Reference

[1] [Eye of Depth] [Pytorch Check-in Day 15]: Model saving and loading
[2] Pytorch series – model saving and loading, finetune
[3] [25] Deep learning Pytorch-model saving and loading, breakpoint continuous training
[4] [Eye of Depth] Pytorch Framework Class Fifth Phase – Model Saving and Loading Code Analysis
[5] 07-01- Model saving and loading