MCPcopy
hub / github.com/hpcaitech/ColossalAI / check_streamingllm

Function check_streamingllm

tests/test_infer/test_streamingllm.py:27–100  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

25
26
27def check_streamingllm():
28 setup_seed(20)
29 tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
30 model = LlamaForCausalLM(
31 LlamaConfig(
32 vocab_size=50000,
33 hidden_size=512,
34 intermediate_size=1536,
35 num_attention_heads=4,
36 num_key_value_heads=2,
37 num_hidden_layers=16,
38 )
39 ).cuda()
40 model = model.eval()
41
42 input_token_ids = data_gen(1, 4)
43
44 output_len = 128
45
46 inference_config = InferenceConfig(
47 max_batch_size=1,
48 max_output_len=output_len,
49 dtype="fp32",
50 use_cuda_kernel=True,
51 enable_streamingllm=True,
52 start_token_size=4,
53 generated_token_size=32,
54 )
55
56 inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
57 assert inference_engine.generation_config.max_new_tokens == output_len
58 inference_engine.add_request(prompts_token_ids=input_token_ids)
59 assert inference_engine.request_handler._has_waiting()
60
61 assert inference_config.start_token_size == inference_config.block_size
62
63 request_handler = inference_engine.request_handler
64 running_bb = request_handler.running_bb
65
66 for _ in range(12):
67 inference_engine.step()
68
69 assert running_bb.block_tables[0].tolist() == [0, -1, -1, -1]
70 assert running_bb.seq_lengths[0].item() == 16
71
72 for _ in range(16):
73 inference_engine.step()
74
75 assert running_bb.block_tables[0].tolist() == [0, 1, -1, -1]
76 assert running_bb.seq_lengths[0].item() == 32
77
78 for _ in range(16):
79 inference_engine.step()
80
81 assert running_bb.block_tables[0].tolist() == [0, 1, 2, -1]
82 assert running_bb.seq_lengths[0].item() == 48
83
84 for _ in range(16):

Callers

nothing calls this directly

Calls 11

add_requestMethod · 0.95
stepMethod · 0.95
InferenceConfigClass · 0.90
InferenceEngineClass · 0.90
_has_waitingMethod · 0.80
tolistMethod · 0.80
setup_seedFunction · 0.70
data_genFunction · 0.70
from_pretrainedMethod · 0.45
cudaMethod · 0.45
evalMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…