MCPcopy
hub / github.com/hpcaitech/ColossalAI / prepare_input

Method prepare_input

colossalai/inference/core/llm_engine.py:671–717  ·  view source on GitHub ↗
(self, batch: BatchBucket)

Source from the content-addressed store, hash-verified

669 self.request_handler.add_sequence(sequence)
670
671 def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
672 input_ids = batch.get_1D_inputs()
673 sequence_lengths = batch.get_sequence_lengths()
674
675 if batch.is_prompts:
676 n_tokens = sequence_lengths.sum().item()
677 else:
678 n_tokens = batch.current_batch_size
679 if batch.use_spec_dec:
680 n_tokens = batch.num_tokens_to_verify + 1
681 assert n_tokens == input_ids.size(0)
682 n_tokens = n_tokens * batch.current_batch_size
683 output_tensor = torch.zeros(
684 (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
685 )
686
687 batch_token_ids = None
688 if (
689 self.generation_config.repetition_penalty != 1.0
690 or self.generation_config.no_repeat_ngram_size > 0
691 or self.generation_config.forced_eos_token_id is not None
692 ):
693 batch_token_ids = batch.batch_token_ids
694
695 # only when we have the graph for specific decoding batch size can we use the cuda graph for inference
696 use_cuda_graph = False
697 if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
698 use_cuda_graph = True
699
700 input_meta_data = InputMetaData(
701 block_tables=batch.get_block_table_tensor(),
702 sequence_lengths=sequence_lengths,
703 fd_inter_tensor=batch.fd_inter_tensor,
704 batch_size=batch.current_batch_size,
705 is_prompts=batch.is_prompts,
706 use_cuda_kernel=self.inference_config.use_cuda_kernel,
707 use_cuda_graph=use_cuda_graph,
708 high_precision=self.high_precision,
709 kv_seq_len=sequence_lengths.max().item(),
710 head_dim=batch.head_dim,
711 dtype=batch.dtype,
712 use_spec_dec=batch.use_spec_dec,
713 num_tokens_to_verify=batch.num_tokens_to_verify,
714 batch_token_ids=batch_token_ids,
715 )
716
717 return input_ids, output_tensor, input_meta_data
718
719 def step(self) -> List[str]:
720 """

Callers 3

steps_spec_decMethod · 0.95
stepMethod · 0.95
async_stepMethod · 0.45

Calls 5

InputMetaDataClass · 0.90
get_1D_inputsMethod · 0.80
get_sequence_lengthsMethod · 0.80
sizeMethod · 0.45

Tested by

no test coverage detected