(args)
| 14 | |
| 15 | @torch.inference_mode() |
| 16 | def main(args): |
| 17 | # Load model |
| 18 | model, tokenizer = load_model( |
| 19 | args.model_path, |
| 20 | device=args.device, |
| 21 | num_gpus=args.num_gpus, |
| 22 | max_gpu_memory=args.max_gpu_memory, |
| 23 | load_8bit=args.load_8bit, |
| 24 | cpu_offloading=args.cpu_offloading, |
| 25 | revision=args.revision, |
| 26 | debug=args.debug, |
| 27 | ) |
| 28 | |
| 29 | # Build the prompt with a conversation template |
| 30 | msg = args.message |
| 31 | conv = get_conversation_template(args.model_path) |
| 32 | conv.append_message(conv.roles[0], msg) |
| 33 | conv.append_message(conv.roles[1], None) |
| 34 | prompt = conv.get_prompt() |
| 35 | |
| 36 | # Run inference |
| 37 | inputs = tokenizer([prompt], return_tensors="pt").to(args.device) |
| 38 | output_ids = model.generate( |
| 39 | **inputs, |
| 40 | do_sample=True if args.temperature > 1e-5 else False, |
| 41 | temperature=args.temperature, |
| 42 | repetition_penalty=args.repetition_penalty, |
| 43 | max_new_tokens=args.max_new_tokens, |
| 44 | ) |
| 45 | |
| 46 | if model.config.is_encoder_decoder: |
| 47 | output_ids = output_ids[0] |
| 48 | else: |
| 49 | output_ids = output_ids[0][len(inputs["input_ids"][0]) :] |
| 50 | outputs = tokenizer.decode( |
| 51 | output_ids, skip_special_tokens=True, spaces_between_special_tokens=False |
| 52 | ) |
| 53 | |
| 54 | # Print results |
| 55 | print(f"{conv.roles[0]}: {msg}") |
| 56 | print(f"{conv.roles[1]}: {outputs}") |
| 57 | |
| 58 | |
| 59 | if __name__ == "__main__": |
no test coverage detected
searching dependent graphs…