We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message
(rpc_dict: Dict[str, any])
| 106 | |
| 107 | @staticmethod |
| 108 | def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData": |
| 109 | """ |
| 110 | We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message |
| 111 | """ |
| 112 | from colossalai.accelerator import get_accelerator |
| 113 | |
| 114 | dtype = getattr(torch, rpc_dict["dtype"]) |
| 115 | return InputMetaData( |
| 116 | block_tables=torch.tensor( |
| 117 | rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device() |
| 118 | ), |
| 119 | sequence_lengths=torch.tensor( |
| 120 | rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device() |
| 121 | ), |
| 122 | batch_size=rpc_dict["batch_size"], |
| 123 | is_prompts=rpc_dict["is_prompts"], |
| 124 | use_cuda_kernel=rpc_dict["use_cuda_kernel"], |
| 125 | use_cuda_graph=rpc_dict["use_cuda_graph"], |
| 126 | kv_seq_len=rpc_dict["kv_seq_len"], |
| 127 | head_dim=rpc_dict["head_dim"], |
| 128 | high_precision=rpc_dict["high_precision"], |
| 129 | dtype=dtype, |
| 130 | use_spec_dec=rpc_dict["use_spec_dec"], |
| 131 | num_tokens_to_verify=rpc_dict["num_tokens_to_verify"], |
| 132 | batch_token_ids=rpc_dict["batch_token_ids"], |
| 133 | ) |
| 134 | |
| 135 | def __repr__(self) -> str: |
| 136 | return ( |
nothing calls this directly
no test coverage detected