(self, batch: BatchBucket)
| 669 | self.request_handler.add_sequence(sequence) |
| 670 | |
| 671 | def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: |
| 672 | input_ids = batch.get_1D_inputs() |
| 673 | sequence_lengths = batch.get_sequence_lengths() |
| 674 | |
| 675 | if batch.is_prompts: |
| 676 | n_tokens = sequence_lengths.sum().item() |
| 677 | else: |
| 678 | n_tokens = batch.current_batch_size |
| 679 | if batch.use_spec_dec: |
| 680 | n_tokens = batch.num_tokens_to_verify + 1 |
| 681 | assert n_tokens == input_ids.size(0) |
| 682 | n_tokens = n_tokens * batch.current_batch_size |
| 683 | output_tensor = torch.zeros( |
| 684 | (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device |
| 685 | ) |
| 686 | |
| 687 | batch_token_ids = None |
| 688 | if ( |
| 689 | self.generation_config.repetition_penalty != 1.0 |
| 690 | or self.generation_config.no_repeat_ngram_size > 0 |
| 691 | or self.generation_config.forced_eos_token_id is not None |
| 692 | ): |
| 693 | batch_token_ids = batch.batch_token_ids |
| 694 | |
| 695 | # only when we have the graph for specific decoding batch size can we use the cuda graph for inference |
| 696 | use_cuda_graph = False |
| 697 | if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): |
| 698 | use_cuda_graph = True |
| 699 | |
| 700 | input_meta_data = InputMetaData( |
| 701 | block_tables=batch.get_block_table_tensor(), |
| 702 | sequence_lengths=sequence_lengths, |
| 703 | fd_inter_tensor=batch.fd_inter_tensor, |
| 704 | batch_size=batch.current_batch_size, |
| 705 | is_prompts=batch.is_prompts, |
| 706 | use_cuda_kernel=self.inference_config.use_cuda_kernel, |
| 707 | use_cuda_graph=use_cuda_graph, |
| 708 | high_precision=self.high_precision, |
| 709 | kv_seq_len=sequence_lengths.max().item(), |
| 710 | head_dim=batch.head_dim, |
| 711 | dtype=batch.dtype, |
| 712 | use_spec_dec=batch.use_spec_dec, |
| 713 | num_tokens_to_verify=batch.num_tokens_to_verify, |
| 714 | batch_token_ids=batch_token_ids, |
| 715 | ) |
| 716 | |
| 717 | return input_ids, output_tensor, input_meta_data |
| 718 | |
| 719 | def step(self) -> List[str]: |
| 720 | """ |
no test coverage detected