(args)
| 24 | |
| 25 | @torch.inference_mode() |
| 26 | def generate(args): |
| 27 | model, tokenizer = load_model(model_path=args.model_path, device=args.device) |
| 28 | |
| 29 | if args.prompt_style == "sft": |
| 30 | conversation = default_conversation.copy() |
| 31 | conversation.append_message("Human", args.input_txt) |
| 32 | conversation.append_message("Assistant", None) |
| 33 | input_txt = conversation.get_prompt() |
| 34 | else: |
| 35 | BASE_INFERENCE_SUFFIX = "\n\n->\n\n" |
| 36 | input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}" |
| 37 | |
| 38 | inputs = tokenizer(input_txt, return_tensors="pt").to(args.device) |
| 39 | num_input_tokens = inputs["input_ids"].shape[-1] |
| 40 | output = model.generate( |
| 41 | **inputs, |
| 42 | max_new_tokens=args.max_new_tokens, |
| 43 | do_sample=args.do_sample, |
| 44 | temperature=args.temperature, |
| 45 | top_k=args.top_k, |
| 46 | top_p=args.top_p, |
| 47 | num_return_sequences=1, |
| 48 | ) |
| 49 | response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True) |
| 50 | logger.info(f"\nHuman: {args.input_txt} \n\nAssistant: \n{response}") |
| 51 | return response |
| 52 | |
| 53 | |
| 54 | if __name__ == "__main__": |
no test coverage detected
searching dependent graphs…