pytorch TensorRT PQT, QAT + micro version (small video memory)

content struct

code

vgg

"""
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
    https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce


class VGG(nn.Module):
    def __init__(self, layer_spec, num_classes=1000, init_weights=False):
        super(VGG, self).__init__()

        layers = []
        in_channels = 3
        for l in layer_spec:
            if l == "pool":
                layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                layers + = [
                    nn.Conv2d(in_channels, l, kernel_size=3, padding=1),
                    nn.BatchNorm2d(l),
                    nn.ReLU(),
                ]
                in_channels = l

        self. features = nn. Sequential(*layers)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn. Linear(32, 10),
            # nn. Linear(512 * 1 * 1, 4096),
            #nn.ReLU(),
            # nn. Dropout(),
            # nn. Linear(4096, 4096),
            #nn.ReLU(),
            # nn. Dropout(),
            # nn. Linear(4096, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self. modules():
            if isinstance(m, nn. Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn. BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn. Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self. features(x)
        x = self.avgpool(x)
        x = torch. flatten(x, 1)
        x = self. classifier(x)
        return x


def vgg16(num_classes=1000, init_weights=False):
    vgg16_cfg = [
        32,
        32,
        "pool",
        32,
        32,
        "pool",
        32,
        32,
        32,
        "pool",
        32,
        32,
        32,
        "pool",
        32,
        32,
        32,
        "pool",
    ]
    return VGG(vgg16_cfg, num_classes, init_weights)

qat

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch_tensorrt

from torch.utils.tensorboard import SummaryWriter

import pytorch_quantization
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization import calib
from tqdm import tqdm

print(pytorch_quantization.__version__)

import os
import sys
# sys.path.insert(0, "../examples/int8/training/vgg16")
# print(sys.path)
from vgg import vgg16
#import torchvision
# vgg16=torchvision.models.vgg16()
print(vgg16)




#vgg16.classifier.add_module("add_linear",nn.Linear(1000,10)) #Add a layer to the classfier of vgg16
# vgg16_true.classifier[6] = nn.Linear(in_features=4096, out_features=10)


classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse' , 'ship', 'truck')

# =========== Define Training dataset and dataloaders ==============#
training_dataset = datasets.CIFAR10(root='./data',
                                        train=True,
                                        download=True,
                                        transform = transforms. Compose([
                                            transforms.RandomCrop(32, padding=4),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                        ]))

training_dataloader = torch.utils.data.DataLoader(training_dataset,
                                                      batch_size=32, # 32
                                                      shuffle=True,
                                                      num_workers=2)

# =========== Define Testing dataset and dataloaders ==============#
testing_dataset = datasets.CIFAR10(root='./data',
                                   train=False,
                                   download=True,
                                   transform = transforms. Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                   ]))

testing_dataloader = torch.utils.data.DataLoader(testing_dataset,
                                                 batch_size=16,
                                                 shuffle=False,
                                                 num_workers=2)


def train(model, dataloader, crit, opt, epoch):
    # global writer
    model. train()
    running_loss = 0.0
    for batch, (data, labels) in enumerate(dataloader):
        data, labels = data.cuda(), labels.cuda(non_blocking=True)
        opt.zero_grad()
        out = model(data)
        loss = crit(out, labels)
        loss. backward()
        opt. step()

        running_loss += loss.item()
        if batch % 500 == 499:
            print("Batch: [] | ]] loss: %.3f" % (batch + 1, len(dataloader), running_loss / 100))
            running_loss = 0.0


def test(model, dataloader, crit, epoch):
    global writer
    global classes
    total = 0
    correct = 0
    loss = 0.0
    class_probs = []
    class_preds = []
    model.eval()
    with torch.no_grad():
        for data, labels in dataloader:
            data, labels = data.cuda(), labels.cuda(non_blocking=True)
            out = model(data)
            loss + = crit(out, labels)
            preds = torch.max(out, 1)[1]
            class_probs.append([F.softmax(i, dim=0) for i in out])
            class_preds.append(preds)
            total + = labels. size(0)
            correct + = (preds == labels).sum().item()

    test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
    test_preds = torch.cat(class_preds)

    return loss / total, correct / total


def save_checkpoint(state, ckpt_path="checkpoint.pth"):
    torch.save(state, ckpt_path)
    print("Checkpoint saved")



# CIFAR 10 has 10 classes
model = vgg16()#(num_classes=len(classes), init_weights=False)
model = model. cuda()


# Declare Learning rate
lr = 0.1
state = {}
state["lr"] = lr

# Use cross entropy loss for classification and SGD optimizer
crit = nn.CrossEntropyLoss()
opt = optim.SGD(model.parameters(), lr=state["lr"], momentum=0.9, weight_decay=1e-4)


# Adjust learning rate based on epoch number
def adjust_lr(optimizer, epoch):
    global state
    new_lr = lr * (0.5**(epoch // 12)) if state["lr"] > 1e-7 else state["lr"]
    if new_lr != state["lr"]:
        state["lr"] = new_lr
        print("Updating learning rate: {}". format(state["lr"]))
        for param_group in optimizer.param_groups:
            param_group["lr"] = state["lr"]


# Train the model for 25 epochs to get ~80% accuracy.
if not os.path.exists('vgg16_base_ckpt'):
    num_epochs = 2
    for epoch in range(num_epochs):
        adjust_lr(opt, epoch)
        print('Epoch: [] / ]] LR: %f' % (epoch + 1, num_epochs, state["lr"]))

        train(model, training_dataloader, crit, opt, epoch)
        test_loss, test_acc = test(model, testing_dataloader, crit, epoch)

        print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))

    save_checkpoint({'epoch': epoch + 1,
                     'model_state_dict': model.state_dict(),
                     'acc': test_acc,
                     'opt_state_dict': opt.state_dict(),
                     'state': state},
                    ckpt_path="vgg16_base_ckpt")


quant_modules. initialize()
# All the regular conv, FC layers will be converted to their quantozed counterparts due to quant_modules.initialize()
qat_model = vgg16()#(num_classes=len(classes), init_weights=False)
qat_model = qat_model.cuda()

# vgg16_base_ckpt is the checkpoint generated from Step 3 : Training a baseline VGG16 model.
ckpt = torch.load("./vgg16_base_ckpt")
modified_state_dict={}
for key, val in ckpt["model_state_dict"].items():
    # Remove 'module.' from the key names
    if key.startswith('module'):
        modified_state_dict[key[7:]] = val
    else:
        modified_state_dict[key] = val

# Load the pre-trained checkpoint
qat_model.load_state_dict(modified_state_dict)
opt.load_state_dict(ckpt["opt_state_dict"])

def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn. TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)
            print(F"{name:40}: {module}")
    model.cuda()

def collect_stats(model, data_loader, num_batches):
    """Feed data to the network and collect statistics"""
    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn. TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module. disable()

    # Feed data to the network for collecting stats
    for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
        model(image. cuda())
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn. TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module. disable_calib()
            else:
                module. enable()

def calibrate_model(model, model_name, data_loader, num_calib_batch, calibrator, hist_percentile, out_dir):# Calibration select the real value range before quantization
    """
        Feed data to the network and calibrate.
        Arguments:
            model: classification model
            model_name: name to use when creating state files
            data_loader: calibration data set
            num_calib_batch: amount of calibration passes to perform
            calibrator: type of calibration to use (max/histogram)
            hist_percentile: percentiles to be used for histogram calibration
            out_dir: dir to save state files in
    """

    if num_calib_batch > 0:
        print("Calibrating model")
        with torch.no_grad():
            collect_stats(model, data_loader, num_calib_batch)

        if not calibrator == "histogram":
            compute_amax(model, method="max")
            calib_output = os.path.join(
                out_dir,
                F"{model_name}-max-{num_calib_batch*data_loader.batch_size}.pth")
            torch.save(model.state_dict(), calib_output)
        else:
            for percentile in hist_percentile:
                print(F"{percentile} percentile calibration")
                compute_amax(model, method="percentile")
                calib_output = os.path.join(
                    out_dir,
                    F"{model_name}-percentile-{percentile}-{num_calib_batch*data_loader.batch_size}.pth")
                torch.save(model.state_dict(), calib_output)

            for method in ["mse", "entropy"]:
                print(F"{method} calibration")
                compute_amax(model, method=method)
                calib_output = os.path.join(
                    out_dir,
                    F"{model_name}-{method}-{num_calib_batch*data_loader.batch_size}.pth")
                torch.save(model.state_dict(), calib_output)

#Calibrate the model using max calibration technique.
with torch.no_grad():
    calibrate_model(
        model=qat_model,
        model_name="vgg16",
        data_loader=training_dataloader,
        num_calib_batch=32,
        calibrator="max",
        hist_percentile=[99.9, 99.99, 99.999, 99.9999],
        out_dir="./")

# Finetune the QAT model for 1 epoch
num_epochs = 1
for epoch in range(num_epochs):
    adjust_lr(opt, epoch)
    print('Epoch: [] / ]] LR: %f' % (epoch + 1, num_epochs, state["lr"]))

    train(qat_model, training_dataloader, crit, opt, epoch)
    test_loss, test_acc = test(qat_model, testing_dataloader, crit, epoch)

    print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))

save_checkpoint({'epoch': epoch + 1,
                 'model_state_dict': qat_model.state_dict(),
                 'acc': test_acc,
                 'opt_state_dict': opt.state_dict(),
                 'state': state},
                ckpt_path="vgg16_qat_ckpt")

output

2.1.2
<function vgg16 at 0x7fec91019d30>
Files already downloaded and verified
Files already downloaded and verified
Calibrating model
100%|██████████| 32/32 [00:01<00:00, 26.74it/s]
WARNING: Logging before flag parsing goes to stderr.
W0312 23:16:56.688076 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688241 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688282 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688309 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688339 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688364 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688394 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688419 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688446 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688470 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688497 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688521 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688547 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688571 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688597 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688621 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688653 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688676 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688703 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688727 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688754 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688777 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688802 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688826 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688852 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688875 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688904 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688932 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.688956 140658743240512 tensor_quantizer.py:173] Disable MaxCalibrator
W0312 23:16:56.689098 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0312 23:16:56.689127 140658743240512 tensor_quantizer.py:238] Call .cuda() if running on GPU after loading calibrated amax.
W0312 23:16:56.689297 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([32, 1, 1, 1]).
W0312 23:16:56.689497 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0312 23:16:56.689561 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([32, 1, 1, 1]).
W0312 23:16:56.689650 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0312 23:16:56.689702 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([32, 1, 1, 1]).
W0312 23:16:56.689786 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0312 23:16:56.689837 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([32, 1, 1, 1]).
W0312 23:16:56.689918 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0312 23:16:56.689969 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([32, 1, 1, 1]).
W0312 23:16:56.690048 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0312 23:16:56.690099 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([32, 1, 1, 1]).
W0312 23:16:56.690190 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0312 23:16:56.690244 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([32, 1, 1, 1]).
W0312 23:16:56.690324 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0312 23:16:56.690373 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([32, 1, 1, 1]).
W0312 23:16:56.690452 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0312 23:16:56.690502 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([32, 1, 1, 1]).
W0312 23:16:56.690580 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0312 23:16:56.690632 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([32, 1, 1, 1]).
W0312 23:16:56.690712 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0312 23:16:56.690762 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([32, 1, 1, 1]).
W0312 23:16:56.690840 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0312 23:16:56.690891 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([32, 1, 1, 1]).
W0312 23:16:56.690968 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0312 23:16:56.691018 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([32, 1, 1, 1]).
W0312 23:16:56.691096 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0312 23:16:56.691149 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0312 23:16:56.691196 140658743240512 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([10, 1]).
features.0._input_quantizer : TensorQuantizer(8bit fake per-tensor amax=2.7537 calibrator=MaxCalibrator scale=1.0 quant)
features.0._weight_quantizer : TensorQuantizer(8bit fake axis=0 amax=[0.2916, 1.3594](32) calibrator=MaxCalibrator scale=1.0 quant)
features.3._input_quantizer : TensorQuantizer(8bit fake per-tensor amax=23.5690 calibrator=MaxCalibrator scale=1.0 quant)
features.3._weight_quantizer : TensorQuantizer(8bit fake axis=0 amax=[0.2789, 1.6201](32) calibrator=MaxCalibrator scale=1.0 quant)
features.7._input_quantizer : TensorQuantizer(8bit fake per-tensor amax=15.5116 calibrator=MaxCalibrator scale=1.0 quant)
features.7._weight_quantizer : TensorQuantizer(8bit fake axis=0 amax=[0.2731, 0.7190](32) calibrator=MaxCalibrator scale=1.0 quant)
features.10._input_quantizer : TensorQuantizer(8bit fake per-tensor amax=9.8704 calibrator=MaxCalibrator scale=1.0 quant)
features.10._weight_quantizer : TensorQuantizer(8bit fake axis=0 amax=[0.2451, 0.6903](32) calibrator=MaxCalibrator scale=1.0 quant)
features.14._input_quantizer : TensorQuantizer(8bit fake per-tensor amax=8.8258 calibrator=MaxCalibrator scale=1.0 quant)
features.14._weight_quantizer : TensorQuantizer(8bit fake axis=0 amax=[0.2536, 0.5359](32) calibrator=MaxCalibrator scale=1.0 quant)
features.17._input_quantizer : TensorQuantizer(8bit fake per-tensor amax=7.3961 calibrator=MaxCalibrator scale=1.0 quant)
features.17._weight_quantizer : TensorQuantizer(8bit fake axis=0 amax=[0.2086, 0.5812](32) calibrator=MaxCalibrator scale=1.0 quant)
features.20._input_quantizer : TensorQuantizer(8bit fake per-tensor amax=7.0251 calibrator=MaxCalibrator scale=1.0 quant)
features.20._weight_quantizer : TensorQuantizer(8bit fake axis=0 amax=[0.1935, 0.4652](32) calibrator=MaxCalibrator scale=1.0 quant)
features.24._input_quantizer : TensorQuantizer(8bit fake per-tensor amax=6.4505 calibrator=MaxCalibrator scale=1.0 quant)
features.24._weight_quantizer : TensorQuantizer(8bit fake axis=0 amax=[0.1842, 0.4343](32) calibrator=MaxCalibrator scale=1.0 quant)
features.27._input_quantizer : TensorQuantizer(8bit fake per-tensor amax=5.8092 calibrator=MaxCalibrator scale=1.0 quant)
features.27._weight_quantizer : TensorQuantizer(8bit fake axis=0 amax=[0.1743, 0.4571](32) calibrator=MaxCalibrator scale=1.0 quant)
features.30._input_quantizer : TensorQuantizer(8bit fake per-tensor amax=5.4201 calibrator=MaxCalibrator scale=1.0 quant)
features.30._weight_quantizer : TensorQuantizer(8bit fake axis=0 amax=[0.1543, 0.4066](32) calibrator=MaxCalibrator scale=1.0 quant)
features.34._input_quantizer : TensorQuantizer(8bit fake per-tensor amax=7.1276 calibrator=MaxCalibrator scale=1.0 quant)
features.34._weight_quantizer : TensorQuantizer(8bit fake axis=0 amax=[0.1258, 0.4596](32) calibrator=MaxCalibrator scale=1.0 quant)
features.37._input_quantizer : TensorQuantizer(8bit fake per-tensor amax=6.0650 calibrator=MaxCalibrator scale=1.0 quant)
features.37._weight_quantizer : TensorQuantizer(8bit fake axis=0 amax=[0.1472, 0.5321](32) calibrator=MaxCalibrator scale=1.0 quant)
features.40._input_quantizer : TensorQuantizer(8bit fake per-tensor amax=7.2193 calibrator=MaxCalibrator scale=1.0 quant)
features.40._weight_quantizer : TensorQuantizer(8bit fake axis=0 amax=[0.1540, 0.9711](32) calibrator=MaxCalibrator scale=1.0 quant)
avgpool._input_quantizer : TensorQuantizer(8bit fake per-tensor amax=4.2832 calibrator=MaxCalibrator scale=1.0 quant)
classifier.0._input_quantizer : TensorQuantizer(8bit fake per-tensor amax=4.2832 calibrator=MaxCalibrator scale=1.0 quant)
classifier.0._weight_quantizer : TensorQuantizer(8bit fake axis=0 amax=[0.6562, 1.3407](10) calibrator=MaxCalibrator scale=1.0 quant)
Epoch: [ 1 / 1] LR: 0.100000
Batch: [ 500 | 1563] loss: 7.695
Batch: [ 1000 | 1563] loss: 7.762
Batch: [ 1500 | 1563] loss: 7.725
Test Loss: 0.09093 Test Acc: 42.48%
Checkpoint saved

Process finished with exit code 0