(config)
| 42 | |
| 43 | |
| 44 | def build_loader(config): |
| 45 | config.defrost() |
| 46 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) |
| 47 | config.freeze() |
| 48 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") |
| 49 | dataset_val, _ = build_dataset(is_train=False, config=config) |
| 50 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") |
| 51 | |
| 52 | num_tasks = dist.get_world_size() |
| 53 | global_rank = dist.get_rank() |
| 54 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': |
| 55 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) |
| 56 | sampler_train = SubsetRandomSampler(indices) |
| 57 | else: |
| 58 | sampler_train = torch.utils.data.DistributedSampler( |
| 59 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True |
| 60 | ) |
| 61 | |
| 62 | if config.TEST.SEQUENTIAL: |
| 63 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
| 64 | else: |
| 65 | sampler_val = torch.utils.data.distributed.DistributedSampler( |
| 66 | dataset_val, shuffle=config.TEST.SHUFFLE |
| 67 | ) |
| 68 | |
| 69 | data_loader_train = torch.utils.data.DataLoader( |
| 70 | dataset_train, sampler=sampler_train, |
| 71 | batch_size=config.DATA.BATCH_SIZE, |
| 72 | num_workers=config.DATA.NUM_WORKERS, |
| 73 | pin_memory=config.DATA.PIN_MEMORY, |
| 74 | drop_last=True, |
| 75 | ) |
| 76 | |
| 77 | data_loader_val = torch.utils.data.DataLoader( |
| 78 | dataset_val, sampler=sampler_val, |
| 79 | batch_size=config.DATA.BATCH_SIZE, |
| 80 | shuffle=False, |
| 81 | num_workers=config.DATA.NUM_WORKERS, |
| 82 | pin_memory=config.DATA.PIN_MEMORY, |
| 83 | drop_last=False |
| 84 | ) |
| 85 | |
| 86 | # setup mixup / cutmix |
| 87 | mixup_fn = None |
| 88 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None |
| 89 | if mixup_active: |
| 90 | mixup_fn = Mixup( |
| 91 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, |
| 92 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, |
| 93 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) |
| 94 | |
| 95 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn |
| 96 | |
| 97 | |
| 98 | def build_dataset(is_train, config): |
no test coverage detected