Evaluate a model on classification dataset. Args: eval_model (BaseEvalModel): model to evaluate seed (int, optional): random seed. Defaults to 42. num_shots (int, optional): number of shots to use. Defaults to 8. no_kv_caching (bool): whether to disable key-
(
args: argparse.Namespace,
eval_model,
seed: int = 42,
num_shots: int = 8,
dataset_name: str = "imagenet",
cached_features=None,
no_kv_caching=False,
use_prompt_ensembling: bool = False,
)
| 1116 | |
| 1117 | |
| 1118 | def evaluate_classification( |
| 1119 | args: argparse.Namespace, |
| 1120 | eval_model, |
| 1121 | seed: int = 42, |
| 1122 | num_shots: int = 8, |
| 1123 | dataset_name: str = "imagenet", |
| 1124 | cached_features=None, |
| 1125 | no_kv_caching=False, |
| 1126 | use_prompt_ensembling: bool = False, |
| 1127 | ): |
| 1128 | """ |
| 1129 | Evaluate a model on classification dataset. |
| 1130 | |
| 1131 | Args: |
| 1132 | eval_model (BaseEvalModel): model to evaluate |
| 1133 | seed (int, optional): random seed. Defaults to 42. |
| 1134 | num_shots (int, optional): number of shots to use. Defaults to 8. |
| 1135 | no_kv_caching (bool): whether to disable key-value caching |
| 1136 | dataset_name (str, optional): dataset name. Defaults to "imagenet". |
| 1137 | cached_features (tensor, optional): cached demonstration features for RICES. Defaults to None. |
| 1138 | |
| 1139 | Returns: |
| 1140 | float: accuracy score |
| 1141 | """ |
| 1142 | if args.model != "open_flamingo": |
| 1143 | raise NotImplementedError( |
| 1144 | "evaluate_classification is currently only supported for OpenFlamingo" |
| 1145 | ) |
| 1146 | |
| 1147 | if dataset_name == "imagenet": |
| 1148 | train_dataset = ImageNetDataset(os.path.join(args.imagenet_root, "train")) |
| 1149 | test_dataset = ImageNetDataset(os.path.join(args.imagenet_root, "val")) |
| 1150 | prompt_fn = lambda x: eval_model.get_imagenet_prompt(label=x["class_name"]) |
| 1151 | all_class_names = IMAGENET_CLASSNAMES |
| 1152 | k = 5 |
| 1153 | elif dataset_name == "hateful_memes": |
| 1154 | train_dataset = HatefulMemesDataset( |
| 1155 | args.hateful_memes_image_dir_path, |
| 1156 | args.hateful_memes_train_annotations_json_path, |
| 1157 | ) |
| 1158 | test_dataset = HatefulMemesDataset( |
| 1159 | args.hateful_memes_image_dir_path, |
| 1160 | args.hateful_memes_test_annotations_json_path, |
| 1161 | ) |
| 1162 | prompt_fn = lambda x: eval_model.get_hateful_memes_prompt( |
| 1163 | text=x["ocr"], label=x["class_name"] |
| 1164 | ) |
| 1165 | all_class_names = HM_CLASSNAMES |
| 1166 | k = 1 |
| 1167 | else: |
| 1168 | raise ValueError(f"Unsupported dataset {dataset_name}") |
| 1169 | |
| 1170 | class_id_to_name = dict(zip(range(len(all_class_names)), all_class_names)) |
| 1171 | |
| 1172 | effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model) |
| 1173 | |
| 1174 | np.random.seed(seed) |
| 1175 | test_dataloader = utils.prepare_eval_samples( |
no test coverage detected