MCPcopy Index your code
hub / github.com/davda54/sam / Cifar

Class Cifar

example/data/cifar.py:9–38  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

7
8
9class Cifar:
10 def __init__(self, batch_size, threads):
11 mean, std = self._get_statistics()
12
13 train_transform = transforms.Compose([
14 torchvision.transforms.RandomCrop(size=(32, 32), padding=4),
15 torchvision.transforms.RandomHorizontalFlip(),
16 transforms.ToTensor(),
17 transforms.Normalize(mean, std),
18 Cutout()
19 ])
20
21 test_transform = transforms.Compose([
22 transforms.ToTensor(),
23 transforms.Normalize(mean, std)
24 ])
25
26 train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
27 test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
28
29 self.train = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=threads)
30 self.test = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=threads)
31
32 self.classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
33
34 def _get_statistics(self):
35 train_set = torchvision.datasets.CIFAR10(root='./cifar', train=True, download=True, transform=transforms.ToTensor())
36
37 data = torch.cat([d[0] for d in DataLoader(train_set)])
38 return data.mean(dim=[0, 2, 3]), data.std(dim=[0, 2, 3])

Callers 1

train.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected