(
self,
model_or_path: Union[nn.Module, str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
inference_config: InferenceConfig = None,
verbose: bool = False,
model_policy: Union[Policy, type[Policy]] = None,
)
| 56 | """ |
| 57 | |
| 58 | def __init__( |
| 59 | self, |
| 60 | model_or_path: Union[nn.Module, str], |
| 61 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, |
| 62 | inference_config: InferenceConfig = None, |
| 63 | verbose: bool = False, |
| 64 | model_policy: Union[Policy, type[Policy]] = None, |
| 65 | ) -> None: |
| 66 | self.inference_config = inference_config |
| 67 | self.dtype = inference_config.dtype |
| 68 | self.high_precision = inference_config.high_precision |
| 69 | |
| 70 | self.verbose = verbose |
| 71 | self.logger = get_dist_logger(__name__) |
| 72 | self.model_shard_infer_config = inference_config.to_model_shard_inference_config() |
| 73 | |
| 74 | self.init_model(model_or_path, model_policy, self.model_shard_infer_config) |
| 75 | |
| 76 | self.generation_config = inference_config.to_generation_config(self.model_config) |
| 77 | self.generation_config_dict = self.generation_config.to_dict() |
| 78 | |
| 79 | self.tokenizer = tokenizer |
| 80 | self.tokenizer.pad_token = self.tokenizer.eos_token |
| 81 | |
| 82 | self.request_handler = RequestHandler(self.inference_config, self.model_config) |
| 83 | self.k_cache, self.v_cache = self.request_handler.get_kvcache() |
| 84 | # DISCUSS maybe move this into batch info? |
| 85 | |
| 86 | self.counter = count() |
| 87 | |
| 88 | self.use_cuda_graph = self.inference_config.use_cuda_graph |
| 89 | if self.use_cuda_graph: |
| 90 | self.graph_runners: Dict[int, CUDAGraphRunner] = {} |
| 91 | self.graph_memory_pool = None # Set during graph capture. |
| 92 | if verbose: |
| 93 | self.logger.info("Colossal AI CUDA Graph Capture on") |
| 94 | |
| 95 | self.capture_model(self.k_cache, self.v_cache) |
| 96 | |
| 97 | # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` |
| 98 | self.use_spec_dec = self.inference_config.use_spec_dec |
| 99 | |
| 100 | self.drafter_model = None |
| 101 | self.drafter = None |
| 102 | self.use_glide = False |
| 103 | self.n_spec_tokens = self.inference_config.max_n_spec_tokens |
| 104 | |
| 105 | self._verify_args() |
| 106 | |
| 107 | def init_model( |
| 108 | self, |
nothing calls this directly
no test coverage detected