(args)
| 27 | |
| 28 | |
| 29 | def eval_model(args): |
| 30 | # Model |
| 31 | disable_torch_init() |
| 32 | model_path = os.path.expanduser(args.model_path) |
| 33 | model_name = get_model_name_from_path(model_path) |
| 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) |
| 35 | |
| 36 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] |
| 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) |
| 38 | answers_file = os.path.expanduser(args.answers_file) |
| 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) |
| 40 | ans_file = open(answers_file, "w") |
| 41 | for line in tqdm(questions): |
| 42 | idx = line["question_id"] |
| 43 | image_file = line["image"] |
| 44 | qs = line["text"] |
| 45 | cur_prompt = qs |
| 46 | if model.config.mm_use_im_start_end: |
| 47 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs |
| 48 | else: |
| 49 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs |
| 50 | |
| 51 | conv = conv_templates[args.conv_mode].copy() |
| 52 | conv.append_message(conv.roles[0], qs) |
| 53 | conv.append_message(conv.roles[1], None) |
| 54 | prompt = conv.get_prompt() |
| 55 | |
| 56 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() |
| 57 | |
| 58 | image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB') |
| 59 | image_tensor = process_images([image], image_processor, model.config)[0] |
| 60 | |
| 61 | with torch.inference_mode(): |
| 62 | output_ids = model.generate( |
| 63 | input_ids, |
| 64 | images=image_tensor.unsqueeze(0).half().cuda(), |
| 65 | image_sizes=[image.size], |
| 66 | do_sample=True if args.temperature > 0 else False, |
| 67 | temperature=args.temperature, |
| 68 | top_p=args.top_p, |
| 69 | num_beams=args.num_beams, |
| 70 | # no_repeat_ngram_size=3, |
| 71 | max_new_tokens=1024, |
| 72 | use_cache=True) |
| 73 | |
| 74 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() |
| 75 | |
| 76 | ans_id = shortuuid.uuid() |
| 77 | ans_file.write(json.dumps({"question_id": idx, |
| 78 | "prompt": cur_prompt, |
| 79 | "text": outputs, |
| 80 | "answer_id": ans_id, |
| 81 | "model_id": model_name, |
| 82 | "metadata": {}}) + "\n") |
| 83 | ans_file.flush() |
| 84 | ans_file.close() |
| 85 | |
| 86 | if __name__ == "__main__": |
no test coverage detected