MCPcopy
hub / github.com/hpcaitech/ColossalAI / check_inference_engine

Function check_inference_engine

tests/test_infer/test_inference_engine.py:26–85  ·  view source on GitHub ↗
(use_engine=False, prompt_template=None, do_sample=True, policy=None)

Source from the content-addressed store, hash-verified

24
25
26def check_inference_engine(use_engine=False, prompt_template=None, do_sample=True, policy=None):
27 setup_seed(20)
28 tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
29 model = LlamaForCausalLM(
30 LlamaConfig(
31 vocab_size=50000,
32 hidden_size=512,
33 intermediate_size=1536,
34 num_attention_heads=4,
35 num_key_value_heads=2,
36 num_hidden_layers=16,
37 )
38 ).cuda()
39 model = model.eval()
40 inputs = [
41 "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
42 "介绍一下武汉,",
43 ]
44
45 output_len = 38
46 do_sample = do_sample
47 top_p = 0.5
48 top_k = 50
49
50 if use_engine:
51 inference_config = InferenceConfig(
52 max_output_len=output_len,
53 prompt_template=prompt_template,
54 dtype="fp32",
55 use_cuda_kernel=True,
56 tp_size=dist.get_world_size(),
57 )
58 inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
59 assert inference_engine.generation_config.max_new_tokens == output_len
60 inference_engine.add_request(prompts=inputs)
61 assert inference_engine.request_handler._has_waiting()
62 generation_config = GenerationConfig(
63 max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
64 )
65 outputs = inference_engine.generate(generation_config=generation_config)
66 else:
67 if prompt_template:
68 # apply prompt template
69 inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
70 tokenizer.pad_token = tokenizer.eos_token
71 tokenizer.pad_token_id = tokenizer.eos_token_id
72 inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
73 inputs = inputs.cuda()
74 generation_config = GenerationConfig(
75 do_sample=do_sample,
76 dtype="fp32",
77 top_p=top_p,
78 top_k=top_k,
79 pad_token_id=tokenizer.pad_token_id,
80 max_new_tokens=output_len,
81 )
82 outputs = model.generate(inputs, generation_config=generation_config)
83 outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

Callers

nothing calls this directly

Calls 11

add_requestMethod · 0.95
generateMethod · 0.95
InferenceConfigClass · 0.90
InferenceEngineClass · 0.90
_has_waitingMethod · 0.80
setup_seedFunction · 0.70
from_pretrainedMethod · 0.45
cudaMethod · 0.45
evalMethod · 0.45
get_world_sizeMethod · 0.45
generateMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…