MCPcopy
hub / github.com/hpcaitech/ColossalAI / check_spec_dec

Function check_spec_dec

tests/test_infer/test_inference_engine.py:96–163  ·  view source on GitHub ↗
(num_layers, max_length)

Source from the content-addressed store, hash-verified

94
95
96def check_spec_dec(num_layers, max_length):
97 torch.manual_seed(123)
98
99 tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
100 # Dummy configs for testing
101 toy_config = LlamaConfig(num_hidden_layers=num_layers)
102 toy_config.pad_token_id = tokenizer.eos_token_id
103 drafter_model = LlamaForCausalLM(toy_config)
104 drafter_model = drafter_model.eval().cuda()
105 large_config = LlamaConfig(
106 hidden_size=4096,
107 intermediate_size=11008,
108 num_attention_heads=32,
109 num_hidden_layers=8,
110 num_key_value_heads=32,
111 max_position_embeddings=2048,
112 )
113 large_config.pad_token_id = tokenizer.eos_token_id
114 main_model = LlamaForCausalLM(large_config)
115
116 inference_config = InferenceConfig(
117 dtype="fp16",
118 micro_batch_size=1,
119 max_batch_size=1,
120 max_input_len=128,
121 max_output_len=128,
122 prefill_ratio=1.2,
123 block_size=16,
124 )
125 engine = InferenceEngine(main_model, tokenizer, inference_config)
126 engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
127
128 dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
129 generation_config = GenerationConfig(
130 pad_token_id=tokenizer.eos_token_id,
131 max_length=max_length,
132 eos_token_id=tokenizer.eos_token_id,
133 )
134 out, out_token_ids = engine.generate(
135 prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
136 )
137 engine.disable_spec_dec()
138 engine.clear_spec_dec()
139
140 assert not engine.use_spec_dec
141 assert engine.drafter is None and engine.drafter_model is None
142
143 max_new_tokens = max_length - dummy_inputs.size(1)
144 assert len(out) == 1
145 assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens
146
147 # test GLIDE model
148 glide_config = GlideLlamaConfig(
149 intermediate_size=8192,
150 large_hidden_size=4096,
151 large_num_attention_heads=32,
152 num_hidden_layers=num_layers,
153 )

Callers

nothing calls this directly

Calls 13

generateMethod · 0.95
InferenceConfigClass · 0.90
InferenceEngineClass · 0.90
GlideLlamaConfigClass · 0.90
enable_spec_decMethod · 0.80
disable_spec_decMethod · 0.80
clear_spec_decMethod · 0.80
manual_seedMethod · 0.45
from_pretrainedMethod · 0.45
cudaMethod · 0.45
evalMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…