MCPcopy
hub / github.com/hpcaitech/ColossalAI / check_inference_engine

Function check_inference_engine

tests/test_infer/test_cuda_graph.py:21–71  ·  view source on GitHub ↗
(use_cuda_graph=False, batch_size=32)

Source from the content-addressed store, hash-verified

19
20
21def check_inference_engine(use_cuda_graph=False, batch_size=32):
22 setup_seed(20)
23 tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
24 model = (
25 LlamaForCausalLM(
26 LlamaConfig(
27 vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
28 )
29 )
30 .cuda()
31 .half()
32 )
33 model = model.eval()
34
35 prompts_token_ids = []
36 for i in range(batch_size):
37 prompts_token_ids.append(
38 np.random.randint(low=0, high=100, size=random.randint(1, max(1024 // batch_size, 32))).tolist()
39 )
40
41 input_len = 1024
42 output_len = 128
43 do_sample = False
44 top_p = 0.5
45 top_k = 50
46
47 if use_cuda_graph:
48 inference_config = InferenceConfig(
49 max_batch_size=batch_size,
50 max_input_len=input_len,
51 max_output_len=output_len,
52 use_cuda_kernel=False,
53 use_cuda_graph=True,
54 block_size=16,
55 )
56 else:
57 inference_config = InferenceConfig(
58 max_batch_size=batch_size,
59 max_input_len=input_len,
60 max_output_len=output_len,
61 use_cuda_kernel=False,
62 use_cuda_graph=False,
63 block_size=16,
64 )
65
66 inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
67 assert inference_engine.generation_config.max_new_tokens == output_len
68 generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
69 outputs = inference_engine.generate(prompts_token_ids=prompts_token_ids, generation_config=generation_config)
70
71 return outputs
72
73
74def check_output_consistency(batch_size):

Callers 1

check_output_consistencyFunction · 0.70

Calls 10

generateMethod · 0.95
InferenceConfigClass · 0.90
InferenceEngineClass · 0.90
halfMethod · 0.80
tolistMethod · 0.80
setup_seedFunction · 0.70
from_pretrainedMethod · 0.45
cudaMethod · 0.45
evalMethod · 0.45
appendMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…