Build a CAM class activation heat map from 0 based on the pytorch framework

Since ResNet and other cnns that end with average pooling and a fully connected layer can use CAM without modifying the network structure, this article is based on AlexNet. Using Kaggle’s cat and dog dataset, the code is as follows:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from import Dataset
from PIL import Image
import glob
from import DataLoader
import numpy as np
from torchvision.models import alexnet
from datetime import datetime
import time
import os

class train_ImageDataset(Dataset):
    def __init__(self, root):

        # Transforms for low resolution images and high resolution images
        self. transform = transforms. Compose(
                transforms.Resize((256, 256)),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])# Normalized

        self.files = sorted(glob.glob(root + "/*.*"))

    def __getitem__(self, index):
        path = self.files[index % len(self.files)]
        label = path[path.index(".")-3:path.index(".")]
        img =[index % len(self.files)])
        img = self. transform(img)
        #print(img. shape)
        if label == "cat":
            label = [0, 1]
        if label == "dog":
            label = [1, 0]
        label = torch. Tensor(label)
        return img, label
    def __len__(self):
        return len(self. files)

class test_ImageDataset(Dataset):
    def __init__(self, root):

        # Transforms for low resolution images and high resolution images
        self. transform = transforms. Compose(
                transforms.Resize((256, 256)),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])# Normalized

        self.files = sorted(glob.glob(root + "/*.*"))

    def __getitem__(self, index):
        path = self.files[index % len(self.files)]
        label = path[path.index(".")-3:path.index(".")]
        img =[index % len(self.files)])
        img = self. transform(img)
        #print(img. shape)
        if label == "cat":
            label = [0, 1]
        if label == "dog":
            label = [1, 0]
        label = torch. Tensor(label)
        return img, label
    def __len__(self):
        return len(self. files)

def train(model, device, train_loader, optimizer, epoch, loss):
    start_epoch = time. time()
    model. train()
    error_num = 0
    sum_num = 0
    sum_loss = 0
    for i, (input, label) in enumerate(train_loader):
        input =
        target =
        optimizer. zero_grad()
        output = model(input)
        Loss = loss(output, target)
        Loss. backward()
        optimizer. step()
        pred = output.max(1, keepdim=True)[1].squeeze(1).cpu().numpy() # Find the subscript with the highest probability
        target = target.max(1, keepdim=True)[1].detach().squeeze(1).cpu().numpy()
        error_num += np.sum(np.abs(pred - target))
        sum_num + = pred.shape[0]
        sum_loss += Loss.item()
    acc = 1.0 - error_num / sum_num
    loss_avg = sum_loss / len(train_loader)
    print("Train:[Epoch %d] [Loss: %f] [Acc: %f]" % (epoch, loss_avg, acc))
    end_time = time. time()
    times = end_time - start_epoch
    if epoch % 5 == 0:
        with open(Log_txt, "a") as f:
            f.write("Train:[Iterations %d] [Loss: %f] [Acc: %f] [Epoch Time: %f]\
" % (epoch, loss_avg, acc, times))
# test
def test(model, device, test_loader, epoch, loss):
    with torch.no_grad():
        error_num = 0
        sum_num = 0
        sum_loss = 0
        for input, label in test_loader:
            input =
            target =
            optimizer. zero_grad()
            output = model(input)
            Loss = loss(output, target)
            pred = output.max(1, keepdim=True)[1].squeeze(1).cpu().numpy() # Find the subscript with the highest probability
            target = target.max(1, keepdim=True)[1].detach().squeeze(1).cpu().numpy()
            error_num += np.sum(np.abs(pred-target))
            sum_num + = pred.shape[0]
            sum_loss += Loss.item()
        acc = 1.0 - error_num / sum_num
        loss_avg = sum_loss / len(test_loader)
        print("Test:[Epoch %d] [Loss: %f] [Acc: %f]"% (epoch, loss_avg, acc))
        if epoch % 5 == 0:
            with open(Log_txt, "a") as f:
                f.write("Test:[Iterations %d] [Loss: %f] [Acc: %f]\
"% (epoch, loss_avg, acc))

