Unexpected key(s) in state_dict: “module.backbone.bn1.num_batches_tracked“

Unexpected key(s) in state_dict: “module.backbone.bn1.num_batches_tracked”

When using PyTorch for deep learning model training and inference, we often use ??state_dict?? to save and load model parameters. However, sometimes when we try to load the saved state_dict, we may encounter the error “Unexpected key(s) in state_dict” and indicate the error key name. This article explains the causes and solutions to this error.

Error reason

When we try to load model parameters, the key names in ??state_dict?? must exactly match the key names in the current model. If there is a mismatch, an ??Unexpected key(s) in state_dict?? error will occur. This error is usually caused by the following reasons:

  1. Model structure changes: When we modify the structure of the model (such as adding, deleting or modifying certain layers), the key names of the model will also change. If you use the old ??state_dict?? to load a new model, there will be a mismatch in key names, leading to errors.
  2. Key name prefix caused by multi-GPU training: When using multiple GPUs for model training, PyTorch will automatically add the prefix ? to the model's ??state_dict?? ?module.?? to indicate that the model parameters come from different GPUs. If we use the state_dict trained on a single GPU to load a multi-GPU model, there will be a key name mismatch.

Solution

Here are a few possible solutions:

1. Use the model’s??state_dict??attribute name matching function

In PyTorch, you can use the ??.keys()?? method of the model’s ??state_dict?? attribute to view all key names of the current model. We can then compare the saved ??state_dict?? with the key names of the current model to find mismatched key names and modify them. Here is a sample code:

pythonCopy code# Load the saved state_dict
saved_state_dict = torch.load('model.pth')
# View the state_dict key name of the current model
model = YourModel()
current_state_dict = model.state_dict()
print("Current model keys:", current_state_dict.keys())
# Modify mismatched key names
for key in list(saved_state_dict.keys()):
    if key not in current_state_dict:
        new_key = key.replace("module.", "") # Remove multi-GPU prefix
        saved_state_dict[new_key] = saved_state_dict.pop(key)
#Load the modified state_dict
model.load_state_dict(saved_state_dict)

2. Modify the model code to adapt to the saved??state_dict??

If we modify the structure of the model, we can modify the model’s code to match the saved ??state_dict?? format. Before loading the model, you can adjust the structure of the model to be the same as the ??state_dict??structure.

3. Use??torch.nn.DataParallel?? to load the model

If the model is packaged using ??torch.nn.DataParallel??, we can use ??model = torch.nn.DataParallel(model)?? to load the model . In this way, the model can automatically handle key name problems caused by multi-GPU training.

pythonCopy codemodel = YourModel()
model = torch.nn.DataParallel(model) #Load model
model.load_state_dict(torch.load('model.pth')) # Load state_dict

Summary

When loading a saved ??state_dict??, an ??Unexpected key(s) in state_dict?? error is usually caused by a mismatch in key names. We can find mismatched keys by looking at the key names of the model and the key names of the saved state_dict and modify them accordingly. In addition, using ??torch.nn.DataParallel?? to wrap the model can solve the key name prefix problem caused by multi-GPU training. I hope this article can help you solve the ??Unexpected key(s) in state_dict?? error and load model parameters smoothly.

Example code

Suppose we have an image classification model that identifies cats and dogs. We first trained a model and saved its state_dict into the “model.pth” file. Then, we modified the structure of the model, added a new fully connected layer, and hoped to be able to load the previously saved ??state_dict??. First, we define a model class??AnimalClassifier??, which contains a convolutional neural network and a fully connected layer:

pythonCopy codeimport torch
import torch.nn as nn
class AnimalClassifier(nn.Module):
    def __init__(self):
        super(AnimalClassifier, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(64 * 16 * 16, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 2)
        )
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

Then, we trained the model and saved the ??state_dict??:

pythonCopy code# Create model instance
model = AnimalClassifier()
#Train model...
#...
# Save state_dict
torch.save(model.state_dict(), 'model.pth')

Next, we modified the structure of the model and added a new ReLU layer after the fully connected layer:

pythonCopy codeimport torch
import torch.nn as nn
class AnimalClassifier(nn.Module):
    def __init__(self):
        super(AnimalClassifier, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(64 * 16 * 16, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 2),
            nn.ReLU(inplace=True) # Add a new ReLU layer
        )
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

Now, we want to be able to load the previously saved state_dict and continue training a new model. We can load ??state_dict?? and solve the key name mismatch problem through the following code:

pythonCopy codeimport torch
import torch.nn as nn
class AnimalClassifier(nn.Module):
    def __init__(self):
        super(AnimalClassifier, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(64 * 16 * 16, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 2),
            nn.ReLU(inplace=True) # Add a new ReLU layer
        )
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
#Create a new model instance
model = AnimalClassifier()
#Load the saved state_dict
saved_state_dict = torch.load('model.pth')
# View the state_dict key name of the current model
current_state_dict = model.state_dict()
print("Current model keys:", current_state_dict.keys())
# Modify mismatched key names
for key in list(saved_state_dict.keys()):
    if key not in current_state_dict:
        new_key = key.replace("classifier.", "classifier.3.") # Modify unmatched key names
        saved_state_dict[new_key] = saved_state_dict.pop(key)
#Load the modified state_dict
model.load_state_dict(saved_state_dict)
# Continue training the new model...
# ...

Through the above code, we successfully loaded the previously saved ??state_dict?? and continued to train the new model while solving the key name mismatch problem.

??state_dict? is a dictionary object used in PyTorch to save and load model parameters. It contains tensors of all the learnable parameters of the model (such as the weights and biases of the neural network) and other relevant parameters (such as the state of the optimizer), but does not include the structure of the model. The structure of ?state_dict?? is as follows:

plaintextCopy code{
    'key1': tensor1,
    'key2': tensor2,
    ...
}

Among them, ‘key’ is a string corresponding to the name of each parameter in the model; ‘tensor’ is the tensor corresponding to the parameter. The ??state_dict?? of the saved model can be obtained by calling the model’s ??state_dict()?? method:

pythonCopy codemodel = MyModel()
...
state_dict = model.state_dict()
torch.save(state_dict, 'model.pth')

The ??state_dict?? of the loaded model can be loaded by calling the ??torch.load()?? function:

pythonCopy codestate_dict = torch.load('model.pth')
model = MyModel()
model.load_state_dict(state_dict)

??state_dict?? is used in the following common scenarios:

  1. Saving and Loading Models: By saving and loading ??state_dict??, you can save a model’s parameters to a file and reload the parameters when needed.
  2. Migration learning and fine-tuning of the model: The ??state_dict?? of the pre-trained model can be loaded into the corresponding layer of the new model, thereby using the parameters of the pre-trained model to speed up the training of the new model or improve its performance. .
  3. Sharing and copying of model parameters: You can copy the state_dict of one model to another model to share or reuse parameters.
  4. Saving and loading optimizer state: The state information of the optimizer (such as momentum, learning rate decay, etc.) is usually also stored in the model’s ??state_dict?? and can be saved and loaded together. It should be noted that when loading??state_dict??, the structure of the model should be completely consistent with the structure when saving, otherwise loading failure or errors may occur.

The knowledge points of the article match the official knowledge files, and you can further learn relevant knowledge. Python entry skill treeHomepageOverview 386778 people are learning the system