MCPcopy
hub / github.com/XPixelGroup/DiffBIR / eval_model

Function eval_model

llava/eval/model_qa.py:14–55  ·  view source on GitHub ↗
(model_name, questions_file, answers_file)

Source from the content-addressed store, hash-verified

12
13@torch.inference_mode()
14def eval_model(model_name, questions_file, answers_file):
15 # Model
16 disable_torch_init()
17 model_name = os.path.expanduser(model_name)
18 tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
19 model = AutoModelForCausalLM.from_pretrained(model_name,
20 torch_dtype=torch.float16).cuda()
21
22
23 ques_file = open(os.path.expanduser(questions_file), "r")
24 ans_file = open(os.path.expanduser(answers_file), "w")
25 for i, line in enumerate(tqdm(ques_file)):
26 idx = json.loads(line)["question_id"]
27 qs = json.loads(line)["text"]
28 cat = json.loads(line)["category"]
29 conv = default_conversation.copy()
30 conv.append_message(conv.roles[0], qs)
31 prompt = conv.get_prompt()
32 inputs = tokenizer([prompt])
33 input_ids = torch.as_tensor(inputs.input_ids).cuda()
34 output_ids = model.generate(
35 input_ids,
36 do_sample=True,
37 use_cache=True,
38 temperature=0.7,
39 max_new_tokens=1024,)
40 outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
41 try:
42 index = outputs.index(conv.sep, len(prompt))
43 except ValueError:
44 outputs += conv.sep
45 index = outputs.index(conv.sep, len(prompt))
46
47 outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
48 ans_id = shortuuid.uuid()
49 ans_file.write(json.dumps({"question_id": idx,
50 "text": outputs,
51 "answer_id": ans_id,
52 "model_id": model_name,
53 "metadata": {}}) + "\n")
54 ans_file.flush()
55 ans_file.close()
56
57if __name__ == "__main__":
58 parser = argparse.ArgumentParser()

Callers 1

model_qa.pyFile · 0.70

Calls 7

disable_torch_initFunction · 0.90
copyMethod · 0.80
append_messageMethod · 0.80
get_promptMethod · 0.80
writeMethod · 0.80
flushMethod · 0.80
generateMethod · 0.45

Tested by

no test coverage detected