(dataset_name)
| 302 | raise ValueError(f"Unknown dataset: {name}") |
| 303 | |
| 304 | def get_num_classes(dataset_name): |
| 305 | dataset_name = dataset_name.lower() |
| 306 | if dataset_name in ['cifar10', 'mnist', 'fashion_mnist', 'fashionmnist', 'stl10']: |
| 307 | return 10 |
| 308 | elif dataset_name == 'cifar100': |
| 309 | return 100 |
| 310 | elif dataset_name in ['tiny_imagenet', 'tinyimagenet']: |
| 311 | return 200 |
| 312 | elif dataset_name == 'imagenet': |
| 313 | return 1000 |
| 314 | elif dataset_name == 'food101': |
| 315 | return 101 |
| 316 | elif dataset_name == 'caltech256': |
| 317 | return 257 |
| 318 | elif dataset_name in ['oxford_pets', 'oxfordpets']: |
| 319 | return 37 |
| 320 | else: |
| 321 | return 10 |
| 322 | |
| 323 | |
| 324 | def get_class_names(dataset_name): |
no outgoing calls