()
| 1509 | |
| 1510 | |
| 1511 | def predict(): |
| 1512 | parser = PdArgumentParser((PredictorArgument, ModelArgument)) |
| 1513 | predictor_args, model_args = parser.parse_args_into_dataclasses() |
| 1514 | |
| 1515 | paddle.set_device(predictor_args.device) |
| 1516 | paddle.set_default_dtype(predictor_args.dtype) |
| 1517 | |
| 1518 | tensor_parallel_degree = paddle.distributed.get_world_size() |
| 1519 | if tensor_parallel_degree > 1: |
| 1520 | strategy = fleet.DistributedStrategy() |
| 1521 | strategy.hybrid_configs = { |
| 1522 | "dp_degree": 1, |
| 1523 | "mp_degree": tensor_parallel_degree, |
| 1524 | "pp_degree": 1, |
| 1525 | "sharding_degree": 1, |
| 1526 | } |
| 1527 | fleet.init(is_collective=True, strategy=strategy) |
| 1528 | |
| 1529 | predictor = create_predictor(predictor_args, model_args) |
| 1530 | source_texts = [] |
| 1531 | target_texts = [] |
| 1532 | if model_args.data_file: |
| 1533 | with open(model_args.data_file, "r", encoding="utf-8") as f: |
| 1534 | for line in f: |
| 1535 | example = json.loads(line) |
| 1536 | if isinstance(example["src"], str) or predictor.tokenizer.chat_template is None: |
| 1537 | if isinstance(example["src"], str): |
| 1538 | source_texts.append(example["src"]) |
| 1539 | target_texts.append(example["tgt"]) |
| 1540 | else: |
| 1541 | # load multi-rounds dataset |
| 1542 | source_texts.append(example["src"][0]) |
| 1543 | target_texts.append(example["tgt"][0]) |
| 1544 | else: |
| 1545 | source_texts.append(list(zip(example["src"], example["tgt"]))) |
| 1546 | target_texts.append("") |
| 1547 | |
| 1548 | else: |
| 1549 | source_texts = ["解释一下“温故而知新”", "你好,请问你是谁?"] |
| 1550 | target_texts = ["", ""] |
| 1551 | |
| 1552 | batch_source_texts = batchfy_text(source_texts, predictor_args.batch_size) |
| 1553 | batch_target_texts = batchfy_text(target_texts, predictor_args.batch_size) |
| 1554 | |
| 1555 | with open(model_args.output_file, "w", encoding="utf-8") as f: |
| 1556 | for bs, batch_source_text in enumerate(batch_source_texts): |
| 1557 | logger.info("Start predict") |
| 1558 | outputs = predictor.predict(batch_source_text) |
| 1559 | logger.info("End predict") |
| 1560 | |
| 1561 | if predictor.tensor_parallel_rank > 0: |
| 1562 | continue |
| 1563 | for output, source, target in zip(outputs, batch_source_texts[bs], batch_target_texts[bs]): |
| 1564 | print("***********Source**********") |
| 1565 | print(source) |
| 1566 | print("***********Target**********") |
| 1567 | print(target) |
| 1568 | print("***********Output**********") |
searching dependent graphs…