MCPcopy
hub / github.com/mlfoundations/open_flamingo / evaluate_classification

Function evaluate_classification

open_flamingo/eval/evaluate.py:1118–1297  ·  view source on GitHub ↗

Evaluate a model on classification dataset. Args: eval_model (BaseEvalModel): model to evaluate seed (int, optional): random seed. Defaults to 42. num_shots (int, optional): number of shots to use. Defaults to 8. no_kv_caching (bool): whether to disable key-

(
    args: argparse.Namespace,
    eval_model,
    seed: int = 42,
    num_shots: int = 8,
    dataset_name: str = "imagenet",
    cached_features=None,
    no_kv_caching=False,
    use_prompt_ensembling: bool = False,
)

Source from the content-addressed store, hash-verified

1116
1117
1118def evaluate_classification(
1119 args: argparse.Namespace,
1120 eval_model,
1121 seed: int = 42,
1122 num_shots: int = 8,
1123 dataset_name: str = "imagenet",
1124 cached_features=None,
1125 no_kv_caching=False,
1126 use_prompt_ensembling: bool = False,
1127):
1128 """
1129 Evaluate a model on classification dataset.
1130
1131 Args:
1132 eval_model (BaseEvalModel): model to evaluate
1133 seed (int, optional): random seed. Defaults to 42.
1134 num_shots (int, optional): number of shots to use. Defaults to 8.
1135 no_kv_caching (bool): whether to disable key-value caching
1136 dataset_name (str, optional): dataset name. Defaults to "imagenet".
1137 cached_features (tensor, optional): cached demonstration features for RICES. Defaults to None.
1138
1139 Returns:
1140 float: accuracy score
1141 """
1142 if args.model != "open_flamingo":
1143 raise NotImplementedError(
1144 "evaluate_classification is currently only supported for OpenFlamingo"
1145 )
1146
1147 if dataset_name == "imagenet":
1148 train_dataset = ImageNetDataset(os.path.join(args.imagenet_root, "train"))
1149 test_dataset = ImageNetDataset(os.path.join(args.imagenet_root, "val"))
1150 prompt_fn = lambda x: eval_model.get_imagenet_prompt(label=x["class_name"])
1151 all_class_names = IMAGENET_CLASSNAMES
1152 k = 5
1153 elif dataset_name == "hateful_memes":
1154 train_dataset = HatefulMemesDataset(
1155 args.hateful_memes_image_dir_path,
1156 args.hateful_memes_train_annotations_json_path,
1157 )
1158 test_dataset = HatefulMemesDataset(
1159 args.hateful_memes_image_dir_path,
1160 args.hateful_memes_test_annotations_json_path,
1161 )
1162 prompt_fn = lambda x: eval_model.get_hateful_memes_prompt(
1163 text=x["ocr"], label=x["class_name"]
1164 )
1165 all_class_names = HM_CLASSNAMES
1166 k = 1
1167 else:
1168 raise ValueError(f"Unsupported dataset {dataset_name}")
1169
1170 class_id_to_name = dict(zip(range(len(all_class_names)), all_class_names))
1171
1172 effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model)
1173
1174 np.random.seed(seed)
1175 test_dataloader = utils.prepare_eval_samples(

Callers 1

mainFunction · 0.85

Calls 7

findMethod · 0.95
ImageNetDatasetClass · 0.90
HatefulMemesDatasetClass · 0.90
RICESClass · 0.90
get_imagenet_promptMethod · 0.80

Tested by

no test coverage detected