import os importsys import json import torch import torch.nn as nn import torch.optim as optim from torchvision import transforms, datasets from tqdm import tqdm from model import resnet34,resnet101 import matplotlib.pyplot as plt # from csv import readerxon import numpy as np from osgeo import gdal from torchvision.transforms import functional as F # from torch.utils.tensorboard import SummaryWriter def main(): # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cpu") # device = torch.device("cuda:0") print("using {} device.".format(device)) data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.3971, 0.4091, 0.3681], [0.2169, 0.1943, 0.1917])]), "val": transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.3971, 0.4091, 0.3681], [0.2169, 0.1943, 0.1917])]), "test": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.3971, 0.4091, 0.3681], [0.2169, 0.1943, 0.1917])]), } data_root = os.path.abspath(os.path.join(os.getcwd(), "G:/splitdata")) # get data root path image_path = os.path.join(data_root, "data") # data set path assert os.path.exists(image_path), "{} path does not exist.".format(image_path) train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=data_transform["train"]) train_num = len(train_dataset) flower_list = train_dataset.class_to_idx cla_dict = dict((val, key) for key, val in flower_list.items()) # write dict into json file json_str = json.dumps(cla_dict, indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str) batch_size = 16 nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers print('Using {} dataloader workers every process'.format(nw)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw) validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=data_transform["val"]) val_num = len(validate_dataset) validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size, shuffle=False, num_workers=nw) test_dataset = datasets.ImageFolder(root=os.path.join(image_path, "test"), transform=data_transform["test"]) test_num = len(test_dataset) test_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=nw) print("using {} images for training, {} images for validation, {} images for testing".format(train_num, val_num, test_num)) arry_train = [] arry_test = [] def plot_loss(arry_train): line1, = plt.plot(range(0, len(arry_train)), arry_train, 'r.-') plt_title = 'BATCH_SIZE = 16; EPOCH = 5' plt.title(plt_title) plt.legend(handles=[line1], labels=["train_loss", "test_loss"], loc="upper right", fontsize=7) plt.ylabel('LOSS') plt.show() # net = ResNet34(classes_num=10) net = resnet34() model_weight_path = "./resnet34-333f7ec4.pth" assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path) net.load_state_dict(torch.load(model_weight_path, map_location='cpu')) for param in net.parameters(): param.requires_grad = False in_channel = net.fc.in_features net.fc = nn.Linear(in_channel, 30) net.to(device) #define loss function loss_function = nn.CrossEntropyLoss() # construct an optimizer params = [p for p in net.parameters() if p.requires_grad] optimizer = optim.Adam(params, lr=0.001) epochs = 20 best_acc = 0.0 save_path = './best.pth' train_steps = len(train_loader) total_test_step = 0 Loss_list = [] Accuracy_list = [] for epoch in range(epochs): #train net.train() running_loss = 0.0 train_bar = tqdm(train_loader, file=sys.stdout) for step, data in enumerate(train_bar): images, labels = data optimizer.zero_grad() logits = net(images.to(device)) loss = loss_function(logits, labels.to(device)) loss.backward() optimizer.step() # print statistics running_loss + = loss.item() arry_train.append(loss) train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss) #validate net.eval() acc = 0.0 #accumulate accurate number / epoch with torch.no_grad(): val_bar = tqdm(validate_loader, file=sys.stdout) for val_data in val_bar: val_images, val_labels = val_data outputs = net(val_images.to(device)) # loss = loss_function(outputs, test_labels) predict_y = torch.max(outputs, dim=1)[1] acc + = torch.eq(predict_y, val_labels.to(device)).sum().item() val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs) val_accurate = acc / val_num print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, running_loss / train_steps, val_accurate)) Loss_list.append(running_loss / train_steps) Accuracy_list.append(val_accurate) if val_accurate > best_acc: best_acc = val_accurate torch.save(net.state_dict(), save_path) print('Finished') x1 = range(0, 10) x2 = range(0, 10) y1 = Accuracy_list y2 = Loss_list plt.subplot(2, 1, 1) plt.plot(x1, y1, 'o-') plt.title('val accuracy') plt.ylabel('val accuracy') plt.subplot(2, 1, 2) plt.plot(x2, y2, '.-') plt.xlabel('training loss') plt.ylabel('training') plt.show() plt.savefig("accuracy_loss.jpg") if __name__ == '__main__': main()