MCPcopy Index your code
hub / github.com/hpcaitech/ColossalAI / from_rpc_param

Method from_rpc_param

colossalai/inference/config.py:108–133  ·  view source on GitHub ↗

We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message

(rpc_dict: Dict[str, any])

Source from the content-addressed store, hash-verified

106
107 @staticmethod
108 def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData":
109 """
110 We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message
111 """
112 from colossalai.accelerator import get_accelerator
113
114 dtype = getattr(torch, rpc_dict["dtype"])
115 return InputMetaData(
116 block_tables=torch.tensor(
117 rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device()
118 ),
119 sequence_lengths=torch.tensor(
120 rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device()
121 ),
122 batch_size=rpc_dict["batch_size"],
123 is_prompts=rpc_dict["is_prompts"],
124 use_cuda_kernel=rpc_dict["use_cuda_kernel"],
125 use_cuda_graph=rpc_dict["use_cuda_graph"],
126 kv_seq_len=rpc_dict["kv_seq_len"],
127 head_dim=rpc_dict["head_dim"],
128 high_precision=rpc_dict["high_precision"],
129 dtype=dtype,
130 use_spec_dec=rpc_dict["use_spec_dec"],
131 num_tokens_to_verify=rpc_dict["num_tokens_to_verify"],
132 batch_token_ids=rpc_dict["batch_token_ids"],
133 )
134
135 def __repr__(self) -> str:
136 return (

Callers

nothing calls this directly

Calls 3

get_acceleratorFunction · 0.90
InputMetaDataClass · 0.85
get_current_deviceMethod · 0.45

Tested by

no test coverage detected