MCPcopy
hub / github.com/hpcaitech/ColossalAI / __init__

Method __init__

colossalai/inference/core/rpc_engine.py:51–123  ·  view source on GitHub ↗

If you input a real model loaded by transformers, the init will take quite a long time Currently we don't support model(nn.Module) format as the param.

(
        self,
        model_or_path: Union[nn.Module, str],
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
        inference_config: InferenceConfig,
        verbose: bool = False,
        model_policy: Policy = None,
    )

Source from the content-addressed store, hash-verified

49 """
50
51 def __init__(
52 self,
53 model_or_path: Union[nn.Module, str],
54 tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
55 inference_config: InferenceConfig,
56 verbose: bool = False,
57 model_policy: Policy = None,
58 ) -> None:
59 """
60 If you input a real model loaded by transformers, the init will take quite a long time
61 Currently we don't support model(nn.Module) format as the param.
62 """
63
64 torch.multiprocessing.set_start_method("spawn", force=True)
65
66 self.inference_config = inference_config
67 self.tokenizer = tokenizer
68 self.tokenizer.pad_token = self.tokenizer.eos_token
69
70 self.verbose = verbose
71 self.logger = get_dist_logger(__name__)
72
73 try:
74 if isinstance(model_or_path, str):
75 self.model_config = AutoConfig.from_pretrained(
76 model_or_path, trust_remote_code=True, torch_dtype=self.dtype
77 )
78 elif isinstance(model_or_path, nn.Module):
79 self.logger.error(
80 f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n"
81 )
82 # self.model_config = model_or_path.config
83 else:
84 self.logger.error(
85 f"An exception occurred during loading model Config: Please pass right param for {__class__.__name__}\n"
86 )
87 except Exception as e:
88 self.logger.error(
89 f"An exception occurred during loading model Config: {e}, The path should be transformers-like\n"
90 )
91 self.generation_config = inference_config.to_generation_config(self.model_config)
92
93 self.tp_size = inference_config.tp_size
94 self.events = [mp.Event() for _ in range(self.tp_size)]
95
96 # This operation will init the dist env and models
97 self.workers: List[rpcWorkerService] = []
98 self.init_workers()
99
100 asyncio.run(self.init_model(model_or_path, model_policy))
101
102 # init the scheduler and logic block manager
103 self.request_handler = self.init_scheduler(self.inference_config, self.model_config)
104
105 # init the physical cache
106 alloc_shape = self.request_handler.cache_manager.get_physical_cache_shape()
107 self.init_device_cache(alloc_shape)
108

Callers

nothing calls this directly

Calls 13

init_workersMethod · 0.95
init_modelMethod · 0.95
init_schedulerMethod · 0.95
init_device_cacheMethod · 0.95
_verify_argsMethod · 0.95
get_dist_loggerFunction · 0.90
to_generation_configMethod · 0.80
from_pretrainedMethod · 0.45
errorMethod · 0.45
EventMethod · 0.45
runMethod · 0.45

Tested by

no test coverage detected