MCPcopy Index your code
hub / github.com/HobbitLong/PyContrast / build_linear_loader

Function build_linear_loader

pycontrast/datasets/util.py:376–446  ·  view source on GitHub ↗

build loaders for linear evaluation

(opt, ngpus_per_node)

Source from the content-addressed store, hash-verified

374
375
376def build_linear_loader(opt, ngpus_per_node):
377 """build loaders for linear evaluation"""
378 # transform
379 if opt.modal == 'RGB':
380 mean = [0.485, 0.456, 0.406]
381 std = [0.229, 0.224, 0.225]
382 color_transfer = RGB2RGB()
383 else:
384 mean = [0.457, -0.082, -0.052]
385 std = [0.500, 1.331, 1.333]
386 color_transfer = RGB2YDbDr()
387 normalize = transforms.Normalize(mean=mean, std=std)
388
389 if opt.aug_linear == 'NULL':
390 train_transform = transforms.Compose([
391 transforms.RandomResizedCrop(224, scale=(opt.crop, 1.)),
392 transforms.RandomHorizontalFlip(),
393 color_transfer,
394 transforms.ToTensor(),
395 normalize,
396 ])
397 elif opt.aug_linear == 'RA':
398 rgb_mean = (0.485, 0.456, 0.406)
399 ra_params = dict(
400 translate_const=100,
401 img_mean=tuple([min(255, round(255 * x)) for x in rgb_mean]),
402 )
403 train_transform = transforms.Compose([
404 transforms.RandomResizedCrop(224, scale=(opt.crop, 1.)),
405 transforms.RandomHorizontalFlip(),
406 rand_augment_transform('rand-n{}-m{}-mstd0.5'.format(2, 10),
407 ra_params,
408 use_cmc=(opt.modal == 'CMC')),
409 color_transfer,
410 transforms.ToTensor(),
411 normalize,
412 ])
413 else:
414 raise NotImplementedError('aug not found: {}'.format(opt.aug_linear))
415
416 # dataset
417 data_folder = opt.data_folder
418 train_dir = os.path.join(data_folder, 'train')
419 val_dir = os.path.join(data_folder, 'val')
420 train_dataset = datasets.ImageFolder(train_dir, train_transform)
421 val_dataset = datasets.ImageFolder(
422 val_dir,
423 transforms.Compose([
424 transforms.Resize(256),
425 transforms.CenterCrop(224),
426 color_transfer,
427 transforms.ToTensor(),
428 normalize,
429 ])
430 )
431
432 # loader
433 batch_size = int(opt.batch_size / opt.world_size)

Callers 1

main_workerFunction · 0.90

Calls 3

RGB2RGBClass · 0.85
RGB2YDbDrClass · 0.85
rand_augment_transformFunction · 0.85

Tested by

no test coverage detected