MCPcopy
hub / github.com/microsoft/Swin-Transformer / build_loader

Function build_loader

data/build.py:44–95  ·  view source on GitHub ↗
(config)

Source from the content-addressed store, hash-verified

42
43
44def build_loader(config):
45 config.defrost()
46 dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
47 config.freeze()
48 print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
49 dataset_val, _ = build_dataset(is_train=False, config=config)
50 print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
51
52 num_tasks = dist.get_world_size()
53 global_rank = dist.get_rank()
54 if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
55 indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
56 sampler_train = SubsetRandomSampler(indices)
57 else:
58 sampler_train = torch.utils.data.DistributedSampler(
59 dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
60 )
61
62 if config.TEST.SEQUENTIAL:
63 sampler_val = torch.utils.data.SequentialSampler(dataset_val)
64 else:
65 sampler_val = torch.utils.data.distributed.DistributedSampler(
66 dataset_val, shuffle=config.TEST.SHUFFLE
67 )
68
69 data_loader_train = torch.utils.data.DataLoader(
70 dataset_train, sampler=sampler_train,
71 batch_size=config.DATA.BATCH_SIZE,
72 num_workers=config.DATA.NUM_WORKERS,
73 pin_memory=config.DATA.PIN_MEMORY,
74 drop_last=True,
75 )
76
77 data_loader_val = torch.utils.data.DataLoader(
78 dataset_val, sampler=sampler_val,
79 batch_size=config.DATA.BATCH_SIZE,
80 shuffle=False,
81 num_workers=config.DATA.NUM_WORKERS,
82 pin_memory=config.DATA.PIN_MEMORY,
83 drop_last=False
84 )
85
86 # setup mixup / cutmix
87 mixup_fn = None
88 mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
89 if mixup_active:
90 mixup_fn = Mixup(
91 mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
92 prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
93 label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
94
95 return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
96
97
98def build_dataset(is_train, config):

Callers 4

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls 2

SubsetRandomSamplerClass · 0.85
build_datasetFunction · 0.70

Tested by

no test coverage detected