MCPcopy
hub / github.com/InternLM/lmdeploy / _make_infer_outputs

Method _make_infer_outputs

lmdeploy/pytorch/engine/engine_loop.py:258–368  ·  view source on GitHub ↗

Make infer output.

(
        self,
        batched_outputs: 'BatchedOutputs',
        running: 'SeqList',
        model_inputs: 'ModelInputs',
        delta: 'ModelInputsDelta',
    )

Source from the content-addressed store, hash-verified

256
257 @record_function('make_infer_outputs')
258 def _make_infer_outputs(
259 self,
260 batched_outputs: 'BatchedOutputs',
261 running: 'SeqList',
262 model_inputs: 'ModelInputs',
263 delta: 'ModelInputsDelta',
264 ):
265 """Make infer output."""
266
267 def __get_logit(msg, logits: torch.Tensor, seq_length: list[int], idx: int):
268 logit = logits.split(seq_length)[idx]
269 if len(msg.all_logits) > 0:
270 # for chunked long context
271 msg.append_logits(logit)
272 logit = msg.logits
273 msg.all_logits.resize(0)
274
275 return logit
276
277 def __get_logprobs(batched_outputs: 'BatchedOutputs'):
278 """Get valid logprobs."""
279 batch_size = batched_outputs.stop_pos.size(0)
280 logprobs = batched_outputs.logprobs
281 if logprobs is None:
282 return [None for _ in range(batch_size)]
283 num_decode_tokens = logprobs.indices.shape[0] // batch_size
284 results = [[] for _ in range(batch_size)]
285 range_tensor = torch.arange(num_decode_tokens, device=logprobs.indices.device)
286
287 for idx in range(batch_size):
288 start = idx * num_decode_tokens
289 end = (idx + 1) * num_decode_tokens
290 mask = logprobs.indices[start:end][:, 0] >= 0
291 stop_pos = batched_outputs.stop_pos[idx]
292 # only apply when stopped
293 if stop_pos > -1:
294 mask = torch.logical_and(mask, stop_pos >= range_tensor)
295 indices = logprobs.indices[start:end][mask].tolist()
296 vals = logprobs.vals[start:end][mask].tolist()
297 results[idx] = list(zip(vals, indices))
298 return results
299
300 logits = batched_outputs.logits
301 all_routed_experts = batched_outputs.all_routed_experts
302
303 if model_inputs is not None and (model_inputs.is_chunk and not model_inputs.is_last_chunk):
304 # chunk long context does not need to update seqs and outputs
305 seq = running[0]
306 seq.append_routed_experts(all_routed_experts)
307 seq.append_logits(logits)
308 return dict()
309
310 new_token_timestamp = batched_outputs.new_token_timestamp
311 logprobs = batched_outputs.logprobs
312
313 all_logprobs = __get_logprobs(batched_outputs)
314
315 seq_length = [seq.num_token_ids for seq in running]

Callers 1

Calls 7

RequestMetricsClass · 0.90
InferOutputClass · 0.85
append_routed_expertsMethod · 0.80
append_logitsMethod · 0.80
get_block_tableMethod · 0.80
update_runningMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected