Pytorch Training With Early Stopping and Progress Bar¶
cifar_model.py¶
[ ]:
# The content of cifar_model.py
'''ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
class FakeReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
return grad_output
class SequentialWithArgs(torch.nn.Sequential):
def forward(self, input, *args, **kwargs):
vs = list(self._modules.values())
l = len(vs)
for i in range(l):
if i == l - 1:
input = vs[i](input, *args, **kwargs)
else:
input = vs[i](input)
return input
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes))
def forward(self, x, fake_relu=False):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
if fake_relu:
return FakeReLU.apply(out)
return F.relu(out)
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
def forward(self, x, fake_relu=False):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
if fake_relu:
return FakeReLU.apply(out)
return F.relu(out)
class ResNet(nn.Module):
# feat_scale lets us deal with CelebA, other non-32x32 datasets
def __init__(self, block, num_blocks, num_classes=10, feat_scale=1, wm=1, dataset='cifar10'):
super(ResNet, self).__init__()
assert dataset in ['cifar10', 'celeba128']
first_stride = {'cifar10': 1, 'celeba128': 2}[dataset]
widths = [64, 128, 256, 512]
widths = [int(w * wm) for w in widths]
self.in_planes = widths[0]
self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=first_stride,
padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_planes)
self.layer1 = self._make_layer(block, widths[0], num_blocks[0], stride=first_stride)
self.layer2 = self._make_layer(block, widths[1], num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, widths[2], num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, widths[3], num_blocks[3], stride=2)
self.layer5 = self._make_layer(block, widths[3], num_blocks[3], stride=2)
# self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.pool = nn.AvgPool2d(4)
self.linear = nn.Linear(feat_scale * widths[3] * block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return SequentialWithArgs(*layers)
def forward(self, x, with_latent=False, fake_relu=False, no_relu=False):
assert (not no_relu), \
"no_relu not yet supported for this architecture"
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out, fake_relu=fake_relu)
out = self.pool(out)
pre_out = out.view(out.size(0), -1)
final = self.linear(pre_out)
if with_latent:
return final, pre_out
return final
def ResNet18(**kwargs):
return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
def ResNet18Wide(**kwargs):
return ResNet(BasicBlock, [2, 2, 2, 2], wd=1.5, **kwargs)
def ResNet18Thin(**kwargs):
return ResNet(BasicBlock, [2, 2, 2, 2], wd=.75, **kwargs)
def ResNet34(**kwargs):
return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
def ResNet50(**kwargs):
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
def ResNet101(**kwargs):
return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
def ResNet152(**kwargs):
return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
resnet50 = ResNet50
resnet18 = ResNet18
resnet101 = ResNet101
resnet152 = ResNet152
Training notebook¶
[1]:
%load_ext autoreload
%autoreload 2
[2]:
import torch
import torchvision.transforms as transforms
import numpy as np
import torchvision
from cifar_model import resnet50
from tqdm import tqdm
import copy
import torch.nn.functional as F
[3]:
torch.__version__, torchvision.__version__
[3]:
('1.4.0', '0.5.0')
[4]:
device = 'cuda:0'
[5]:
def train_val(model, optimizer, train_loader, val_loader, epochs, earlystop_patience=-1):
model.train()
val_accs = []
best_val_loss = np.inf
patience_counter = 0
for epoch in range(epochs):
# Train
train_loss, train_preds, train_labels = [], [], []
model.train()
pbar = tqdm(train_loader, position=0, leave=True)
for i, (data, label) in enumerate(pbar):
data, label = data.to(device), label.to(device)
optimizer.zero_grad()
logits = model(data)
loss = F.cross_entropy(logits, label)
loss.backward()
optimizer.step()
preds = torch.argmax(logits, dim=1)
acc = (preds == label).type(torch.FloatTensor).mean()
train_loss.append(loss.item()), train_preds.append(preds), train_labels.append(label)
train_loss_mean = torch.Tensor(train_loss).mean()
train_acc_mean = (torch.cat(train_preds) ==
torch.cat(train_labels)).type(torch.FloatTensor).mean().item()
pbar.set_postfix({'epoch': epoch,
'loss': f'{train_loss_mean:.2f}',
'acc': f'{train_acc_mean:.2f}'})
# Validation
model.eval()
val_loss, val_preds, val_labels = [], [], []
pbar = tqdm(val_loader, position=0, leave=True)
for i, (data, label) in enumerate(pbar):
data, label = data.to(device), label.to(device)
with torch.no_grad():
logits = model(data)
loss = F.cross_entropy(logits, label).item()
preds = torch.argmax(logits, dim=1)
acc = (preds == label).type(torch.FloatTensor).mean().item()
val_loss.append(loss), val_preds.append(preds), val_labels.append(label)
val_loss_mean = torch.Tensor(val_loss).mean()
val_acc_mean = (torch.cat(val_preds) ==
torch.cat(val_labels)).type(torch.FloatTensor).mean().item()
pbar.set_postfix({'epoch': epoch,
'val loss': f'{val_loss_mean:.2f}',
'val acc': f'{val_acc_mean:.2f}'})
if val_loss_mean < best_val_loss:
best_val_loss = val_loss_mean
patience_counter = 0
best_model_state = copy.deepcopy(model.state_dict())
else:
patience_counter += 1
if earlystop_patience >= 0 and patience_counter > earlystop_patience:
break
if earlystop_patience >= 0:
model.load_state_dict(best_model_state)
[11]:
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(
trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform_test)
val_loader = torch.utils.data.DataLoader(
testset, batch_size=100, shuffle=False, num_workers=2)
Files already downloaded and verified
Files already downloaded and verified
[8]:
model = resnet50(num_classes=len(trainset.classes))
model = torch.nn.DataParallel(model)
model.to(device);
[12]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=5e-4)
train_val(model, optimizer, train_loader, val_loader, epochs=20)
100%|██████████| 391/391 [01:09<00:00, 5.60it/s, epoch=0, loss=1.53, acc=0.44]
100%|██████████| 100/100 [00:07<00:00, 13.40it/s, epoch=0, val loss=1.16, val acc=0.60]
100%|██████████| 391/391 [01:00<00:00, 6.43it/s, epoch=1, loss=1.03, acc=0.64]
100%|██████████| 100/100 [00:07<00:00, 13.16it/s, epoch=1, val loss=0.91, val acc=0.68]
100%|██████████| 391/391 [01:01<00:00, 6.33it/s, epoch=2, loss=0.81, acc=0.72]
100%|██████████| 100/100 [00:07<00:00, 13.12it/s, epoch=2, val loss=0.81, val acc=0.72]
100%|██████████| 391/391 [01:00<00:00, 6.42it/s, epoch=3, loss=0.70, acc=0.76]
100%|██████████| 100/100 [00:07<00:00, 13.21it/s, epoch=3, val loss=0.80, val acc=0.74]
100%|██████████| 391/391 [01:00<00:00, 6.43it/s, epoch=4, loss=0.62, acc=0.79]
100%|██████████| 100/100 [00:07<00:00, 13.09it/s, epoch=4, val loss=0.90, val acc=0.70]
100%|██████████| 391/391 [01:00<00:00, 6.44it/s, epoch=5, loss=0.57, acc=0.80]
100%|██████████| 100/100 [00:07<00:00, 12.90it/s, epoch=5, val loss=0.64, val acc=0.78]
100%|██████████| 391/391 [01:01<00:00, 6.35it/s, epoch=6, loss=0.54, acc=0.82]
100%|██████████| 100/100 [00:07<00:00, 12.97it/s, epoch=6, val loss=0.55, val acc=0.81]
100%|██████████| 391/391 [01:01<00:00, 6.31it/s, epoch=7, loss=0.52, acc=0.83]
100%|██████████| 100/100 [00:07<00:00, 12.87it/s, epoch=7, val loss=0.85, val acc=0.75]
100%|██████████| 391/391 [01:02<00:00, 6.31it/s, epoch=8, loss=0.49, acc=0.83]
100%|██████████| 100/100 [00:08<00:00, 12.50it/s, epoch=8, val loss=0.62, val acc=0.79]
100%|██████████| 391/391 [01:03<00:00, 6.20it/s, epoch=9, loss=0.48, acc=0.84]
100%|██████████| 100/100 [00:07<00:00, 13.07it/s, epoch=9, val loss=0.64, val acc=0.79]
100%|██████████| 391/391 [01:01<00:00, 6.41it/s, epoch=10, loss=0.45, acc=0.85]
100%|██████████| 100/100 [00:07<00:00, 13.14it/s, epoch=10, val loss=0.65, val acc=0.78]
100%|██████████| 391/391 [01:01<00:00, 6.38it/s, epoch=11, loss=0.43, acc=0.85]
100%|██████████| 100/100 [00:07<00:00, 13.51it/s, epoch=11, val loss=0.49, val acc=0.84]
100%|██████████| 391/391 [01:00<00:00, 6.45it/s, epoch=12, loss=0.42, acc=0.86]
100%|██████████| 100/100 [00:07<00:00, 13.61it/s, epoch=12, val loss=0.48, val acc=0.84]
100%|██████████| 391/391 [01:00<00:00, 6.46it/s, epoch=13, loss=0.41, acc=0.86]
100%|██████████| 100/100 [00:07<00:00, 12.56it/s, epoch=13, val loss=0.49, val acc=0.83]
100%|██████████| 391/391 [01:00<00:00, 6.42it/s, epoch=14, loss=0.39, acc=0.87]
100%|██████████| 100/100 [00:07<00:00, 13.28it/s, epoch=14, val loss=0.45, val acc=0.86]
100%|██████████| 391/391 [01:00<00:00, 6.44it/s, epoch=15, loss=0.37, acc=0.87]
100%|██████████| 100/100 [00:07<00:00, 14.01it/s, epoch=15, val loss=0.44, val acc=0.86]
100%|██████████| 391/391 [01:01<00:00, 6.36it/s, epoch=16, loss=0.36, acc=0.88]
100%|██████████| 100/100 [00:07<00:00, 12.87it/s, epoch=16, val loss=0.44, val acc=0.86]
100%|██████████| 391/391 [01:00<00:00, 6.42it/s, epoch=17, loss=0.35, acc=0.88]
100%|██████████| 100/100 [00:07<00:00, 13.25it/s, epoch=17, val loss=0.41, val acc=0.86]
100%|██████████| 391/391 [01:00<00:00, 6.46it/s, epoch=18, loss=0.34, acc=0.88]
100%|██████████| 100/100 [00:07<00:00, 13.40it/s, epoch=18, val loss=0.42, val acc=0.86]
100%|██████████| 391/391 [01:00<00:00, 6.43it/s, epoch=19, loss=0.33, acc=0.89]
100%|██████████| 100/100 [00:07<00:00, 13.58it/s, epoch=19, val loss=0.49, val acc=0.84]
[13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)
train_val(model, optimizer, train_loader, val_loader, epochs=2)
100%|██████████| 391/391 [01:00<00:00, 6.42it/s, epoch=0, loss=0.21, acc=0.93]
100%|██████████| 100/100 [00:07<00:00, 12.76it/s, epoch=0, val loss=0.25, val acc=0.92]
100%|██████████| 391/391 [01:00<00:00, 6.44it/s, epoch=1, loss=0.17, acc=0.94]
100%|██████████| 100/100 [00:07<00:00, 13.12it/s, epoch=1, val loss=0.25, val acc=0.92]
[14]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)
train_val(model, optimizer, train_loader, val_loader, epochs=10)
100%|██████████| 391/391 [01:00<00:00, 6.42it/s, epoch=0, loss=0.15, acc=0.95]
100%|██████████| 100/100 [00:07<00:00, 13.56it/s, epoch=0, val loss=0.24, val acc=0.92]
100%|██████████| 391/391 [01:01<00:00, 6.41it/s, epoch=1, loss=0.14, acc=0.95]
100%|██████████| 100/100 [00:07<00:00, 13.62it/s, epoch=1, val loss=0.24, val acc=0.92]
100%|██████████| 391/391 [01:00<00:00, 6.44it/s, epoch=2, loss=0.13, acc=0.96]
100%|██████████| 100/100 [00:07<00:00, 13.27it/s, epoch=2, val loss=0.24, val acc=0.93]
100%|██████████| 391/391 [01:00<00:00, 6.46it/s, epoch=3, loss=0.12, acc=0.96]
100%|██████████| 100/100 [00:07<00:00, 13.43it/s, epoch=3, val loss=0.25, val acc=0.92]
100%|██████████| 391/391 [01:01<00:00, 6.40it/s, epoch=4, loss=0.12, acc=0.96]
100%|██████████| 100/100 [00:07<00:00, 13.44it/s, epoch=4, val loss=0.25, val acc=0.92]
100%|██████████| 391/391 [01:00<00:00, 6.45it/s, epoch=5, loss=0.11, acc=0.96]
100%|██████████| 100/100 [00:07<00:00, 12.52it/s, epoch=5, val loss=0.24, val acc=0.92]
100%|██████████| 391/391 [01:00<00:00, 6.43it/s, epoch=6, loss=0.11, acc=0.97]
100%|██████████| 100/100 [00:07<00:00, 12.90it/s, epoch=6, val loss=0.26, val acc=0.92]
100%|██████████| 391/391 [01:00<00:00, 6.42it/s, epoch=7, loss=0.10, acc=0.97]
100%|██████████| 100/100 [00:07<00:00, 12.74it/s, epoch=7, val loss=0.24, val acc=0.93]
100%|██████████| 391/391 [01:01<00:00, 6.36it/s, epoch=8, loss=0.10, acc=0.97]
100%|██████████| 100/100 [00:07<00:00, 13.13it/s, epoch=8, val loss=0.25, val acc=0.92]
100%|██████████| 391/391 [01:03<00:00, 6.16it/s, epoch=9, loss=0.10, acc=0.97]
100%|██████████| 100/100 [00:07<00:00, 13.08it/s, epoch=9, val loss=0.25, val acc=0.93]