MCPcopy
hub / github.com/hpcaitech/ColossalAI / InferenceEngine

Class InferenceEngine

colossalai/inference/core/engine.py:16–133  ·  view source on GitHub ↗

InferenceEngine which manages the inference process.. Args: model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model. tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.

Source from the content-addressed store, hash-verified

14
15
16class InferenceEngine:
17 """
18 InferenceEngine which manages the inference process..
19
20 Args:
21 model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model.
22 tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
23 inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
24 verbose (bool): Determine whether or not to log the generation process.
25 model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
26 """
27
28 def __init__(
29 self,
30 model_or_path: Union[nn.Module, str, DiffusionPipeline],
31 tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
32 inference_config: InferenceConfig = None,
33 verbose: bool = False,
34 model_policy: Union[Policy, Type[Policy]] = None,
35 ) -> None:
36 self.__dict__["_initialized"] = False # use __dict__ directly to avoid calling __setattr__
37 self.model_type = get_model_type(model_or_path=model_or_path)
38 self.engine = None
39 if self.model_type == ModelType.LLM:
40 from .llm_engine import LLMEngine
41
42 self.engine = LLMEngine(
43 model_or_path=model_or_path,
44 tokenizer=tokenizer,
45 inference_config=inference_config,
46 verbose=verbose,
47 model_policy=model_policy,
48 )
49 elif self.model_type == ModelType.DIFFUSION_MODEL:
50 from .diffusion_engine import DiffusionEngine
51
52 self.engine = DiffusionEngine(
53 model_or_path=model_or_path,
54 inference_config=inference_config,
55 verbose=verbose,
56 model_policy=model_policy,
57 )
58 elif self.model_type == ModelType.UNKNOWN:
59 self.logger.error(f"Model Type either Difffusion or LLM!")
60
61 self._initialized = True
62 self._verify_args()
63
64 def _verify_args(self) -> None:
65 """Verify the input args"""
66 assert self.engine is not None, "Please init Engine first"
67 assert self._initialized, "Engine must be initialized"
68
69 def generate(
70 self,
71 request_ids: Union[List[int], int] = None,
72 prompts: Union[List[str], str] = None,
73 *args,

Callers 12

check_inference_engineFunction · 0.90
check_inference_engineFunction · 0.90
check_streamingllmFunction · 0.90
check_inference_engineFunction · 0.90
check_spec_decFunction · 0.90
_run_engineFunction · 0.90
check_inference_engineFunction · 0.90
benchmark_inferenceFunction · 0.90
inferFunction · 0.90
benchmark_inferenceFunction · 0.90
inferFunction · 0.90
benchmark_colossalaiFunction · 0.90

Calls

no outgoing calls

Tested by 7

check_inference_engineFunction · 0.72
check_inference_engineFunction · 0.72
check_streamingllmFunction · 0.72
check_inference_engineFunction · 0.72
check_spec_decFunction · 0.72
_run_engineFunction · 0.72
check_inference_engineFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…