Pytorch Training With Early Stopping and Progress Bar

cifar_model.py

[ ]:
# The content of cifar_model.py

'''ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm


class FakeReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


class SequentialWithArgs(torch.nn.Sequential):
    def forward(self, input, *args, **kwargs):
        vs = list(self._modules.values())
        l = len(vs)
        for i in range(l):
            if i == l - 1:
                input = vs[i](input, *args, **kwargs)
            else:
                input = vs[i](input)
        return input


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes))

    def forward(self, x, fake_relu=False):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        if fake_relu:
            return FakeReLU.apply(out)
        return F.relu(out)


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x, fake_relu=False):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        if fake_relu:
            return FakeReLU.apply(out)
        return F.relu(out)


class ResNet(nn.Module):
    # feat_scale lets us deal with CelebA, other non-32x32 datasets
    def __init__(self, block, num_blocks, num_classes=10, feat_scale=1, wm=1, dataset='cifar10'):
        super(ResNet, self).__init__()

        assert dataset in ['cifar10', 'celeba128']
        first_stride = {'cifar10': 1, 'celeba128': 2}[dataset]

        widths = [64, 128, 256, 512]
        widths = [int(w * wm) for w in widths]

        self.in_planes = widths[0]
        self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=first_stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_planes)
        self.layer1 = self._make_layer(block, widths[0], num_blocks[0], stride=first_stride)
        self.layer2 = self._make_layer(block, widths[1], num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, widths[2], num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, widths[3], num_blocks[3], stride=2)
        self.layer5 = self._make_layer(block, widths[3], num_blocks[3], stride=2)
        # self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.pool = nn.AvgPool2d(4)
        self.linear = nn.Linear(feat_scale * widths[3] * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return SequentialWithArgs(*layers)

    def forward(self, x, with_latent=False, fake_relu=False, no_relu=False):
        assert (not no_relu), \
            "no_relu not yet supported for this architecture"
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out, fake_relu=fake_relu)
        out = self.pool(out)
        pre_out = out.view(out.size(0), -1)
        final = self.linear(pre_out)
        if with_latent:
            return final, pre_out
        return final


def ResNet18(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)


def ResNet18Wide(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], wd=1.5, **kwargs)


def ResNet18Thin(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], wd=.75, **kwargs)


def ResNet34(**kwargs):
    return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)


def ResNet50(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)


def ResNet101(**kwargs):
    return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)


def ResNet152(**kwargs):
    return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)


resnet50 = ResNet50
resnet18 = ResNet18
resnet101 = ResNet101
resnet152 = ResNet152

Training notebook

[1]:
%load_ext autoreload
%autoreload 2
[2]:
import torch
import torchvision.transforms as transforms
import numpy as np
import torchvision
from cifar_model import resnet50
from tqdm import tqdm
import copy
import torch.nn.functional as F
[3]:
torch.__version__, torchvision.__version__
[3]:
('1.4.0', '0.5.0')
[4]:
device = 'cuda:0'
[5]:
def train_val(model, optimizer, train_loader, val_loader, epochs, earlystop_patience=-1):
    model.train()
    val_accs = []
    best_val_loss = np.inf
    patience_counter = 0
    for epoch in range(epochs):
        # Train
        train_loss, train_preds, train_labels = [], [], []
        model.train()
        pbar = tqdm(train_loader, position=0, leave=True)
        for i, (data, label) in enumerate(pbar):
            data, label = data.to(device), label.to(device)
            optimizer.zero_grad()
            logits = model(data)
            loss = F.cross_entropy(logits, label)
            loss.backward()
            optimizer.step()
            preds = torch.argmax(logits, dim=1)
            acc = (preds == label).type(torch.FloatTensor).mean()
            train_loss.append(loss.item()), train_preds.append(preds), train_labels.append(label)
            train_loss_mean = torch.Tensor(train_loss).mean()
            train_acc_mean = (torch.cat(train_preds) ==
                              torch.cat(train_labels)).type(torch.FloatTensor).mean().item()
            pbar.set_postfix({'epoch': epoch,
                              'loss': f'{train_loss_mean:.2f}',
                              'acc': f'{train_acc_mean:.2f}'})
        # Validation
        model.eval()
        val_loss, val_preds, val_labels = [], [], []
        pbar = tqdm(val_loader, position=0, leave=True)
        for i, (data, label) in enumerate(pbar):
            data, label = data.to(device), label.to(device)
            with torch.no_grad():
                logits = model(data)
                loss = F.cross_entropy(logits, label).item()
                preds = torch.argmax(logits, dim=1)
                acc = (preds == label).type(torch.FloatTensor).mean().item()
                val_loss.append(loss), val_preds.append(preds), val_labels.append(label)
            val_loss_mean = torch.Tensor(val_loss).mean()
            val_acc_mean = (torch.cat(val_preds) ==
                            torch.cat(val_labels)).type(torch.FloatTensor).mean().item()
            pbar.set_postfix({'epoch': epoch,
                              'val loss': f'{val_loss_mean:.2f}',
                              'val acc': f'{val_acc_mean:.2f}'})

        if val_loss_mean < best_val_loss:
            best_val_loss = val_loss_mean
            patience_counter = 0
            best_model_state = copy.deepcopy(model.state_dict())
        else:
            patience_counter += 1
            if earlystop_patience >= 0 and patience_counter > earlystop_patience:
                break
    if earlystop_patience >= 0:
        model.load_state_dict(best_model_state)
[11]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
val_loader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)
Files already downloaded and verified
Files already downloaded and verified
[8]:
model = resnet50(num_classes=len(trainset.classes))
model = torch.nn.DataParallel(model)
model.to(device);
[12]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=5e-4)
train_val(model, optimizer, train_loader, val_loader, epochs=20)
100%|██████████| 391/391 [01:09<00:00,  5.60it/s, epoch=0, loss=1.53, acc=0.44]
100%|██████████| 100/100 [00:07<00:00, 13.40it/s, epoch=0, val loss=1.16, val acc=0.60]
100%|██████████| 391/391 [01:00<00:00,  6.43it/s, epoch=1, loss=1.03, acc=0.64]
100%|██████████| 100/100 [00:07<00:00, 13.16it/s, epoch=1, val loss=0.91, val acc=0.68]
100%|██████████| 391/391 [01:01<00:00,  6.33it/s, epoch=2, loss=0.81, acc=0.72]
100%|██████████| 100/100 [00:07<00:00, 13.12it/s, epoch=2, val loss=0.81, val acc=0.72]
100%|██████████| 391/391 [01:00<00:00,  6.42it/s, epoch=3, loss=0.70, acc=0.76]
100%|██████████| 100/100 [00:07<00:00, 13.21it/s, epoch=3, val loss=0.80, val acc=0.74]
100%|██████████| 391/391 [01:00<00:00,  6.43it/s, epoch=4, loss=0.62, acc=0.79]
100%|██████████| 100/100 [00:07<00:00, 13.09it/s, epoch=4, val loss=0.90, val acc=0.70]
100%|██████████| 391/391 [01:00<00:00,  6.44it/s, epoch=5, loss=0.57, acc=0.80]
100%|██████████| 100/100 [00:07<00:00, 12.90it/s, epoch=5, val loss=0.64, val acc=0.78]
100%|██████████| 391/391 [01:01<00:00,  6.35it/s, epoch=6, loss=0.54, acc=0.82]
100%|██████████| 100/100 [00:07<00:00, 12.97it/s, epoch=6, val loss=0.55, val acc=0.81]
100%|██████████| 391/391 [01:01<00:00,  6.31it/s, epoch=7, loss=0.52, acc=0.83]
100%|██████████| 100/100 [00:07<00:00, 12.87it/s, epoch=7, val loss=0.85, val acc=0.75]
100%|██████████| 391/391 [01:02<00:00,  6.31it/s, epoch=8, loss=0.49, acc=0.83]
100%|██████████| 100/100 [00:08<00:00, 12.50it/s, epoch=8, val loss=0.62, val acc=0.79]
100%|██████████| 391/391 [01:03<00:00,  6.20it/s, epoch=9, loss=0.48, acc=0.84]
100%|██████████| 100/100 [00:07<00:00, 13.07it/s, epoch=9, val loss=0.64, val acc=0.79]
100%|██████████| 391/391 [01:01<00:00,  6.41it/s, epoch=10, loss=0.45, acc=0.85]
100%|██████████| 100/100 [00:07<00:00, 13.14it/s, epoch=10, val loss=0.65, val acc=0.78]
100%|██████████| 391/391 [01:01<00:00,  6.38it/s, epoch=11, loss=0.43, acc=0.85]
100%|██████████| 100/100 [00:07<00:00, 13.51it/s, epoch=11, val loss=0.49, val acc=0.84]
100%|██████████| 391/391 [01:00<00:00,  6.45it/s, epoch=12, loss=0.42, acc=0.86]
100%|██████████| 100/100 [00:07<00:00, 13.61it/s, epoch=12, val loss=0.48, val acc=0.84]
100%|██████████| 391/391 [01:00<00:00,  6.46it/s, epoch=13, loss=0.41, acc=0.86]
100%|██████████| 100/100 [00:07<00:00, 12.56it/s, epoch=13, val loss=0.49, val acc=0.83]
100%|██████████| 391/391 [01:00<00:00,  6.42it/s, epoch=14, loss=0.39, acc=0.87]
100%|██████████| 100/100 [00:07<00:00, 13.28it/s, epoch=14, val loss=0.45, val acc=0.86]
100%|██████████| 391/391 [01:00<00:00,  6.44it/s, epoch=15, loss=0.37, acc=0.87]
100%|██████████| 100/100 [00:07<00:00, 14.01it/s, epoch=15, val loss=0.44, val acc=0.86]
100%|██████████| 391/391 [01:01<00:00,  6.36it/s, epoch=16, loss=0.36, acc=0.88]
100%|██████████| 100/100 [00:07<00:00, 12.87it/s, epoch=16, val loss=0.44, val acc=0.86]
100%|██████████| 391/391 [01:00<00:00,  6.42it/s, epoch=17, loss=0.35, acc=0.88]
100%|██████████| 100/100 [00:07<00:00, 13.25it/s, epoch=17, val loss=0.41, val acc=0.86]
100%|██████████| 391/391 [01:00<00:00,  6.46it/s, epoch=18, loss=0.34, acc=0.88]
100%|██████████| 100/100 [00:07<00:00, 13.40it/s, epoch=18, val loss=0.42, val acc=0.86]
100%|██████████| 391/391 [01:00<00:00,  6.43it/s, epoch=19, loss=0.33, acc=0.89]
100%|██████████| 100/100 [00:07<00:00, 13.58it/s, epoch=19, val loss=0.49, val acc=0.84]
[13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)
train_val(model, optimizer, train_loader, val_loader, epochs=2)
100%|██████████| 391/391 [01:00<00:00,  6.42it/s, epoch=0, loss=0.21, acc=0.93]
100%|██████████| 100/100 [00:07<00:00, 12.76it/s, epoch=0, val loss=0.25, val acc=0.92]
100%|██████████| 391/391 [01:00<00:00,  6.44it/s, epoch=1, loss=0.17, acc=0.94]
100%|██████████| 100/100 [00:07<00:00, 13.12it/s, epoch=1, val loss=0.25, val acc=0.92]
[14]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)
train_val(model, optimizer, train_loader, val_loader, epochs=10)
100%|██████████| 391/391 [01:00<00:00,  6.42it/s, epoch=0, loss=0.15, acc=0.95]
100%|██████████| 100/100 [00:07<00:00, 13.56it/s, epoch=0, val loss=0.24, val acc=0.92]
100%|██████████| 391/391 [01:01<00:00,  6.41it/s, epoch=1, loss=0.14, acc=0.95]
100%|██████████| 100/100 [00:07<00:00, 13.62it/s, epoch=1, val loss=0.24, val acc=0.92]
100%|██████████| 391/391 [01:00<00:00,  6.44it/s, epoch=2, loss=0.13, acc=0.96]
100%|██████████| 100/100 [00:07<00:00, 13.27it/s, epoch=2, val loss=0.24, val acc=0.93]
100%|██████████| 391/391 [01:00<00:00,  6.46it/s, epoch=3, loss=0.12, acc=0.96]
100%|██████████| 100/100 [00:07<00:00, 13.43it/s, epoch=3, val loss=0.25, val acc=0.92]
100%|██████████| 391/391 [01:01<00:00,  6.40it/s, epoch=4, loss=0.12, acc=0.96]
100%|██████████| 100/100 [00:07<00:00, 13.44it/s, epoch=4, val loss=0.25, val acc=0.92]
100%|██████████| 391/391 [01:00<00:00,  6.45it/s, epoch=5, loss=0.11, acc=0.96]
100%|██████████| 100/100 [00:07<00:00, 12.52it/s, epoch=5, val loss=0.24, val acc=0.92]
100%|██████████| 391/391 [01:00<00:00,  6.43it/s, epoch=6, loss=0.11, acc=0.97]
100%|██████████| 100/100 [00:07<00:00, 12.90it/s, epoch=6, val loss=0.26, val acc=0.92]
100%|██████████| 391/391 [01:00<00:00,  6.42it/s, epoch=7, loss=0.10, acc=0.97]
100%|██████████| 100/100 [00:07<00:00, 12.74it/s, epoch=7, val loss=0.24, val acc=0.93]
100%|██████████| 391/391 [01:01<00:00,  6.36it/s, epoch=8, loss=0.10, acc=0.97]
100%|██████████| 100/100 [00:07<00:00, 13.13it/s, epoch=8, val loss=0.25, val acc=0.92]
100%|██████████| 391/391 [01:03<00:00,  6.16it/s, epoch=9, loss=0.10, acc=0.97]
100%|██████████| 100/100 [00:07<00:00, 13.08it/s, epoch=9, val loss=0.25, val acc=0.93]