Make infer output.
(
self,
batched_outputs: 'BatchedOutputs',
running: 'SeqList',
model_inputs: 'ModelInputs',
delta: 'ModelInputsDelta',
)
| 256 | |
| 257 | @record_function('make_infer_outputs') |
| 258 | def _make_infer_outputs( |
| 259 | self, |
| 260 | batched_outputs: 'BatchedOutputs', |
| 261 | running: 'SeqList', |
| 262 | model_inputs: 'ModelInputs', |
| 263 | delta: 'ModelInputsDelta', |
| 264 | ): |
| 265 | """Make infer output.""" |
| 266 | |
| 267 | def __get_logit(msg, logits: torch.Tensor, seq_length: list[int], idx: int): |
| 268 | logit = logits.split(seq_length)[idx] |
| 269 | if len(msg.all_logits) > 0: |
| 270 | # for chunked long context |
| 271 | msg.append_logits(logit) |
| 272 | logit = msg.logits |
| 273 | msg.all_logits.resize(0) |
| 274 | |
| 275 | return logit |
| 276 | |
| 277 | def __get_logprobs(batched_outputs: 'BatchedOutputs'): |
| 278 | """Get valid logprobs.""" |
| 279 | batch_size = batched_outputs.stop_pos.size(0) |
| 280 | logprobs = batched_outputs.logprobs |
| 281 | if logprobs is None: |
| 282 | return [None for _ in range(batch_size)] |
| 283 | num_decode_tokens = logprobs.indices.shape[0] // batch_size |
| 284 | results = [[] for _ in range(batch_size)] |
| 285 | range_tensor = torch.arange(num_decode_tokens, device=logprobs.indices.device) |
| 286 | |
| 287 | for idx in range(batch_size): |
| 288 | start = idx * num_decode_tokens |
| 289 | end = (idx + 1) * num_decode_tokens |
| 290 | mask = logprobs.indices[start:end][:, 0] >= 0 |
| 291 | stop_pos = batched_outputs.stop_pos[idx] |
| 292 | # only apply when stopped |
| 293 | if stop_pos > -1: |
| 294 | mask = torch.logical_and(mask, stop_pos >= range_tensor) |
| 295 | indices = logprobs.indices[start:end][mask].tolist() |
| 296 | vals = logprobs.vals[start:end][mask].tolist() |
| 297 | results[idx] = list(zip(vals, indices)) |
| 298 | return results |
| 299 | |
| 300 | logits = batched_outputs.logits |
| 301 | all_routed_experts = batched_outputs.all_routed_experts |
| 302 | |
| 303 | if model_inputs is not None and (model_inputs.is_chunk and not model_inputs.is_last_chunk): |
| 304 | # chunk long context does not need to update seqs and outputs |
| 305 | seq = running[0] |
| 306 | seq.append_routed_experts(all_routed_experts) |
| 307 | seq.append_logits(logits) |
| 308 | return dict() |
| 309 | |
| 310 | new_token_timestamp = batched_outputs.new_token_timestamp |
| 311 | logprobs = batched_outputs.logprobs |
| 312 | |
| 313 | all_logprobs = __get_logprobs(batched_outputs) |
| 314 | |
| 315 | seq_length = [seq.num_token_ids for seq in running] |
no test coverage detected