(args)
| 77 | |
| 78 | |
| 79 | def eval_model(args): |
| 80 | # Model |
| 81 | disable_torch_init() |
| 82 | model_path = os.path.expanduser(args.model_path) |
| 83 | model_name = get_model_name_from_path(model_path) |
| 84 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) |
| 85 | |
| 86 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] |
| 87 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) |
| 88 | answers_file = os.path.expanduser(args.answers_file) |
| 89 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) |
| 90 | ans_file = open(answers_file, "w") |
| 91 | |
| 92 | if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode: |
| 93 | args.conv_mode = args.conv_mode + '_mmtag' |
| 94 | print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.') |
| 95 | |
| 96 | data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config) |
| 97 | |
| 98 | for (input_ids, image_tensor, image_sizes), line in tqdm(zip(data_loader, questions), total=len(questions)): |
| 99 | idx = line["question_id"] |
| 100 | cur_prompt = line["text"] |
| 101 | |
| 102 | input_ids = input_ids.to(device='cuda', non_blocking=True) |
| 103 | |
| 104 | with torch.inference_mode(): |
| 105 | output_ids = model.generate( |
| 106 | input_ids, |
| 107 | images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), |
| 108 | image_sizes=image_sizes, |
| 109 | do_sample=True if args.temperature > 0 else False, |
| 110 | temperature=args.temperature, |
| 111 | top_p=args.top_p, |
| 112 | num_beams=args.num_beams, |
| 113 | max_new_tokens=args.max_new_tokens, |
| 114 | use_cache=True) |
| 115 | |
| 116 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() |
| 117 | |
| 118 | ans_id = shortuuid.uuid() |
| 119 | ans_file.write(json.dumps({"question_id": idx, |
| 120 | "prompt": cur_prompt, |
| 121 | "text": outputs, |
| 122 | "answer_id": ans_id, |
| 123 | "model_id": model_name, |
| 124 | "metadata": {}}) + "\n") |
| 125 | # ans_file.flush() |
| 126 | ans_file.close() |
| 127 | |
| 128 | if __name__ == "__main__": |
| 129 | parser = argparse.ArgumentParser() |
no test coverage detected