This article introduces the use, modification, saving and loading of Pytorch models. Taking torchvision in image processing as an example, PyTorch provides more pre-trained models through the torchvision.models module.
Article directory
- Use and modification of network models
-
- VGG16 model use
- VGG16 model modification
- Saving and reading network models
-
- Save the model
- Loading the network model
The use and modification of network models
In image classification, Pytorch provides many models
import torchvision import warnings import torch warnings.filterwarnings("ignore")
VGG16 model usage
This article will take VGG16 as an example to demonstrate the specific operations of using and modifying existing models with Pytorch.
VGG16 is a classic convolutional neural network model proposed by the Visual Geometry Group of Oxford University and used to participate in the 2014 ImageNet Image Classification
.
The biggest feature of VGG is that it uses a 3×3 size convolution kernel to stack the neural network, which also deepens the depth of the entire neural network. These two important changes are also very helpful for people to redefine the convolutional neural network model architecture. At least they prove that using smaller convolution kernels and increasing the depth of the convolutional neural network can more effectively improve the performance of the model.
torchvision.models.vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any)
- weights (optional): Specify the pre-training weights to load. It can be None (default value) to indicate not to load the pre-training weights, or it can be specified as a pre-defined pre-training weight identifier.
- progress: Indicates the display settings of the download progress bar. The default is True to display the download progress bar.
- kwargs: other optional parameters, passed to the base class torchvision.models.VGG of the VGG-16 model
vgg16 = torchvision.models.vgg16(weights=True,progress=True)
print(vgg16)
VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace=True) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace=True) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace=True) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=1000, bias=True) ) )
It can be seen from the above running results that the VGG16 network is composed of 13 layers of convolutional layers and 3 layers of fully connected layers. In the end, the network outputs a total of 1,000 classification results.
VGG16 model modification
Modify the VGG16 model:
-
Take CIFAR10 as an example
-
Use the add_module() method to add a linear layer after the VGG16 model to output the 1000 categories of VGG16 into 10 categories similar to CIFAR10. The code is as follows:
import torchvision.models as models from torch import nn vgg16 = torchvision.models.vgg16(weights=True,progress=True) vgg16.add_module("add_linear", nn.Linear(1000, 10)) print(vgg16)
VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace=True) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace=True) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace=True) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=1000, bias=True) ) (add_linear): Linear(in_features=1000, out_features=10, bias=True) )
From the above, we can know that add_linear is outside the classifier. If it is inside the classifier, you can
vgg16.add_module("add_linear", nn.Linear(1000, 10))
Replace with
vgg16.classifier.add_module("add_linear", nn.Linear(1000, 10))
import torchvision.models as models from torch import nn vgg16 = torchvision.models.vgg16(weights=True,progress=True) vgg16.classifier.add_module("add_linear", nn.Linear(1000, 10)) print(vgg16)
VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace=True) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace=True) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace=True) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=1000, bias=True) (add_linear): Linear(in_features=1000, out_features=10, bias=True) ) )
You can also modify it directly, for example, 6): Linear(in_features=4096, out_features=1000, bias=True)
in classifier is directly modified to out_features=10
vgg16.classifier[6] = nn.Linear(in_features=4096,out_features=10,bias=True)
Saving and reading network models
Saving the model
torch.save(obj, f, pickle_protocol=DEFAULT_PROTOCOL)
Parameter | Description |
---|---|
obj: | The object to be saved can be a model, tensor, dictionary, etc. |
f: | The file path or file object to be saved |
pickle_protocol: | The version of the serialization protocol, the default is DEFAULT_PROTOCOL |
Method 1: Save the entire model, including all its related parameters, using
torch.save()
import torchvision vgg16 = torchvision.models.vgg16(weights=True, progress=True) torch.save(vgg16, "vgg16_model_true.pth") #pytorch generally saves models with the suffix .pth
Method 2: Only save the model parameters and use the
.state_dict()
method in the original vgg16 object (official recommendation)
import torchvision vgg16 = torchvision.models.vgg16(weights=True,progress=True) torch.save(vgg16.state_dict(), "vgg16_model_true_2.pth")
After successful operation, the corresponding files: vgg16_model_true.pth and vgg16_model_true_2.pth will be saved in the default path.
Using the second method, vgg16.state_dict()
will take up less space
Loading of network model
After the model is saved, you can use torch.load()
to load the saved model, which is a .pth file.
torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)
Parameter | Description |
---|---|
f | File path or file object to load |
map_location | Optional parameter used to specify on which device to load the model. If this parameter is not provided, it will be loaded into the current device by default |
pickle_module | Optional parameter used to specify the module used for deserialization. The default is pickle |
pickle_load_args | Other optional parameters for deserialization |