MCPcopy
hub / github.com/microsoft/Cream / build_loader

Function build_loader

TinyViT/data/build.py:48–112  ·  view source on GitHub ↗
(config)

Source from the content-addressed store, hash-verified

46
47
48def build_loader(config):
49 config.defrost()
50 dataset_train, config.MODEL.NUM_CLASSES = build_dataset(
51 is_train=True, config=config)
52 config.freeze()
53
54 print(
55 f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
56 dataset_val, _ = build_dataset(is_train=False, config=config)
57 print(
58 f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
59
60 mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
61
62 sampler_train = MyDistributedSampler(
63 dataset_train, shuffle=True,
64 drop_last=False, padding=True, pair=mixup_active and config.DISTILL.ENABLED,
65 )
66
67 sampler_val = MyDistributedSampler(
68 dataset_val, shuffle=False,
69 drop_last=False, padding=False, pair=False,
70 )
71
72 # TinyViT Dataset Wrapper
73 if config.DISTILL.ENABLED:
74 dataset_train = DatasetWrapper(dataset_train,
75 logits_path=config.DISTILL.TEACHER_LOGITS_PATH,
76 topk=config.DISTILL.LOGITS_TOPK,
77 write=config.DISTILL.SAVE_TEACHER_LOGITS,
78 )
79
80 data_loader_train = torch.utils.data.DataLoader(
81 dataset_train, sampler=sampler_train,
82 batch_size=config.DATA.BATCH_SIZE,
83 num_workers=config.DATA.NUM_WORKERS,
84 pin_memory=config.DATA.PIN_MEMORY,
85 # modified for TinyViT, we save logits of all samples
86 drop_last=not config.DISTILL.SAVE_TEACHER_LOGITS,
87 )
88
89 data_loader_val = torch.utils.data.DataLoader(
90 dataset_val, sampler=sampler_val,
91 batch_size=config.DATA.BATCH_SIZE,
92 shuffle=False,
93 num_workers=config.DATA.NUM_WORKERS,
94 pin_memory=config.DATA.PIN_MEMORY,
95 drop_last=False
96 )
97
98 # setup mixup / cutmix
99 mixup_fn = None
100 if mixup_active:
101 mixup_t = Mixup if not config.DISTILL.ENABLED else Mixup_record
102 if config.DISTILL.ENABLED and config.AUG.MIXUP_MODE != "pair2":
103 # change to pair2 mode for saving logits
104 config.defrost()
105 config.AUG.MIXUP_MODE = 'pair2'

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls 4

DatasetWrapperClass · 0.85
build_datasetFunction · 0.70
printFunction · 0.50

Tested by

no test coverage detected