MCPcopy
hub / github.com/PaddlePaddle/PaddleNLP / predict

Function predict

llm/predictor.py:1511–1574  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

1509
1510
1511def predict():
1512 parser = PdArgumentParser((PredictorArgument, ModelArgument))
1513 predictor_args, model_args = parser.parse_args_into_dataclasses()
1514
1515 paddle.set_device(predictor_args.device)
1516 paddle.set_default_dtype(predictor_args.dtype)
1517
1518 tensor_parallel_degree = paddle.distributed.get_world_size()
1519 if tensor_parallel_degree > 1:
1520 strategy = fleet.DistributedStrategy()
1521 strategy.hybrid_configs = {
1522 "dp_degree": 1,
1523 "mp_degree": tensor_parallel_degree,
1524 "pp_degree": 1,
1525 "sharding_degree": 1,
1526 }
1527 fleet.init(is_collective=True, strategy=strategy)
1528
1529 predictor = create_predictor(predictor_args, model_args)
1530 source_texts = []
1531 target_texts = []
1532 if model_args.data_file:
1533 with open(model_args.data_file, "r", encoding="utf-8") as f:
1534 for line in f:
1535 example = json.loads(line)
1536 if isinstance(example["src"], str) or predictor.tokenizer.chat_template is None:
1537 if isinstance(example["src"], str):
1538 source_texts.append(example["src"])
1539 target_texts.append(example["tgt"])
1540 else:
1541 # load multi-rounds dataset
1542 source_texts.append(example["src"][0])
1543 target_texts.append(example["tgt"][0])
1544 else:
1545 source_texts.append(list(zip(example["src"], example["tgt"])))
1546 target_texts.append("")
1547
1548 else:
1549 source_texts = ["解释一下“温故而知新”", "你好,请问你是谁?"]
1550 target_texts = ["", ""]
1551
1552 batch_source_texts = batchfy_text(source_texts, predictor_args.batch_size)
1553 batch_target_texts = batchfy_text(target_texts, predictor_args.batch_size)
1554
1555 with open(model_args.output_file, "w", encoding="utf-8") as f:
1556 for bs, batch_source_text in enumerate(batch_source_texts):
1557 logger.info("Start predict")
1558 outputs = predictor.predict(batch_source_text)
1559 logger.info("End predict")
1560
1561 if predictor.tensor_parallel_rank > 0:
1562 continue
1563 for output, source, target in zip(outputs, batch_source_texts[bs], batch_target_texts[bs]):
1564 print("***********Source**********")
1565 print(source)
1566 print("***********Target**********")
1567 print(target)
1568 print("***********Output**********")

Callers 3

run_predictorMethod · 0.90
predictor.pyFile · 0.70

Calls 10

PdArgumentParserClass · 0.90
create_predictorFunction · 0.85
benchmarkFunction · 0.85
get_world_sizeMethod · 0.80
batchfy_textFunction · 0.70
initMethod · 0.45
appendMethod · 0.45
predictMethod · 0.45
writeMethod · 0.45

Tested by 2

run_predictorMethod · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…