MCPcopy
hub / github.com/Zyphra/Zonos / _decode_one_token

Method _decode_one_token

zonos/model.py:118–179  ·  view source on GitHub ↗

Single-step decode. Prepares the hidden states, possibly replicates them for CFG, and then delegates to `_compute_logits`. Below we wrap this function with a simple CUDA Graph capturing mechanism, doing 3 warmup steps if needed and then capturing or replaying the gr

(
        self,
        input_ids: torch.Tensor,
        inference_params: InferenceParams,
        cfg_scale: float,
        allow_cudagraphs: bool = True,
    )

Source from the content-addressed store, hash-verified

116 return logits
117
118 def _decode_one_token(
119 self,
120 input_ids: torch.Tensor,
121 inference_params: InferenceParams,
122 cfg_scale: float,
123 allow_cudagraphs: bool = True,
124 ) -> torch.Tensor:
125 """
126 Single-step decode. Prepares the hidden states, possibly replicates them
127 for CFG, and then delegates to `_compute_logits`.
128
129 Below we wrap this function with a simple CUDA Graph capturing mechanism,
130 doing 3 warmup steps if needed and then capturing or replaying the graph.
131 We only recapture if the batch size changes.
132 """
133 # TODO: support cfg_scale==1
134 if cfg_scale == 1.0:
135 hidden_states = self.embed_codes(input_ids)
136 return self._compute_logits(hidden_states, inference_params, cfg_scale)
137
138 bsz = input_ids.size(0)
139
140 if not allow_cudagraphs or input_ids.device.type != "cuda":
141 hidden_states_local = self.embed_codes(input_ids)
142 hidden_states_local = hidden_states_local.repeat(2, 1, 1)
143 return self._compute_logits(hidden_states_local, inference_params, cfg_scale)
144
145 need_capture = (self._cg_graph is None) or (self._cg_batch_size != bsz)
146
147 if need_capture:
148 self._cg_graph = None
149
150 self._cg_batch_size = bsz
151 self._cg_inference_params = inference_params
152 self._cg_scale = cfg_scale
153
154 for _ in range(3):
155 hidden_states = self.embed_codes(input_ids)
156 hidden_states = hidden_states.repeat(2, 1, 1) # because cfg != 1.0
157 logits = self._compute_logits(hidden_states, inference_params, cfg_scale)
158
159 self._cg_input_ids = input_ids.clone()
160 self._cg_logits = torch.empty_like(logits)
161
162 g = torch.cuda.CUDAGraph()
163
164 def capture_region():
165 hidden_states_local = self.embed_codes(self._cg_input_ids)
166 hidden_states_local = hidden_states_local.repeat(2, 1, 1)
167 self._cg_logits = self._compute_logits(hidden_states_local, self._cg_inference_params, self._cg_scale)
168
169 with torch.cuda.graph(g):
170 capture_region()
171
172 self._cg_graph = g
173
174 else:
175 self._cg_input_ids.copy_(input_ids)

Callers

nothing calls this directly

Calls 2

embed_codesMethod · 0.95
_compute_logitsMethod · 0.95

Tested by

no test coverage detected