(args)
| 192 | |
| 193 | |
| 194 | def main(args): |
| 195 | utils.init_distributed_mode(args) |
| 196 | |
| 197 | if args.distillation_type != 'none' and args.finetune and not args.eval: |
| 198 | raise NotImplementedError( |
| 199 | "Finetuning with distillation not yet supported") |
| 200 | |
| 201 | device = torch.device(args.device) |
| 202 | |
| 203 | # fix the seed for reproducibility |
| 204 | seed = args.seed + utils.get_rank() |
| 205 | torch.manual_seed(seed) |
| 206 | np.random.seed(seed) |
| 207 | # random.seed(seed) |
| 208 | |
| 209 | cudnn.benchmark = True |
| 210 | |
| 211 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) |
| 212 | dataset_val, _ = build_dataset(is_train=False, args=args) |
| 213 | |
| 214 | if True: # args.distributed: |
| 215 | num_tasks = utils.get_world_size() |
| 216 | global_rank = utils.get_rank() |
| 217 | if args.repeated_aug: |
| 218 | sampler_train = RASampler( |
| 219 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True |
| 220 | ) |
| 221 | else: |
| 222 | sampler_train = torch.utils.data.DistributedSampler( |
| 223 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True |
| 224 | ) |
| 225 | if args.dist_eval: |
| 226 | if len(dataset_val) % num_tasks != 0: |
| 227 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' |
| 228 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' |
| 229 | 'equal num of samples per-process.') |
| 230 | sampler_val = torch.utils.data.DistributedSampler( |
| 231 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) |
| 232 | else: |
| 233 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
| 234 | else: |
| 235 | sampler_train = torch.utils.data.RandomSampler(dataset_train) |
| 236 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
| 237 | |
| 238 | data_loader_train = torch.utils.data.DataLoader( |
| 239 | dataset_train, sampler=sampler_train, |
| 240 | batch_size=args.batch_size, |
| 241 | num_workers=args.num_workers, |
| 242 | pin_memory=args.pin_mem, |
| 243 | drop_last=True, |
| 244 | ) |
| 245 | |
| 246 | if args.ThreeAugment: |
| 247 | data_loader_train.dataset.transform = new_data_aug_generator(args) |
| 248 | |
| 249 | data_loader_val = torch.utils.data.DataLoader( |
| 250 | dataset_val, sampler=sampler_val, |
| 251 | batch_size=int(1.5 * args.batch_size), |
no test coverage detected