| 116 | |
| 117 | |
| 118 | def set_loader(opt): |
| 119 | # construct data loader |
| 120 | if opt.dataset == 'cifar10': |
| 121 | mean = (0.4914, 0.4822, 0.4465) |
| 122 | std = (0.2023, 0.1994, 0.2010) |
| 123 | elif opt.dataset == 'cifar100': |
| 124 | mean = (0.5071, 0.4867, 0.4408) |
| 125 | std = (0.2675, 0.2565, 0.2761) |
| 126 | else: |
| 127 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) |
| 128 | normalize = transforms.Normalize(mean=mean, std=std) |
| 129 | |
| 130 | train_transform = transforms.Compose([ |
| 131 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), |
| 132 | transforms.RandomHorizontalFlip(), |
| 133 | transforms.ToTensor(), |
| 134 | normalize, |
| 135 | ]) |
| 136 | |
| 137 | val_transform = transforms.Compose([ |
| 138 | transforms.ToTensor(), |
| 139 | normalize, |
| 140 | ]) |
| 141 | |
| 142 | if opt.dataset == 'cifar10': |
| 143 | train_dataset = datasets.CIFAR10(root=opt.data_folder, |
| 144 | transform=train_transform, |
| 145 | download=True) |
| 146 | val_dataset = datasets.CIFAR10(root=opt.data_folder, |
| 147 | train=False, |
| 148 | transform=val_transform) |
| 149 | elif opt.dataset == 'cifar100': |
| 150 | train_dataset = datasets.CIFAR100(root=opt.data_folder, |
| 151 | transform=train_transform, |
| 152 | download=True) |
| 153 | val_dataset = datasets.CIFAR100(root=opt.data_folder, |
| 154 | train=False, |
| 155 | transform=val_transform) |
| 156 | else: |
| 157 | raise ValueError(opt.dataset) |
| 158 | |
| 159 | train_sampler = None |
| 160 | train_loader = torch.utils.data.DataLoader( |
| 161 | train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None), |
| 162 | num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler) |
| 163 | val_loader = torch.utils.data.DataLoader( |
| 164 | val_dataset, batch_size=256, shuffle=False, |
| 165 | num_workers=8, pin_memory=True) |
| 166 | |
| 167 | return train_loader, val_loader |
| 168 | |
| 169 | |
| 170 | def set_model(opt): |