Skip to content

Commit 2f3dda7

Browse files
committed
cifar10 example
1 parent de8135b commit 2f3dda7

File tree

10 files changed

+380
-0
lines changed

10 files changed

+380
-0
lines changed

example/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Cifar10 with WRN 🌁
2+
3+
This folder contains a simple Wide-ResNet implementation that can be trained on Cifar10 with SAM. Start the training by running `python3 train.py`

example/data/cifar.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
import torchvision
3+
import torchvision.transforms as transforms
4+
from torch.utils.data import DataLoader
5+
6+
from utility.cutout import Cutout
7+
8+
9+
class Cifar:
10+
def __init__(self, batch_size, threads):
11+
mean, std = self._get_statistics()
12+
13+
train_transform = transforms.Compose([
14+
torchvision.transforms.RandomCrop(size=(32, 32), padding=4),
15+
torchvision.transforms.RandomHorizontalFlip(),
16+
transforms.ToTensor(),
17+
transforms.Normalize(mean, std),
18+
Cutout()
19+
])
20+
21+
test_transform = transforms.Compose([
22+
transforms.ToTensor(),
23+
transforms.Normalize(mean, std)
24+
])
25+
26+
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
27+
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
28+
29+
self.train = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=threads)
30+
self.test = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=threads)
31+
32+
self.classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
33+
34+
def _get_statistics(self):
35+
train_set = torchvision.datasets.CIFAR10(root='./cifar', train=True, download=True, transform=transforms.ToTensor())
36+
test_set = torchvision.datasets.CIFAR10(root='./cifar', train=False, download=True, transform=transforms.ToTensor())
37+
38+
data = torch.cat([d[0] for d in DataLoader(train_set)] + [d[0] for d in DataLoader(test_set)])
39+
return data.mean(dim=[0, 2, 3]), data.std(dim=[0, 2, 3])

example/model/smooth_cross_entropy.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
def smooth_crossentropy(pred, gold, smoothing=0.1):
7+
n_class = pred.size(1)
8+
9+
one_hot = torch.full_like(pred, fill_value=smoothing / (n_class - 1))
10+
one_hot.scatter_(dim=1, index=gold.unsqueeze(1), value=1.0 - smoothing)
11+
log_prob = F.log_softmax(pred, dim=1)
12+
13+
return F.kl_div(input=log_prob, target=one_hot, reduction='none').sum(-1)

