| 7 | |
| 8 | |
| 9 | class Cifar: |
| 10 | def __init__(self, batch_size, threads): |
| 11 | mean, std = self._get_statistics() |
| 12 | |
| 13 | train_transform = transforms.Compose([ |
| 14 | torchvision.transforms.RandomCrop(size=(32, 32), padding=4), |
| 15 | torchvision.transforms.RandomHorizontalFlip(), |
| 16 | transforms.ToTensor(), |
| 17 | transforms.Normalize(mean, std), |
| 18 | Cutout() |
| 19 | ]) |
| 20 | |
| 21 | test_transform = transforms.Compose([ |
| 22 | transforms.ToTensor(), |
| 23 | transforms.Normalize(mean, std) |
| 24 | ]) |
| 25 | |
| 26 | train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform) |
| 27 | test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform) |
| 28 | |
| 29 | self.train = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=threads) |
| 30 | self.test = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=threads) |
| 31 | |
| 32 | self.classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') |
| 33 | |
| 34 | def _get_statistics(self): |
| 35 | train_set = torchvision.datasets.CIFAR10(root='./cifar', train=True, download=True, transform=transforms.ToTensor()) |
| 36 | |
| 37 | data = torch.cat([d[0] for d in DataLoader(train_set)]) |
| 38 | return data.mean(dim=[0, 2, 3]), data.std(dim=[0, 2, 3]) |