The input info for a single step Args: block_tables (torch.Tensor, optional): Sequences' BlockTables Defaults to None. sequence_lengths (torch.Tensor): A tensor containing sequence lengths. fd_inter_tensor (torch.Tensor, optional): A tensor representing intermediate data for flash d
| 51 | |
| 52 | @dataclass |
| 53 | class InputMetaData(RPC_PARAM): |
| 54 | """The input info for a single step |
| 55 | |
| 56 | Args: |
| 57 | block_tables (torch.Tensor, optional): Sequences' BlockTables Defaults to None. |
| 58 | sequence_lengths (torch.Tensor): A tensor containing sequence lengths. |
| 59 | fd_inter_tensor (torch.Tensor, optional): A tensor representing intermediate data for flash decoding. Defaults to None. |
| 60 | batch_size (int, optional): The current batch size. Defaults to 64. |
| 61 | is_prompts (bool, optional): Indicates whether prefill or decoding. Defaults to False(decoding). |
| 62 | use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally |
| 63 | use_cuda_graph (bool, optional): Indicates whether to use the CUDA graph. Defaults to False. |
| 64 | kv_seq_len (int, optional): Key-value sequence length. Defaults to 512. |
| 65 | head_dim (int, optional): Head dimension. Defaults to 32. |
| 66 | high_precision(bool, optional): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, Defaults to False. |
| 67 | dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32. |
| 68 | use_spec_dec (bool): Indicate whether to use speculative decoding. |
| 69 | num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True. |
| 70 | batch_token_ids (List[List[int]], optional): input_token_ids + output_token_ids of current batch. Only used for `repetition_penalty`, `no_repeat_ngram_size` in sampler process. |
| 71 | """ |
| 72 | |
| 73 | block_tables: torch.Tensor = None |
| 74 | sequence_lengths: torch.Tensor = None |
| 75 | fd_inter_tensor: FDIntermTensors = None |
| 76 | batch_size: int = 64 # current_batch_size |
| 77 | is_prompts: bool = False |
| 78 | use_cuda_kernel: bool = False |
| 79 | use_cuda_graph: bool = False |
| 80 | kv_seq_len: int = 512 |
| 81 | head_dim: int = 32 |
| 82 | high_precision: bool = False |
| 83 | dtype: torch.dtype = torch.float32 |
| 84 | use_spec_dec: bool = False |
| 85 | num_tokens_to_verify: int = 0 |
| 86 | batch_token_ids: Optional[List[List[int]]] = ( |
| 87 | None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process |
| 88 | ) |
| 89 | |
| 90 | def to_rpc_param(self) -> Dict[str, any]: |
| 91 | return { |
| 92 | "block_tables": self.block_tables.tolist(), |
| 93 | "sequence_lengths": self.sequence_lengths.tolist(), |
| 94 | "batch_size": self.batch_size, |
| 95 | "is_prompts": self.is_prompts, |
| 96 | "use_cuda_kernel": self.use_cuda_kernel, |
| 97 | "use_cuda_graph": self.use_cuda_graph, |
| 98 | "kv_seq_len": self.kv_seq_len, |
| 99 | "head_dim": self.head_dim, |
| 100 | "high_precision": self.high_precision, |
| 101 | "dtype": str(self.dtype).split(".")[-1], |
| 102 | "use_spec_dec": self.use_spec_dec, |
| 103 | "num_tokens_to_verify": self.num_tokens_to_verify, |
| 104 | "batch_token_ids": self.batch_token_ids, |
| 105 | } |
| 106 | |
| 107 | @staticmethod |
| 108 | def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData": |
| 109 | """ |
| 110 | We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message |
no outgoing calls
no test coverage detected
searching dependent graphs…