| 209 | |
| 210 | class BasePredictor: |
| 211 | def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None): |
| 212 | self.model_config = AutoConfig.from_pretrained(config.model_name_or_path) |
| 213 | self.config: PredictorArgument = config |
| 214 | if tokenizer is None: |
| 215 | tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path, padding_side="left") |
| 216 | |
| 217 | self.tokenizer = tokenizer |
| 218 | |
| 219 | self.return_tensors = "pd" |
| 220 | self.tensor_parallel_rank, self.tensor_parallel_degree = init_dist_env() |
| 221 | self.model_config.tensor_parallel_rank, self.model_config.tensor_parallel_degree = init_dist_env() |
| 222 | |
| 223 | try: |
| 224 | self.generation_config = GenerationConfig.from_pretrained(config.model_name_or_path) |
| 225 | except: |
| 226 | logger.warning( |
| 227 | "Can't find generation config, so it will not use generation_config field in the model config" |
| 228 | ) |
| 229 | self.generation_config = None |
| 230 | |
| 231 | def _preprocess(self, source): |
| 232 | if self.tokenizer.chat_template is not None: |