(tokenizer, spec_num: int)
| 18 | |
| 19 | @pytest.mark.parametrize("spec_num", [SPEC_NUM]) |
| 20 | def test_drafter(tokenizer, spec_num: int): |
| 21 | torch.manual_seed(123) |
| 22 | |
| 23 | device = get_current_device() |
| 24 | toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS) |
| 25 | toy_config.pad_token_id = tokenizer.eos_token_id |
| 26 | drafter_model = LlamaForCausalLM(toy_config) |
| 27 | drafter_model = drafter_model.eval().cuda() |
| 28 | |
| 29 | drafter = Drafter(drafter_model, tokenizer, device=device) |
| 30 | |
| 31 | input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device) |
| 32 | out = drafter.speculate(input_ids, spec_num) |
| 33 | past_kv_length = input_ids.size(1) + spec_num - 1 |
| 34 | |
| 35 | assert out.speculated_length == spec_num |
| 36 | assert out.next_tokens.shape == (spec_num,) |
| 37 | assert out.logits.shape == (spec_num, len(tokenizer)) |
| 38 | assert out.past_key_values[0][0].size(2) == past_kv_length |
| 39 | |
| 40 | reject_num = max(0, spec_num - 1) |
| 41 | trimmed_past_key_values = drafter.trim_kv_cache(out.past_key_values, reject_num) |
| 42 | assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num |
| 43 | |
| 44 | |
| 45 | def test_spec_dec(tokenizer): |
no test coverage detected
searching dependent graphs…