MCPcopy
hub / github.com/PaddlePaddle/PaddleNLP / create_predictor

Function create_predictor

llm/predictor.py:1178–1508  ·  view source on GitHub ↗
(
    predictor_args: PredictorArgument,
    model_args: ModelArgument,
    tensor_parallel_degree: int = 1,
    tensor_parallel_rank: int = 0,
)

Source from the content-addressed store, hash-verified

1176
1177
1178def create_predictor(
1179 predictor_args: PredictorArgument,
1180 model_args: ModelArgument,
1181 tensor_parallel_degree: int = 1,
1182 tensor_parallel_rank: int = 0,
1183):
1184 tokenizer = AutoTokenizer.from_pretrained(
1185 predictor_args.model_name_or_path,
1186 )
1187 # init chat_template for tokenizer
1188 init_chat_template(tokenizer, predictor_args.model_name_or_path, predictor_args.chat_template)
1189
1190 # TODO(wj-Mcat): fix llama tokenzier pad_token bug
1191 if isinstance(tokenizer, LlamaTokenizer) and not tokenizer.pad_token:
1192 tokenizer.pad_token = tokenizer.unk_token
1193
1194 config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
1195
1196 max_position_embeddings = get_model_max_position_embeddings(config)
1197 if max_position_embeddings is None:
1198 max_position_embeddings = 2048
1199 logger.warning("Can not retrieval `max_position_embeddings` from config.json, use default value 2048")
1200
1201 if predictor_args.src_length is None:
1202 if predictor_args.max_length is None:
1203 predictor_args.src_length = get_default_max_encoding_length(config)
1204 predictor_args.max_length = get_default_max_decoding_length(config)
1205 else:
1206 predictor_args.src_length = max_position_embeddings - predictor_args.max_length
1207 if predictor_args.src_length <= 0:
1208 raise ValueError(
1209 f"--max_length<{predictor_args.max_length}> param should be smaller "
1210 f"than max_position_embeddings<{max_position_embeddings}>"
1211 )
1212 else:
1213 if predictor_args.max_length is None:
1214 predictor_args.max_length = max_position_embeddings - predictor_args.src_length
1215 if predictor_args.max_length <= 0:
1216 raise ValueError(
1217 f"--src_length<{predictor_args.src_length}> param should be smaller "
1218 f"than max_position_embeddings<{max_position_embeddings}>"
1219 )
1220 else:
1221 if predictor_args.src_length + predictor_args.max_length > max_position_embeddings:
1222 raise ValueError(
1223 f"The sum of src_length<{predictor_args.src_length}> and "
1224 f"max_length<{predictor_args.max_length}> should be smaller than or equal to "
1225 f"the maximum position embedding size<{max_position_embeddings}>"
1226 )
1227
1228 # update config parameter for inference predictor
1229 if predictor_args.decode_strategy == "greedy_search":
1230 predictor_args.top_p = 0.0
1231 predictor_args.temperature = 1.0
1232
1233 tensor_parallel_rank, tensor_parallel_degree = init_dist_env()
1234 if not predictor_args.inference_model:
1235 tokenizer.padding_side = "left"

Callers 3

mainFunction · 0.90
flask_server.pyFile · 0.90
predictFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…