(is_train, config)
| 96 | |
| 97 | |
| 98 | def build_dataset(is_train, config): |
| 99 | transform = build_transform(is_train, config) |
| 100 | if config.DATA.DATASET == 'imagenet': |
| 101 | prefix = 'train' if is_train else 'val' |
| 102 | if config.DATA.ZIP_MODE: |
| 103 | ann_file = prefix + "_map.txt" |
| 104 | prefix = prefix + ".zip@/" |
| 105 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, |
| 106 | cache_mode=config.DATA.CACHE_MODE if is_train else 'part') |
| 107 | else: |
| 108 | root = os.path.join(config.DATA.DATA_PATH, prefix) |
| 109 | dataset = datasets.ImageFolder(root, transform=transform) |
| 110 | nb_classes = 1000 |
| 111 | elif config.DATA.DATASET == 'imagenet22K': |
| 112 | prefix = 'ILSVRC2011fall_whole' |
| 113 | if is_train: |
| 114 | ann_file = prefix + "_map_train.txt" |
| 115 | else: |
| 116 | ann_file = prefix + "_map_val.txt" |
| 117 | dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform) |
| 118 | nb_classes = 21841 |
| 119 | else: |
| 120 | raise NotImplementedError("We only support ImageNet Now.") |
| 121 | |
| 122 | return dataset, nb_classes |
| 123 | |
| 124 | |
| 125 | def build_transform(is_train, config): |
no test coverage detected