class cnn(nn.Module):
    def __init__(self):
        super(cnn, self).__init__()
        self.model = alexnet(pretrained=True)
        self.encoder = nn.Sequential(*list(self.model.children())[0]) #Only take the convolution layer, if you use vgg16, just change alexnet to vgg16
        self.avg = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(256,2)
    def forward(self, x):
        x = self.encoder(x) #batchsize*256*7*7
        x = self.avg(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

if __name__ == '__main__':
    os.makedirs("saved_model", exist_ok=True)
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    Log_txt = "./log/" + str(" ","_").replace(":","_") + "_result.txt"
    start_all_time = time. time()
    train_loader = DataLoader(
    test_loader = DataLoader(
    model = cnn().to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    loss = torch.nn.CrossEntropyLoss().to(DEVICE)
    EPOCH = 50
    for epoch in range(1, EPOCH + 1):
        train(model, DEVICE, train_loader, optimizer, epoch, loss)
        test(model, DEVICE, test_loader, epoch, loss), './saved_model/AlexNet_oriCAM_CAT & amp;DOG_iteration_' + str(epoch) + '.pth') # Save the model after training
    end_all_time = time. time()
    time_all = end_all_time - start_all_time
    with open(Log_txt, "a") as f:
        f.write("All Time: %f" % (time_all))

If you need to output class activation heatmap, the code is as follows:

import os
from import Dataset

import glob
from import DataLoader
import pandas as pd

import cv2
import torch
import torch.nn as nn
from torchvision.models import alexnet
from torchvision import datasets, transforms
import numpy as np
from PIL import Image
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

class test_ImageDataset(Dataset):
    def __init__(self, root):

        # Transforms for low resolution images and high resolution images
        self. transform = transforms. Compose(
                transforms.Resize((256, 256)),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])# Normalized

        self.files = sorted(glob.glob(root + "/*.*"))

    def __getitem__(self, index):
        path = self.files[index % len(self.files)]
        label = path[path.index(".")-3:path.index(".")]
        img =[index % len(self.files)])
        img = self. transform(img)
        #print(img. shape)
        if label == "cat":
            label = [0, 1]
        if label == "dog":
            label = [1, 0]
        label = torch. Tensor(label)
        return img, label
    def __len__(self):
        return len(self. files)

def denormalize(tensors):
    """ Denormalizes image tensors using mean and std """
    for c in range(3):
    return torch.clamp(tensors, 0, 255)

class cnn(nn.Module):
    def __init__(self):
        super(cnn, self).__init__()
        self.model = alexnet(pretrained=True)
        self.encoder = nn.Sequential(*list(self.model.children())[0])
        self.avg = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(256,2)
    def forward(self, x):

        xx = self.encoder(x) #batchsize*256*7*7
        x = self.avg(xx)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x, xx

# test
def test(model, device, test_loader):
    for name, param in model.named_parameters():
        if name=="fc.weight":
            w = param.detach().cpu().numpy()
    with torch.no_grad():
        index = 1
        for data, target in test_loader:
            data_de = denormalize(data.squeeze()) #Prevent subsequent data from becoming cuda
            data, target =,
            true_output, featuremap = model(data)
            featuremap = featuremap.detach().squeeze().cpu().numpy()
            pred = true_output.max(1, keepdim=True)[1].squeeze(1).cpu().numpy() # Find the subscript with the highest probability
            target = target.max(1, keepdim=True)[1].detach().squeeze(1).cpu().numpy()
            bo = (pred == target)
            dir = "./oriCAM_output/test/%d_attentionMap_%d_%s" % (index, int(pred[0]), str(bo))
            os.makedirs(dir, exist_ok=True)

            ori__img = np.array(data_de)[::-1, :, :].transpose(1, 2, 0) * 255.0
            ori_img = cv2.resize(ori__img, (500, 500))
            original_img = ori_img.astype(np.uint8)
            cv2.imwrite("%s/%d_ORIImg.jpg" % (dir, index), original_img)

            prediction = int(pred[0])
            weight = w[prediction,:]
            activation_img = np.zeros((featuremap.shape[1],featuremap.shape[2]))
            for num2 in range(featuremap.shape[0]):
                feature_img = featuremap[num2, :, :] * weight[num2]
                activation_img += feature_img
            activation_img = (activation_img - np.min(activation_img)) / (
                        np.max(activation_img) - np.min(activation_img)) * 255.0
            activation_img = cv2.resize(activation_img, (500, 500))
            activation_img = activation_img.astype(np.uint8)
            image = cv2.applyColorMap(activation_img, cv2.COLORMAP_JET)
            add_img = cv2. addWeighted(image, 0.7, original_img, 0.3, 0)
            cv2.imwrite("%s/%d_AttImg.jpg" % (dir, index), image)
            cv2.imwrite("%s/%d_AddImg.jpg" % (dir, index), add_img)
            index += 1

if __name__ == '__main__':
    BATCH_SIZE = 1 # About 2G of video memory is required
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Let torch determine whether to use GPU, it is recommended to use GPU environment, because it will be much faster
    # Download the test set
    test_loader = DataLoader(
    model = cnn().to(DEVICE)
    model.load_state_dict(torch.load("./saved_model/AlexNet_oriCAM_CAT & amp;DOG_iteration_50.pth"))
    test(model, DEVICE, test_loader)

It can be seen from the above picture that the effect is still good~