Get class names for a dataset
(dataset_name)
| 322 | |
| 323 | |
| 324 | def get_class_names(dataset_name): |
| 325 | """Get class names for a dataset""" |
| 326 | dataset_name = dataset_name.lower() |
| 327 | |
| 328 | class_names_map = { |
| 329 | 'cifar10': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], |
| 330 | 'mnist': ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], |
| 331 | 'fashion_mnist': ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'], |
| 332 | 'fashionmnist': ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'], |
| 333 | 'stl10': ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck'], |
| 334 | } |
| 335 | |
| 336 | if dataset_name in class_names_map: |
| 337 | return class_names_map[dataset_name] |
| 338 | |
| 339 | # For other datasets, return generic class names |
| 340 | num_classes = get_num_classes(dataset_name) |
| 341 | return [f'class_{i}' for i in range(num_classes)] |