MCPcopy
hub / github.com/hpcaitech/ColossalAI / check_inference_engine

Function check_inference_engine

tests/test_infer/test_rpc_engine.py:22–71  ·  view source on GitHub ↗
(tp_size, use_engine=False, prompt_template=None, do_sample=True, policy=None)

Source from the content-addressed store, hash-verified

20
21
22def check_inference_engine(tp_size, use_engine=False, prompt_template=None, do_sample=True, policy=None):
23 setup_seed(20)
24 tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
25 model = "meta-llama/Llama-2-7b-hf" # remote mode path
26 inputs = [
27 "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
28 "介绍一下武汉,",
29 ]
30
31 output_len = 38
32 top_p = 0.5
33 top_k = 50
34
35 if use_engine:
36 inference_config = InferenceConfig(
37 max_output_len=output_len,
38 prompt_template=prompt_template,
39 dtype="fp32",
40 use_cuda_kernel=True,
41 tp_size=tp_size,
42 )
43 inference_engine = RPCInferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
44 assert inference_engine.generation_config.max_new_tokens == output_len
45 inference_engine.add_request(prompts=inputs)
46 assert inference_engine.request_handler._has_waiting()
47 generation_config = GenerationConfig(
48 max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
49 )
50 outputs = inference_engine.generate(generation_config=generation_config)
51 else:
52 if prompt_template:
53 # apply prompt template
54 inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
55 model = AutoModelForCausalLM.from_pretrained(model).cuda()
56 tokenizer.pad_token = tokenizer.eos_token
57 tokenizer.pad_token_id = tokenizer.eos_token_id
58 inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
59 inputs = inputs.cuda()
60 generation_config = GenerationConfig(
61 do_sample=do_sample,
62 dtype="fp32",
63 top_p=top_p,
64 top_k=top_k,
65 pad_token_id=tokenizer.pad_token_id,
66 max_new_tokens=output_len,
67 )
68 outputs = model.generate(inputs, generation_config=generation_config)
69 outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
70
71 return outputs
72
73
74def run_engine(tp_size, **kwargs):

Callers 1

run_engineFunction · 0.70

Calls 8

InferenceConfigClass · 0.90
RPCInferenceEngineClass · 0.90
_has_waitingMethod · 0.80
setup_seedFunction · 0.70
from_pretrainedMethod · 0.45
add_requestMethod · 0.45
generateMethod · 0.45
cudaMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…