MCPcopy
hub / github.com/hpcaitech/ColossalAI / InputMetaData

Class InputMetaData

colossalai/inference/config.py:53–147  ·  view source on GitHub ↗

The input info for a single step Args: block_tables (torch.Tensor, optional): Sequences' BlockTables Defaults to None. sequence_lengths (torch.Tensor): A tensor containing sequence lengths. fd_inter_tensor (torch.Tensor, optional): A tensor representing intermediate data for flash d

Source from the content-addressed store, hash-verified

51
52@dataclass
53class InputMetaData(RPC_PARAM):
54 """The input info for a single step
55
56 Args:
57 block_tables (torch.Tensor, optional): Sequences' BlockTables Defaults to None.
58 sequence_lengths (torch.Tensor): A tensor containing sequence lengths.
59 fd_inter_tensor (torch.Tensor, optional): A tensor representing intermediate data for flash decoding. Defaults to None.
60 batch_size (int, optional): The current batch size. Defaults to 64.
61 is_prompts (bool, optional): Indicates whether prefill or decoding. Defaults to False(decoding).
62 use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally
63 use_cuda_graph (bool, optional): Indicates whether to use the CUDA graph. Defaults to False.
64 kv_seq_len (int, optional): Key-value sequence length. Defaults to 512.
65 head_dim (int, optional): Head dimension. Defaults to 32.
66 high_precision(bool, optional): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, Defaults to False.
67 dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32.
68 use_spec_dec (bool): Indicate whether to use speculative decoding.
69 num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True.
70 batch_token_ids (List[List[int]], optional): input_token_ids + output_token_ids of current batch. Only used for `repetition_penalty`, `no_repeat_ngram_size` in sampler process.
71 """
72
73 block_tables: torch.Tensor = None
74 sequence_lengths: torch.Tensor = None
75 fd_inter_tensor: FDIntermTensors = None
76 batch_size: int = 64 # current_batch_size
77 is_prompts: bool = False
78 use_cuda_kernel: bool = False
79 use_cuda_graph: bool = False
80 kv_seq_len: int = 512
81 head_dim: int = 32
82 high_precision: bool = False
83 dtype: torch.dtype = torch.float32
84 use_spec_dec: bool = False
85 num_tokens_to_verify: int = 0
86 batch_token_ids: Optional[List[List[int]]] = (
87 None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
88 )
89
90 def to_rpc_param(self) -> Dict[str, any]:
91 return {
92 "block_tables": self.block_tables.tolist(),
93 "sequence_lengths": self.sequence_lengths.tolist(),
94 "batch_size": self.batch_size,
95 "is_prompts": self.is_prompts,
96 "use_cuda_kernel": self.use_cuda_kernel,
97 "use_cuda_graph": self.use_cuda_graph,
98 "kv_seq_len": self.kv_seq_len,
99 "head_dim": self.head_dim,
100 "high_precision": self.high_precision,
101 "dtype": str(self.dtype).split(".")[-1],
102 "use_spec_dec": self.use_spec_dec,
103 "num_tokens_to_verify": self.num_tokens_to_verify,
104 "batch_token_ids": self.batch_token_ids,
105 }
106
107 @staticmethod
108 def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData":
109 """
110 We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message

Callers 4

prepare_inputMethod · 0.90
capture_modelMethod · 0.90
prepare_inputMethod · 0.90
from_rpc_paramMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…