(config)
| 46 | |
| 47 | |
| 48 | def build_loader(config): |
| 49 | config.defrost() |
| 50 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset( |
| 51 | is_train=True, config=config) |
| 52 | config.freeze() |
| 53 | |
| 54 | print( |
| 55 | f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") |
| 56 | dataset_val, _ = build_dataset(is_train=False, config=config) |
| 57 | print( |
| 58 | f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") |
| 59 | |
| 60 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None |
| 61 | |
| 62 | sampler_train = MyDistributedSampler( |
| 63 | dataset_train, shuffle=True, |
| 64 | drop_last=False, padding=True, pair=mixup_active and config.DISTILL.ENABLED, |
| 65 | ) |
| 66 | |
| 67 | sampler_val = MyDistributedSampler( |
| 68 | dataset_val, shuffle=False, |
| 69 | drop_last=False, padding=False, pair=False, |
| 70 | ) |
| 71 | |
| 72 | # TinyViT Dataset Wrapper |
| 73 | if config.DISTILL.ENABLED: |
| 74 | dataset_train = DatasetWrapper(dataset_train, |
| 75 | logits_path=config.DISTILL.TEACHER_LOGITS_PATH, |
| 76 | topk=config.DISTILL.LOGITS_TOPK, |
| 77 | write=config.DISTILL.SAVE_TEACHER_LOGITS, |
| 78 | ) |
| 79 | |
| 80 | data_loader_train = torch.utils.data.DataLoader( |
| 81 | dataset_train, sampler=sampler_train, |
| 82 | batch_size=config.DATA.BATCH_SIZE, |
| 83 | num_workers=config.DATA.NUM_WORKERS, |
| 84 | pin_memory=config.DATA.PIN_MEMORY, |
| 85 | # modified for TinyViT, we save logits of all samples |
| 86 | drop_last=not config.DISTILL.SAVE_TEACHER_LOGITS, |
| 87 | ) |
| 88 | |
| 89 | data_loader_val = torch.utils.data.DataLoader( |
| 90 | dataset_val, sampler=sampler_val, |
| 91 | batch_size=config.DATA.BATCH_SIZE, |
| 92 | shuffle=False, |
| 93 | num_workers=config.DATA.NUM_WORKERS, |
| 94 | pin_memory=config.DATA.PIN_MEMORY, |
| 95 | drop_last=False |
| 96 | ) |
| 97 | |
| 98 | # setup mixup / cutmix |
| 99 | mixup_fn = None |
| 100 | if mixup_active: |
| 101 | mixup_t = Mixup if not config.DISTILL.ENABLED else Mixup_record |
| 102 | if config.DISTILL.ENABLED and config.AUG.MIXUP_MODE != "pair2": |
| 103 | # change to pair2 mode for saving logits |
| 104 | config.defrost() |
| 105 | config.AUG.MIXUP_MODE = 'pair2' |
no test coverage detected