MCPcopy Index your code
hub / github.com/hpcaitech/ColossalAI / RPCInferenceEngine

Class RPCInferenceEngine

colossalai/inference/core/rpc_engine.py:36–297  ·  view source on GitHub ↗

InferenceEngine which manages the inference process.. NOTE This `RPCInferenceEngine` is designed for multiple-card/online serving. Original `InferenceEngine` is designed for single card and offline service, though it supports multi-card offline inference. Args: model_or_pa

Source from the content-addressed store, hash-verified

34
35
36class RPCInferenceEngine(InferenceEngine):
37 """
38 InferenceEngine which manages the inference process..
39
40 NOTE This `RPCInferenceEngine` is designed for multiple-card/online serving.
41 Original `InferenceEngine` is designed for single card and offline service, though it supports multi-card offline inference.
42
43 Args:
44 model_or_path (nn.Module or str): Path or nn.Module of this model, Currently we don't support `nn.Module` Format
45 tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
46 inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
47 verbose (bool): Determine whether or not to log the generation process.
48 model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
49 """
50
51 def __init__(
52 self,
53 model_or_path: Union[nn.Module, str],
54 tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
55 inference_config: InferenceConfig,
56 verbose: bool = False,
57 model_policy: Policy = None,
58 ) -> None:
59 """
60 If you input a real model loaded by transformers, the init will take quite a long time
61 Currently we don't support model(nn.Module) format as the param.
62 """
63
64 torch.multiprocessing.set_start_method("spawn", force=True)
65
66 self.inference_config = inference_config
67 self.tokenizer = tokenizer
68 self.tokenizer.pad_token = self.tokenizer.eos_token
69
70 self.verbose = verbose
71 self.logger = get_dist_logger(__name__)
72
73 try:
74 if isinstance(model_or_path, str):
75 self.model_config = AutoConfig.from_pretrained(
76 model_or_path, trust_remote_code=True, torch_dtype=self.dtype
77 )
78 elif isinstance(model_or_path, nn.Module):
79 self.logger.error(
80 f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n"
81 )
82 # self.model_config = model_or_path.config
83 else:
84 self.logger.error(
85 f"An exception occurred during loading model Config: Please pass right param for {__class__.__name__}\n"
86 )
87 except Exception as e:
88 self.logger.error(
89 f"An exception occurred during loading model Config: {e}, The path should be transformers-like\n"
90 )
91 self.generation_config = inference_config.to_generation_config(self.model_config)
92
93 self.tp_size = inference_config.tp_size

Callers 1

check_inference_engineFunction · 0.90

Calls

no outgoing calls

Tested by 1

check_inference_engineFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…