MCPcopy
hub / github.com/z-lab/dflash / dflash_generate

Function dflash_generate

dflash/model.py:63–169  ·  view source on GitHub ↗
(
    model: "DFlashDraftModel",
    target: nn.Module,
    input_ids: torch.LongTensor,
    max_new_tokens: int,
    stop_token_ids: Optional[list[int]],
    temperature: float,
    block_size: Optional[int] = None,
    mask_token_id: Optional[int] = None,
    return_stats: bool = False,
)

Source from the content-addressed store, hash-verified

61
62@torch.inference_mode()
63def dflash_generate(
64 model: "DFlashDraftModel",
65 target: nn.Module,
66 input_ids: torch.LongTensor,
67 max_new_tokens: int,
68 stop_token_ids: Optional[list[int]],
69 temperature: float,
70 block_size: Optional[int] = None,
71 mask_token_id: Optional[int] = None,
72 return_stats: bool = False,
73):
74 num_input_tokens = input_ids.shape[1]
75 max_length = num_input_tokens + max_new_tokens
76 block_size = model.block_size if block_size is None else block_size
77 mask_token_id = model.mask_token_id if mask_token_id is None else mask_token_id
78
79 output_ids = torch.full(
80 (1, max_length + block_size), mask_token_id, dtype=torch.long, device=target.device,
81 )
82 position_ids = torch.arange(output_ids.shape[1], device=target.device).unsqueeze(0)
83 past_key_values_target = DynamicCache()
84 past_key_values_draft = DynamicCache()
85
86 prefill_start = _cuda_time() if return_stats else None
87 output = target(
88 input_ids,
89 position_ids=position_ids[:, :num_input_tokens],
90 past_key_values=past_key_values_target,
91 use_cache=True,
92 logits_to_keep=1,
93 output_hidden_states=block_size > 1,
94 )
95
96 output_ids[:, :num_input_tokens] = input_ids
97 output_ids[:, num_input_tokens:num_input_tokens + 1] = sample(output.logits, temperature)
98 if block_size > 1:
99 target_hidden = extract_context_feature(output.hidden_states, model.target_layer_ids)
100 time_to_first_token = _cuda_time() - prefill_start if return_stats else None
101
102 decode_start = _cuda_time() if return_stats else None
103 acceptance_lengths = []
104 start = num_input_tokens
105 draft_prefill = True
106
107 while start < max_length:
108 block_output_ids = output_ids[:, start : start + block_size].clone()
109 block_position_ids = position_ids[:, start : start + block_size]
110 if block_size > 1:
111 noise_embedding = target.model.embed_tokens(block_output_ids)
112 draft_logits = target.lm_head(model(
113 target_hidden=target_hidden,
114 noise_embedding=noise_embedding,
115 position_ids=position_ids[:, past_key_values_draft.get_seq_length(): start + block_size],
116 past_key_values=past_key_values_draft,
117 use_cache=True,
118 is_causal=False,
119 )[:, 1 - block_size :, :])
120 past_key_values_draft.crop(start)

Callers 2

_run_transformersFunction · 0.85
spec_generateMethod · 0.85

Calls 3

_cuda_timeFunction · 0.85
sampleFunction · 0.85
extract_context_featureFunction · 0.85

Tested by

no test coverage detected