build loaders for linear evaluation
(opt, ngpus_per_node)
| 374 | |
| 375 | |
| 376 | def build_linear_loader(opt, ngpus_per_node): |
| 377 | """build loaders for linear evaluation""" |
| 378 | # transform |
| 379 | if opt.modal == 'RGB': |
| 380 | mean = [0.485, 0.456, 0.406] |
| 381 | std = [0.229, 0.224, 0.225] |
| 382 | color_transfer = RGB2RGB() |
| 383 | else: |
| 384 | mean = [0.457, -0.082, -0.052] |
| 385 | std = [0.500, 1.331, 1.333] |
| 386 | color_transfer = RGB2YDbDr() |
| 387 | normalize = transforms.Normalize(mean=mean, std=std) |
| 388 | |
| 389 | if opt.aug_linear == 'NULL': |
| 390 | train_transform = transforms.Compose([ |
| 391 | transforms.RandomResizedCrop(224, scale=(opt.crop, 1.)), |
| 392 | transforms.RandomHorizontalFlip(), |
| 393 | color_transfer, |
| 394 | transforms.ToTensor(), |
| 395 | normalize, |
| 396 | ]) |
| 397 | elif opt.aug_linear == 'RA': |
| 398 | rgb_mean = (0.485, 0.456, 0.406) |
| 399 | ra_params = dict( |
| 400 | translate_const=100, |
| 401 | img_mean=tuple([min(255, round(255 * x)) for x in rgb_mean]), |
| 402 | ) |
| 403 | train_transform = transforms.Compose([ |
| 404 | transforms.RandomResizedCrop(224, scale=(opt.crop, 1.)), |
| 405 | transforms.RandomHorizontalFlip(), |
| 406 | rand_augment_transform('rand-n{}-m{}-mstd0.5'.format(2, 10), |
| 407 | ra_params, |
| 408 | use_cmc=(opt.modal == 'CMC')), |
| 409 | color_transfer, |
| 410 | transforms.ToTensor(), |
| 411 | normalize, |
| 412 | ]) |
| 413 | else: |
| 414 | raise NotImplementedError('aug not found: {}'.format(opt.aug_linear)) |
| 415 | |
| 416 | # dataset |
| 417 | data_folder = opt.data_folder |
| 418 | train_dir = os.path.join(data_folder, 'train') |
| 419 | val_dir = os.path.join(data_folder, 'val') |
| 420 | train_dataset = datasets.ImageFolder(train_dir, train_transform) |
| 421 | val_dataset = datasets.ImageFolder( |
| 422 | val_dir, |
| 423 | transforms.Compose([ |
| 424 | transforms.Resize(256), |
| 425 | transforms.CenterCrop(224), |
| 426 | color_transfer, |
| 427 | transforms.ToTensor(), |
| 428 | normalize, |
| 429 | ]) |
| 430 | ) |
| 431 | |
| 432 | # loader |
| 433 | batch_size = int(opt.batch_size / opt.world_size) |
no test coverage detected