(self, config)
| 58 | |
| 59 | class SerPredictor(object): |
| 60 | def __init__(self, config): |
| 61 | global_config = config["Global"] |
| 62 | self.algorithm = config["Architecture"]["algorithm"] |
| 63 | |
| 64 | # build post process |
| 65 | self.post_process_class = build_post_process( |
| 66 | config["PostProcess"], global_config |
| 67 | ) |
| 68 | |
| 69 | # build model |
| 70 | self.model = build_model(config["Architecture"]) |
| 71 | |
| 72 | load_model(config, self.model, model_type=config["Architecture"]["model_type"]) |
| 73 | |
| 74 | from paddleocr import PaddleOCR |
| 75 | |
| 76 | self.ocr_engine = PaddleOCR( |
| 77 | use_angle_cls=False, |
| 78 | show_log=False, |
| 79 | rec_model_dir=global_config.get("kie_rec_model_dir", None), |
| 80 | det_model_dir=global_config.get("kie_det_model_dir", None), |
| 81 | use_gpu=global_config["use_gpu"], |
| 82 | ) |
| 83 | |
| 84 | # create data ops |
| 85 | transforms = [] |
| 86 | for op in config["Eval"]["dataset"]["transforms"]: |
| 87 | op_name = list(op)[0] |
| 88 | if "Label" in op_name: |
| 89 | op[op_name]["ocr_engine"] = self.ocr_engine |
| 90 | elif op_name == "KeepKeys": |
| 91 | op[op_name]["keep_keys"] = [ |
| 92 | "input_ids", |
| 93 | "bbox", |
| 94 | "attention_mask", |
| 95 | "token_type_ids", |
| 96 | "image", |
| 97 | "labels", |
| 98 | "segment_offset_id", |
| 99 | "ocr_info", |
| 100 | "entities", |
| 101 | ] |
| 102 | |
| 103 | transforms.append(op) |
| 104 | if config["Global"].get("infer_mode", None) is None: |
| 105 | global_config["infer_mode"] = True |
| 106 | self.ops = create_operators( |
| 107 | config["Eval"]["dataset"]["transforms"], global_config |
| 108 | ) |
| 109 | self.model.eval() |
| 110 | |
| 111 | def __call__(self, data): |
| 112 | with open(data["img_path"], "rb") as f: |
nothing calls this directly
no test coverage detected