(
predictor_args: PredictorArgument,
model_args: ModelArgument,
tensor_parallel_degree: int = 1,
tensor_parallel_rank: int = 0,
)
| 1176 | |
| 1177 | |
| 1178 | def create_predictor( |
| 1179 | predictor_args: PredictorArgument, |
| 1180 | model_args: ModelArgument, |
| 1181 | tensor_parallel_degree: int = 1, |
| 1182 | tensor_parallel_rank: int = 0, |
| 1183 | ): |
| 1184 | tokenizer = AutoTokenizer.from_pretrained( |
| 1185 | predictor_args.model_name_or_path, |
| 1186 | ) |
| 1187 | # init chat_template for tokenizer |
| 1188 | init_chat_template(tokenizer, predictor_args.model_name_or_path, predictor_args.chat_template) |
| 1189 | |
| 1190 | # TODO(wj-Mcat): fix llama tokenzier pad_token bug |
| 1191 | if isinstance(tokenizer, LlamaTokenizer) and not tokenizer.pad_token: |
| 1192 | tokenizer.pad_token = tokenizer.unk_token |
| 1193 | |
| 1194 | config = AutoConfig.from_pretrained(predictor_args.model_name_or_path) |
| 1195 | |
| 1196 | max_position_embeddings = get_model_max_position_embeddings(config) |
| 1197 | if max_position_embeddings is None: |
| 1198 | max_position_embeddings = 2048 |
| 1199 | logger.warning("Can not retrieval `max_position_embeddings` from config.json, use default value 2048") |
| 1200 | |
| 1201 | if predictor_args.src_length is None: |
| 1202 | if predictor_args.max_length is None: |
| 1203 | predictor_args.src_length = get_default_max_encoding_length(config) |
| 1204 | predictor_args.max_length = get_default_max_decoding_length(config) |
| 1205 | else: |
| 1206 | predictor_args.src_length = max_position_embeddings - predictor_args.max_length |
| 1207 | if predictor_args.src_length <= 0: |
| 1208 | raise ValueError( |
| 1209 | f"--max_length<{predictor_args.max_length}> param should be smaller " |
| 1210 | f"than max_position_embeddings<{max_position_embeddings}>" |
| 1211 | ) |
| 1212 | else: |
| 1213 | if predictor_args.max_length is None: |
| 1214 | predictor_args.max_length = max_position_embeddings - predictor_args.src_length |
| 1215 | if predictor_args.max_length <= 0: |
| 1216 | raise ValueError( |
| 1217 | f"--src_length<{predictor_args.src_length}> param should be smaller " |
| 1218 | f"than max_position_embeddings<{max_position_embeddings}>" |
| 1219 | ) |
| 1220 | else: |
| 1221 | if predictor_args.src_length + predictor_args.max_length > max_position_embeddings: |
| 1222 | raise ValueError( |
| 1223 | f"The sum of src_length<{predictor_args.src_length}> and " |
| 1224 | f"max_length<{predictor_args.max_length}> should be smaller than or equal to " |
| 1225 | f"the maximum position embedding size<{max_position_embeddings}>" |
| 1226 | ) |
| 1227 | |
| 1228 | # update config parameter for inference predictor |
| 1229 | if predictor_args.decode_strategy == "greedy_search": |
| 1230 | predictor_args.top_p = 0.0 |
| 1231 | predictor_args.temperature = 1.0 |
| 1232 | |
| 1233 | tensor_parallel_rank, tensor_parallel_degree = init_dist_env() |
| 1234 | if not predictor_args.inference_model: |
| 1235 | tokenizer.padding_side = "left" |
no test coverage detected
searching dependent graphs…