MCPcopy
hub / github.com/XPixelGroup/DiffBIR / eval_model

Function eval_model

llava/eval/model_vqa_science.py:29–94  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

27
28
29def eval_model(args):
30 # Model
31 disable_torch_init()
32 model_path = os.path.expanduser(args.model_path)
33 model_name = get_model_name_from_path(model_path)
34 tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
35
36 questions = json.load(open(os.path.expanduser(args.question_file), "r"))
37 questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
38 answers_file = os.path.expanduser(args.answers_file)
39 os.makedirs(os.path.dirname(answers_file), exist_ok=True)
40 ans_file = open(answers_file, "w")
41 for i, line in enumerate(tqdm(questions)):
42 idx = line["id"]
43 question = line['conversations'][0]
44 qs = question['value'].replace('<image>', '').strip()
45 cur_prompt = qs
46
47 if 'image' in line:
48 image_file = line["image"]
49 image = Image.open(os.path.join(args.image_folder, image_file))
50 image_tensor = process_images([image], image_processor, model.config)[0]
51 images = image_tensor.unsqueeze(0).half().cuda()
52 image_sizes = [image.size]
53 if getattr(model.config, 'mm_use_im_start_end', False):
54 qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
55 else:
56 qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
57 cur_prompt = '<image>' + '\n' + cur_prompt
58 else:
59 images = None
60 image_sizes = None
61
62 if args.single_pred_prompt:
63 qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
64 cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
65
66 conv = conv_templates[args.conv_mode].copy()
67 conv.append_message(conv.roles[0], qs)
68 conv.append_message(conv.roles[1], None)
69 prompt = conv.get_prompt()
70
71 input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
72
73 with torch.inference_mode():
74 output_ids = model.generate(
75 input_ids,
76 images=images,
77 image_sizes=image_sizes,
78 do_sample=True if args.temperature > 0 else False,
79 temperature=args.temperature,
80 max_new_tokens=1024,
81 use_cache=True,
82 )
83
84 outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
85
86 ans_id = shortuuid.uuid()

Callers 1

Calls 12

disable_torch_initFunction · 0.90
get_model_name_from_pathFunction · 0.90
load_pretrained_modelFunction · 0.90
process_imagesFunction · 0.90
tokenizer_image_tokenFunction · 0.90
copyMethod · 0.80
append_messageMethod · 0.80
get_promptMethod · 0.80
writeMethod · 0.80
flushMethod · 0.80
get_chunkFunction · 0.70
generateMethod · 0.45

Tested by

no test coverage detected