| 55 | |
| 56 | |
| 57 | def build_dataset(is_train, args): |
| 58 | transform = build_transform(is_train, args) |
| 59 | |
| 60 | if args.data_set == 'CIFAR': |
| 61 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) |
| 62 | nb_classes = 100 |
| 63 | elif args.data_set == 'IMNET': |
| 64 | if not args.use_mcloader: |
| 65 | root = os.path.join(args.data_path, 'train' if is_train else 'val') |
| 66 | dataset = datasets.ImageFolder(root, transform=transform) |
| 67 | else: |
| 68 | dataset = ClassificationDataset( |
| 69 | 'train' if is_train else 'val', |
| 70 | pipeline=transform |
| 71 | ) |
| 72 | nb_classes = 1000 |
| 73 | elif args.data_set == 'INAT': |
| 74 | dataset = INatDataset(args.data_path, train=is_train, year=2018, |
| 75 | category=args.inat_category, transform=transform) |
| 76 | nb_classes = dataset.nb_classes |
| 77 | elif args.data_set == 'INAT19': |
| 78 | dataset = INatDataset(args.data_path, train=is_train, year=2019, |
| 79 | category=args.inat_category, transform=transform) |
| 80 | nb_classes = dataset.nb_classes |
| 81 | |
| 82 | return dataset, nb_classes |
| 83 | |
| 84 | |
| 85 | def build_transform(is_train, args): |