| 62 | |
| 63 | |
| 64 | def build_dataset(is_train, args): |
| 65 | transform = build_transform(is_train, args) |
| 66 | |
| 67 | if args.data_set == 'CIFAR': |
| 68 | dataset = datasets.CIFAR100( |
| 69 | args.data_path, train=is_train, transform=transform) |
| 70 | nb_classes = 100 |
| 71 | elif args.data_set == 'IMNET': |
| 72 | prefix = 'train' if is_train else 'val' |
| 73 | data_dir = os.path.join(args.data_path, f'{prefix}.tar') |
| 74 | if os.path.exists(data_dir): |
| 75 | dataset = TimmDatasetTar(data_dir, transform=transform) |
| 76 | else: |
| 77 | root = os.path.join(args.data_path, 'train' if is_train else 'val') |
| 78 | dataset = datasets.ImageFolder(root, transform=transform) |
| 79 | nb_classes = 1000 |
| 80 | elif args.data_set == 'IMNETEE': |
| 81 | root = os.path.join(args.data_path, 'train' if is_train else 'val') |
| 82 | dataset = datasets.ImageFolder(root, transform=transform) |
| 83 | nb_classes = 10 |
| 84 | elif args.data_set == 'FLOWERS': |
| 85 | root = os.path.join(args.data_path, 'train' if is_train else 'test') |
| 86 | dataset = datasets.ImageFolder(root, transform=transform) |
| 87 | if is_train: |
| 88 | dataset = torch.utils.data.ConcatDataset( |
| 89 | [dataset for _ in range(100)]) |
| 90 | nb_classes = 102 |
| 91 | elif args.data_set == 'INAT': |
| 92 | dataset = INatDataset(args.data_path, train=is_train, year=2018, |
| 93 | category=args.inat_category, transform=transform) |
| 94 | nb_classes = dataset.nb_classes |
| 95 | elif args.data_set == 'INAT19': |
| 96 | dataset = INatDataset(args.data_path, train=is_train, year=2019, |
| 97 | category=args.inat_category, transform=transform) |
| 98 | nb_classes = dataset.nb_classes |
| 99 | return dataset, nb_classes |
| 100 | |
| 101 | |
| 102 | def build_transform(is_train, args): |