MCPcopy
hub / github.com/lm-sys/FastChat / main

Function main

fastchat/serve/huggingface_api.py:16–56  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

14
15@torch.inference_mode()
16def main(args):
17 # Load model
18 model, tokenizer = load_model(
19 args.model_path,
20 device=args.device,
21 num_gpus=args.num_gpus,
22 max_gpu_memory=args.max_gpu_memory,
23 load_8bit=args.load_8bit,
24 cpu_offloading=args.cpu_offloading,
25 revision=args.revision,
26 debug=args.debug,
27 )
28
29 # Build the prompt with a conversation template
30 msg = args.message
31 conv = get_conversation_template(args.model_path)
32 conv.append_message(conv.roles[0], msg)
33 conv.append_message(conv.roles[1], None)
34 prompt = conv.get_prompt()
35
36 # Run inference
37 inputs = tokenizer([prompt], return_tensors="pt").to(args.device)
38 output_ids = model.generate(
39 **inputs,
40 do_sample=True if args.temperature > 1e-5 else False,
41 temperature=args.temperature,
42 repetition_penalty=args.repetition_penalty,
43 max_new_tokens=args.max_new_tokens,
44 )
45
46 if model.config.is_encoder_decoder:
47 output_ids = output_ids[0]
48 else:
49 output_ids = output_ids[0][len(inputs["input_ids"][0]) :]
50 outputs = tokenizer.decode(
51 output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
52 )
53
54 # Print results
55 print(f"{conv.roles[0]}: {msg}")
56 print(f"{conv.roles[1]}: {outputs}")
57
58
59if __name__ == "__main__":

Callers 1

huggingface_api.pyFile · 0.70

Calls 6

load_modelFunction · 0.90
append_messageMethod · 0.80
get_promptMethod · 0.80
toMethod · 0.80
generateMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…