example/model/wide_res_net.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from collections import OrderedDict
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
7+
8+
class BasicUnit(nn.Module):
9+
def __init__(self, channels: int, dropout: float):
10+
super(BasicUnit, self).__init__()
11+
self.block = nn.Sequential(OrderedDict([
12+
("0_normalization", nn.BatchNorm2d(channels)),
13+
("1_activation", nn.ReLU(inplace=True)),
14+
("2_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)),
15+
("3_normalization", nn.BatchNorm2d(channels)),
16+
("4_activation", nn.ReLU(inplace=True)),
17+
("5_dropout", nn.Dropout(dropout, inplace=True)),
18+
("6_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)),
19+
]))
20+
21+
def forward(self, x):
22+
return self.block(x)
23+
24+
25+
class DownsampleUnit(nn.Module):
26+
def __init__(self, in_channels: int, out_channels: int, stride: int, dropout: float):
27+
super(DownsampleUnit, self).__init__()
28+
self.norm_act = nn.Sequential(OrderedDict([
29+
("0_normalization", nn.BatchNorm2d(in_channels)),
30+
("1_activation", nn.ReLU(inplace=True)),
31+
]))
32+
self.block = nn.Sequential(OrderedDict([
33+
("0_convolution", nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=1, bias=False)),
34+
("1_normalization", nn.BatchNorm2d(out_channels)),
35+
("2_activation", nn.ReLU(inplace=True)),
36+
("3_dropout", nn.Dropout(dropout, inplace=True)),
37+
("4_convolution", nn.Conv2d(out_channels, out_channels, (3, 3), stride=1, padding=1, bias=False)),
38+
]))
39+
self.downsample = nn.Conv2d(in_channels, out_channels, (1, 1), stride=stride, padding=0, bias=False)
40+
41+
def forward(self, x):
42+
x = self.norm_act(x)
43+
return self.block(x) + self.downsample(x)
44+
45+
46+
class Block(nn.Module):
47+
def __init__(self, in_channels: int, out_channels: int, stride: int, depth: int, dropout: float):
48+
super(Block, self).__init__()
49+
self.block = nn.Sequential(
50+
DownsampleUnit(in_channels, out_channels, stride, dropout),
51+
*(BasicUnit(out_channels, dropout) for _ in range(depth))
52+
)
53+
54+
def forward(self, x):
55+
return self.block(x)
56+
57+
58+
class WideResNet(nn.Module):
59+
def __init__(self, depth: int, width_factor: int, dropout: float, in_channels: int, labels: int):
60+
super(WideResNet, self).__init__()
61+
62+
self.filters = [16, 1 * 16 * width_factor, 2 * 16 * width_factor, 4 * 16 * width_factor]
63+
self.block_depth = (depth - 4) // (3 * 2)
64+
65+
self.f = nn.Sequential(OrderedDict([
66+
("0_convolution", nn.Conv2d(in_channels, self.filters[0], (3, 3), stride=1, padding=1, bias=False)),
67+
("1_block", Block(self.filters[0], self.filters[1], 1, self.block_depth, dropout)),
68+
("2_block", Block(self.filters[1], self.filters[2], 2, self.block_depth, dropout)),
69+
("3_block", Block(self.filters[2], self.filters[3], 2, self.block_depth, dropout)),
70+
("4_normalization", nn.BatchNorm2d(self.filters[3])),
71+
("5_activation", nn.ReLU(inplace=True)),
72+
("6_pooling", nn.AvgPool2d(kernel_size=8)),
73+
("7_flattening", nn.Flatten()),
74+
("8_classification", nn.Linear(in_features=self.filters[3], out_features=labels)),
75+
]))
76+
77+
self._initialize()
78+
79+
def _initialize(self):
80+
for m in self.modules():
81+
if isinstance(m, nn.Conv2d):
82+
nn.init.kaiming_normal_(m.weight.data, mode="fan_in", nonlinearity="relu")
83+
if m.bias is not None:
84+
m.bias.data.zero_()
85+
elif isinstance(m, nn.BatchNorm2d):
86+
m.weight.data.fill_(1)
87+
m.bias.data.zero_()
88+
elif isinstance(m, nn.Linear):
89+
m.weight.data.zero_()
90+
m.bias.data.zero_()
91+
92+
def forward(self, x):
93+
return self.f(x)

example/train.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import argparse
2+
import torch
3+
4+
from model.wide_res_net import WideResNet
5+
from model.smooth_cross_entropy import smooth_crossentropy
6+
from data.cifar import Cifar
7+
from utility.log import Log
8+
from utility.initialize import initialize
9+
from utility.step_lr import StepLR
10+
import sys; sys.path.append("..")
11+
from sam import SAM
12+
13+
14+
if __name__ == "__main__":
15+
parser = argparse.ArgumentParser()
16+
parser.add_argument("--batch_size", default=128, type=int, help="Batch size used in the training and validation loop.")
17+
parser.add_argument("--depth", default=28, type=int, help="Number of layers.")
18+
parser.add_argument("--dropout", default=0.0, type=float, help="Dropout rate.")
19+
parser.add_argument("--epochs", default=200, type=int, help="Total number of epochs.")
20+
parser.add_argument("--label_smoothing", default=0.1, type=float, help="Use 0.0 for no label smoothing.")
21+
parser.add_argument("--learning_rate", default=0.1, type=float, help="Base learning rate at the start of the training.")
22+
parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum.")
23+
parser.add_argument("--threads", default=2, type=int, help="Number of CPU threads for dataloaders.")
24+
parser.add_argument("--rho", default=0.05, type=int, help="Rho parameter for SAM.")
25+
parser.add_argument("--weight_decay", default=0.0005, type=float, help="L2 weight decay.")
26+
parser.add_argument("--width_factor", default=10, type=int, help="How many times wider compared to normal ResNet.")
27+
args = parser.parse_args()
28+
29+
initialize(args, seed=42)
30+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31+
32+
dataset = Cifar(args.batch_size, args.threads)
33+
log = Log(log_each=10)
34+
model = WideResNet(args.depth, args.width_factor, args.dropout, in_channels=3, labels=10).to(device)
35+
36+
base_optimizer = torch.optim.SGD
37+
optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
38+
scheduler = StepLR(optimizer, args.learning_rate, args.epochs)
39+
40+
for epoch in range(args.epochs):
41+
model.train()
42+
log.train(len_dataset=len(dataset.train))
43+
44+
for batch in dataset.train:
45+
inputs, targets = (b.to(device) for b in batch)
46+
47+
# first forward-backward step
48+
predictions = model(inputs)
49+
loss = smooth_crossentropy(predictions, targets)
50+
loss.mean().backward()
51+
optimizer.first_step(zero_grad=True)
52+
53+
# second forward-backward step
54+
smooth_crossentropy(model(inputs), targets).mean().backward()
55+
optimizer.second_step(zero_grad=True)
56+
57+
with torch.no_grad():
58+
correct = torch.argmax(predictions.data, 1) == targets
59+
log(model, loss.cpu(), correct.cpu(), scheduler.lr())
60+
scheduler(epoch)
61+
62+
model.eval()
63+
log.eval(len_dataset=len(dataset.test))
64+
65+
with torch.no_grad():
66+
for batch in dataset.test:
67+
inputs, targets = (b.to(device) for b in batch)
68+
69+
predictions = model(inputs)
70+
loss = smooth_crossentropy(predictions, targets)
71+
correct = torch.argmax(predictions, 1) == targets
72+
log(model, loss.cpu(), correct.cpu())
73+
74+
log.flush()

example/utility/cutout.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
3+
4+
class Cutout:
5+
def __init__(self, size=16, p=0.5):
6+
self.size = size
7+
self.half_size = size // 2
8+
self.p = p
9+
10+
def __call__(self, image):
11+
if torch.rand([1]).item() > self.p: return image
12+
13+
left = torch.randint(-self.half_size, image.shape[0] - self.half_size, [1]).item()
14+
top = torch.randint(-self.half_size, image.shape[1] - self.half_size, [1]).item()
15+
right = min(image.shape[0], left + self.size)
16+
bottom = min(image.shape[1], top + self.size)
17+
18+
image[max(0,left):right, max(0,top):bottom, :] = 0
19+
return image

example/utility/initialize.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import random
2+
import torch
3+
4+
5+
def initialize(args, seed: int):
6+
random.seed(seed)
7+
torch.manual_seed(seed)
8+
torch.cuda.manual_seed(seed)
9+
torch.cuda.manual_seed_all(seed)
10+
11+
torch.backends.cudnn.enabled = True
12+
torch.backends.cudnn.benchmark = True
13+
torch.backends.cudnn.deterministic = False

example/utility/loading_bar.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
class LoadingBar:
2+
def __init__(self, length: int = 40):
3+
self.length = length
4+
self.symbols = ['┈', '░', '▒', '▓']
5+
6+
def __call__(self, progress: float) -> str:
7+
p = int(progress * self.length*4 + 0.5)
8+
d, r = p // 4, p % 4
9+
return '┠┈' + d * '█' + ((self.symbols[r]) + max(0, self.length-1-d) * '┈' if p < self.length*4 else '') + "┈┨"

example/utility/log.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from utility.loading_bar import LoadingBar
2+
import time
3+
4+
5+
class Log:
6+
def __init__(self, log_each: int, initial_epoch=-1):
7+
self.loading_bar = LoadingBar(length=27)
8+
self.best_accuracy = 0.0
9+
self.log_each = log_each
10+
self.epoch = initial_epoch
11+
12+
def train(self, len_dataset: int) -> None:
13+
self.epoch += 1
14+
if self.epoch == 0:
15+
self._print_header()
16+
else:
17+
self.flush()
18+
19+
self.is_train = True
20+
self.last_steps_state = {"loss": 0.0, "accuracy": 0.0, "steps": 0}
21+
self._reset(len_dataset)
22+
23+
def eval(self, len_dataset: int) -> None:
24+
self.flush()
25+
self.is_train = False
26+
self._reset(len_dataset)
27+
28+
def __call__(self, model, loss, accuracy, learning_rate: float = None) -> None:
29+
if self.is_train:
30+
self._train_step(model, loss, accuracy, learning_rate)
31+
else:
32+
self._eval_step(loss, accuracy)
33+
34+
def flush(self) -> None:
35+
if self.is_train:
36+
loss = self.epoch_state["loss"] / self.epoch_state["steps"]
37+
accuracy = self.epoch_state["accuracy"] / self.epoch_state["steps"]
38+
39+
print(
40+
f"\r{self.epoch:12d}{loss:12.4f}{100*accuracy:10.2f} % ┃{self.learning_rate:12.3e}{self._time():>12} ┃",
41+
end="",
42+
flush=True,
43+
)
44+
45+
else:
46+
loss = self.epoch_state["loss"] / self.epoch_state["steps"]
47+
accuracy = self.epoch_state["accuracy"] / self.epoch_state["steps"]
48+
49+
print(f"{loss:12.4f}{100*accuracy:10.2f} % ┃", flush=True)
50+
51+
if accuracy > self.best_accuracy:
52+
self.best_accuracy = accuracy
53+
54+
def _train_step(self, model, loss, accuracy, learning_rate: float) -> None:
55+
self.learning_rate = learning_rate
56+
self.last_steps_state["loss"] += loss.sum().item()
57+
self.last_steps_state["accuracy"] += accuracy.sum().item()
58+
self.last_steps_state["steps"] += loss.size(0)
59+
self.epoch_state["loss"] += loss.sum().item()
60+
self.epoch_state["accuracy"] += accuracy.sum().item()
61+
self.epoch_state["steps"] += loss.size(0)
62+
self.step += 1
63+
64+
if self.step % self.log_each == self.log_each - 1:
65+
loss = self.last_steps_state["loss"] / self.last_steps_state["steps"]
66+
accuracy = self.last_steps_state["accuracy"] / self.last_steps_state["steps"]
67+
68+
self.last_steps_state = {"loss": 0.0, "accuracy": 0.0, "steps": 0}
69+
progress = self.step / self.len_dataset
70+
71+
print(
72+
f"\r{self.epoch:12d}{loss:12.4f}{100*accuracy:10.2f} % ┃{learning_rate:12.3e}{self._time():>12} {self.loading_bar(progress)}",
73+
end="",
74+
flush=True,
75+
)
76+
77+
def _eval_step(self, loss, accuracy) -> None:
78+
self.epoch_state["loss"] += loss.sum().item()
79+
self.epoch_state["accuracy"] += accuracy.sum().item()
80+
self.epoch_state["steps"] += loss.size(0)
81+
82+
def _reset(self, len_dataset: int) -> None:
83+
self.start_time = time.time()
84+
self.step = 0
85+
self.len_dataset = len_dataset
86+
self.epoch_state = {"loss": 0.0, "accuracy": 0.0, "steps": 0}
87+
88+
def _time(self) -> str:
89+
time_seconds = int(time.time() - self.start_time)
90+
return f"{time_seconds // 60:02d}:{time_seconds % 60:02d} min"
91+
92+
def _print_header(self) -> None:
93+
print(f"┏━━━━━━━━━━━━━━┳━━━━━━━╸T╺╸R╺╸A╺╸I╺╸N╺━━━━━━━┳━━━━━━━╸S╺╸T╺╸A╺╸T╺╸S╺━━━━━━━┳━━━━━━━╸V╺╸A╺╸L╺╸I╺╸D╺━━━━━━━┓")
94+
print(f"┃ ┃ ╷ ┃ ╷ ┃ ╷ ┃")
95+
print(f"┃ epoch ┃ loss │ accuracy ┃ l.r. │ elapsed ┃ loss │ accuracy ┃")
96+
print(f"┠──────────────╂──────────────┼──────────────╂──────────────┼──────────────╂──────────────┼──────────────┨")

example/utility/step_lr.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
class StepLR:
2+
def __init__(self, optimizer, learning_rate: float, total_epochs: int):
3+
self.optimizer = optimizer
4+
self.total_epochs = total_epochs
5+
self.base = learning_rate
6+
7+
def __call__(self, epoch):
8+
if epoch < self.total_epochs * 3/10:
9+
lr = self.base
10+
elif epoch < self.total_epochs * 6/10:
11+
lr = self.base * 0.2
12+
elif epoch < self.total_epochs * 8/10:
13+
lr = self.base * 0.2 ** 2
14+
else:
15+
lr = self.base * 0.2 ** 3
16+
17+
for param_group in self.optimizer.param_groups:
18+
param_group["lr"] = lr
19+
20+
def lr(self) -> float:
21+
return self.optimizer.param_groups[0]["lr"]

0 commit comments

Comments
 (0)