MCPcopy
hub / github.com/XPixelGroup/DiffBIR / eval_model

Function eval_model

llava/eval/model_vqa_loader.py:79–126  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

77
78
79def eval_model(args):
80 # Model
81 disable_torch_init()
82 model_path = os.path.expanduser(args.model_path)
83 model_name = get_model_name_from_path(model_path)
84 tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
85
86 questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
87 questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
88 answers_file = os.path.expanduser(args.answers_file)
89 os.makedirs(os.path.dirname(answers_file), exist_ok=True)
90 ans_file = open(answers_file, "w")
91
92 if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
93 args.conv_mode = args.conv_mode + '_mmtag'
94 print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
95
96 data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
97
98 for (input_ids, image_tensor, image_sizes), line in tqdm(zip(data_loader, questions), total=len(questions)):
99 idx = line["question_id"]
100 cur_prompt = line["text"]
101
102 input_ids = input_ids.to(device='cuda', non_blocking=True)
103
104 with torch.inference_mode():
105 output_ids = model.generate(
106 input_ids,
107 images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
108 image_sizes=image_sizes,
109 do_sample=True if args.temperature > 0 else False,
110 temperature=args.temperature,
111 top_p=args.top_p,
112 num_beams=args.num_beams,
113 max_new_tokens=args.max_new_tokens,
114 use_cache=True)
115
116 outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
117
118 ans_id = shortuuid.uuid()
119 ans_file.write(json.dumps({"question_id": idx,
120 "prompt": cur_prompt,
121 "text": outputs,
122 "answer_id": ans_id,
123 "model_id": model_name,
124 "metadata": {}}) + "\n")
125 # ans_file.flush()
126 ans_file.close()
127
128if __name__ == "__main__":
129 parser = argparse.ArgumentParser()

Callers 1

Calls 7

disable_torch_initFunction · 0.90
get_model_name_from_pathFunction · 0.90
load_pretrained_modelFunction · 0.90
create_data_loaderFunction · 0.85
writeMethod · 0.80
get_chunkFunction · 0.70
generateMethod · 0.45

Tested by

no test coverage detected