(num_layers, max_length)
| 94 | |
| 95 | |
| 96 | def check_spec_dec(num_layers, max_length): |
| 97 | torch.manual_seed(123) |
| 98 | |
| 99 | tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") |
| 100 | # Dummy configs for testing |
| 101 | toy_config = LlamaConfig(num_hidden_layers=num_layers) |
| 102 | toy_config.pad_token_id = tokenizer.eos_token_id |
| 103 | drafter_model = LlamaForCausalLM(toy_config) |
| 104 | drafter_model = drafter_model.eval().cuda() |
| 105 | large_config = LlamaConfig( |
| 106 | hidden_size=4096, |
| 107 | intermediate_size=11008, |
| 108 | num_attention_heads=32, |
| 109 | num_hidden_layers=8, |
| 110 | num_key_value_heads=32, |
| 111 | max_position_embeddings=2048, |
| 112 | ) |
| 113 | large_config.pad_token_id = tokenizer.eos_token_id |
| 114 | main_model = LlamaForCausalLM(large_config) |
| 115 | |
| 116 | inference_config = InferenceConfig( |
| 117 | dtype="fp16", |
| 118 | micro_batch_size=1, |
| 119 | max_batch_size=1, |
| 120 | max_input_len=128, |
| 121 | max_output_len=128, |
| 122 | prefill_ratio=1.2, |
| 123 | block_size=16, |
| 124 | ) |
| 125 | engine = InferenceEngine(main_model, tokenizer, inference_config) |
| 126 | engine.enable_spec_dec(drafter_model, n_spec_tokens=5) |
| 127 | |
| 128 | dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda") |
| 129 | generation_config = GenerationConfig( |
| 130 | pad_token_id=tokenizer.eos_token_id, |
| 131 | max_length=max_length, |
| 132 | eos_token_id=tokenizer.eos_token_id, |
| 133 | ) |
| 134 | out, out_token_ids = engine.generate( |
| 135 | prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True |
| 136 | ) |
| 137 | engine.disable_spec_dec() |
| 138 | engine.clear_spec_dec() |
| 139 | |
| 140 | assert not engine.use_spec_dec |
| 141 | assert engine.drafter is None and engine.drafter_model is None |
| 142 | |
| 143 | max_new_tokens = max_length - dummy_inputs.size(1) |
| 144 | assert len(out) == 1 |
| 145 | assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens |
| 146 | |
| 147 | # test GLIDE model |
| 148 | glide_config = GlideLlamaConfig( |
| 149 | intermediate_size=8192, |
| 150 | large_hidden_size=4096, |
| 151 | large_num_attention_heads=32, |
| 152 | num_hidden_layers=num_layers, |
| 153 | ) |
nothing calls this directly
no test coverage detected
searching dependent graphs…