(args)
| 220 | |
| 221 | |
| 222 | def main(args): |
| 223 | utils.init_distributed_mode(args) |
| 224 | |
| 225 | print(args) |
| 226 | |
| 227 | if args.distillation_type != 'none' and args.finetune and not args.eval: |
| 228 | raise NotImplementedError("Finetuning with distillation not yet supported") |
| 229 | |
| 230 | device = torch.device(args.device) |
| 231 | |
| 232 | # fix the seed for reproducibility |
| 233 | seed = args.seed + utils.get_rank() |
| 234 | torch.manual_seed(seed) |
| 235 | np.random.seed(seed) |
| 236 | # random.seed(seed) |
| 237 | |
| 238 | cudnn.benchmark = True |
| 239 | |
| 240 | # log about |
| 241 | run_name = args.output_dir.split("/")[-1] |
| 242 | if args.local_rank == 0 and args.gpu == 0: |
| 243 | mlflow.start_run(run_name=run_name) |
| 244 | for key, value in vars(args).items(): |
| 245 | mlflow.log_param(key, value) |
| 246 | |
| 247 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) |
| 248 | dataset_val, _ = build_dataset(is_train=False, args=args) |
| 249 | |
| 250 | if args.distributed: |
| 251 | num_tasks = utils.get_world_size() |
| 252 | global_rank = utils.get_rank() |
| 253 | if args.repeated_aug: |
| 254 | sampler_train = RASampler( |
| 255 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True |
| 256 | ) |
| 257 | else: |
| 258 | sampler_train = torch.utils.data.DistributedSampler( |
| 259 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True |
| 260 | ) |
| 261 | if args.dist_eval: |
| 262 | if len(dataset_val) % num_tasks != 0: |
| 263 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' |
| 264 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' |
| 265 | 'equal num of samples per-process.') |
| 266 | sampler_val = torch.utils.data.DistributedSampler( |
| 267 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) |
| 268 | else: |
| 269 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
| 270 | else: |
| 271 | sampler_train = torch.utils.data.RandomSampler(dataset_train) |
| 272 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
| 273 | |
| 274 | data_loader_train = torch.utils.data.DataLoader( |
| 275 | dataset_train, sampler=sampler_train, |
| 276 | batch_size=args.batch_size, |
| 277 | num_workers=args.num_workers, |
| 278 | pin_memory=args.pin_mem, |
| 279 | drop_last=True, |
no test coverage detected