MCPcopy
hub / github.com/hpcaitech/ColossalAI / __init__

Method __init__

colossalai/inference/core/llm_engine.py:58–105  ·  view source on GitHub ↗
(
        self,
        model_or_path: Union[nn.Module, str],
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
        inference_config: InferenceConfig = None,
        verbose: bool = False,
        model_policy: Union[Policy, type[Policy]] = None,
    )

Source from the content-addressed store, hash-verified

56 """
57
58 def __init__(
59 self,
60 model_or_path: Union[nn.Module, str],
61 tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
62 inference_config: InferenceConfig = None,
63 verbose: bool = False,
64 model_policy: Union[Policy, type[Policy]] = None,
65 ) -> None:
66 self.inference_config = inference_config
67 self.dtype = inference_config.dtype
68 self.high_precision = inference_config.high_precision
69
70 self.verbose = verbose
71 self.logger = get_dist_logger(__name__)
72 self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
73
74 self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
75
76 self.generation_config = inference_config.to_generation_config(self.model_config)
77 self.generation_config_dict = self.generation_config.to_dict()
78
79 self.tokenizer = tokenizer
80 self.tokenizer.pad_token = self.tokenizer.eos_token
81
82 self.request_handler = RequestHandler(self.inference_config, self.model_config)
83 self.k_cache, self.v_cache = self.request_handler.get_kvcache()
84 # DISCUSS maybe move this into batch info?
85
86 self.counter = count()
87
88 self.use_cuda_graph = self.inference_config.use_cuda_graph
89 if self.use_cuda_graph:
90 self.graph_runners: Dict[int, CUDAGraphRunner] = {}
91 self.graph_memory_pool = None # Set during graph capture.
92 if verbose:
93 self.logger.info("Colossal AI CUDA Graph Capture on")
94
95 self.capture_model(self.k_cache, self.v_cache)
96
97 # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
98 self.use_spec_dec = self.inference_config.use_spec_dec
99
100 self.drafter_model = None
101 self.drafter = None
102 self.use_glide = False
103 self.n_spec_tokens = self.inference_config.max_n_spec_tokens
104
105 self._verify_args()
106
107 def init_model(
108 self,

Callers

nothing calls this directly

Calls 10

init_modelMethod · 0.95
capture_modelMethod · 0.95
_verify_argsMethod · 0.95
get_dist_loggerFunction · 0.90
RequestHandlerClass · 0.85
to_generation_configMethod · 0.80
get_kvcacheMethod · 0.80
to_dictMethod · 0.45
infoMethod · 0.45

Tested by

no test coverage detected