(
model: "DFlashDraftModel",
target: nn.Module,
input_ids: torch.LongTensor,
max_new_tokens: int,
stop_token_ids: Optional[list[int]],
temperature: float,
block_size: Optional[int] = None,
mask_token_id: Optional[int] = None,
return_stats: bool = False,
)
| 61 | |
| 62 | @torch.inference_mode() |
| 63 | def dflash_generate( |
| 64 | model: "DFlashDraftModel", |
| 65 | target: nn.Module, |
| 66 | input_ids: torch.LongTensor, |
| 67 | max_new_tokens: int, |
| 68 | stop_token_ids: Optional[list[int]], |
| 69 | temperature: float, |
| 70 | block_size: Optional[int] = None, |
| 71 | mask_token_id: Optional[int] = None, |
| 72 | return_stats: bool = False, |
| 73 | ): |
| 74 | num_input_tokens = input_ids.shape[1] |
| 75 | max_length = num_input_tokens + max_new_tokens |
| 76 | block_size = model.block_size if block_size is None else block_size |
| 77 | mask_token_id = model.mask_token_id if mask_token_id is None else mask_token_id |
| 78 | |
| 79 | output_ids = torch.full( |
| 80 | (1, max_length + block_size), mask_token_id, dtype=torch.long, device=target.device, |
| 81 | ) |
| 82 | position_ids = torch.arange(output_ids.shape[1], device=target.device).unsqueeze(0) |
| 83 | past_key_values_target = DynamicCache() |
| 84 | past_key_values_draft = DynamicCache() |
| 85 | |
| 86 | prefill_start = _cuda_time() if return_stats else None |
| 87 | output = target( |
| 88 | input_ids, |
| 89 | position_ids=position_ids[:, :num_input_tokens], |
| 90 | past_key_values=past_key_values_target, |
| 91 | use_cache=True, |
| 92 | logits_to_keep=1, |
| 93 | output_hidden_states=block_size > 1, |
| 94 | ) |
| 95 | |
| 96 | output_ids[:, :num_input_tokens] = input_ids |
| 97 | output_ids[:, num_input_tokens:num_input_tokens + 1] = sample(output.logits, temperature) |
| 98 | if block_size > 1: |
| 99 | target_hidden = extract_context_feature(output.hidden_states, model.target_layer_ids) |
| 100 | time_to_first_token = _cuda_time() - prefill_start if return_stats else None |
| 101 | |
| 102 | decode_start = _cuda_time() if return_stats else None |
| 103 | acceptance_lengths = [] |
| 104 | start = num_input_tokens |
| 105 | draft_prefill = True |
| 106 | |
| 107 | while start < max_length: |
| 108 | block_output_ids = output_ids[:, start : start + block_size].clone() |
| 109 | block_position_ids = position_ids[:, start : start + block_size] |
| 110 | if block_size > 1: |
| 111 | noise_embedding = target.model.embed_tokens(block_output_ids) |
| 112 | draft_logits = target.lm_head(model( |
| 113 | target_hidden=target_hidden, |
| 114 | noise_embedding=noise_embedding, |
| 115 | position_ids=position_ids[:, past_key_values_draft.get_seq_length(): start + block_size], |
| 116 | past_key_values=past_key_values_draft, |
| 117 | use_cache=True, |
| 118 | is_causal=False, |
| 119 | )[:, 1 - block_size :, :]) |
| 120 | past_key_values_draft.crop(start) |
no test coverage detected