MCPcopy
hub / github.com/hpcaitech/ColossalAI / generate

Function generate

applications/Colossal-LLaMA/inference/inference_example.py:26–51  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

24
25@torch.inference_mode()
26def generate(args):
27 model, tokenizer = load_model(model_path=args.model_path, device=args.device)
28
29 if args.prompt_style == "sft":
30 conversation = default_conversation.copy()
31 conversation.append_message("Human", args.input_txt)
32 conversation.append_message("Assistant", None)
33 input_txt = conversation.get_prompt()
34 else:
35 BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
36 input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"
37
38 inputs = tokenizer(input_txt, return_tensors="pt").to(args.device)
39 num_input_tokens = inputs["input_ids"].shape[-1]
40 output = model.generate(
41 **inputs,
42 max_new_tokens=args.max_new_tokens,
43 do_sample=args.do_sample,
44 temperature=args.temperature,
45 top_k=args.top_k,
46 top_p=args.top_p,
47 num_return_sequences=1,
48 )
49 response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
50 logger.info(f"\nHuman: {args.input_txt} \n\nAssistant: \n{response}")
51 return response
52
53
54if __name__ == "__main__":

Callers 1

Calls 10

tokenizerFunction · 0.85
load_modelFunction · 0.70
copyMethod · 0.45
append_messageMethod · 0.45
get_promptMethod · 0.45
toMethod · 0.45
generateMethod · 0.45
decodeMethod · 0.45
cpuMethod · 0.45
infoMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…