(config, device, logger, vdl_writer)
| 44 | |
| 45 | @paddle.no_grad() |
| 46 | def main(config, device, logger, vdl_writer): |
| 47 | global_config = config["Global"] |
| 48 | |
| 49 | # build post process |
| 50 | post_process_class = build_post_process(config["PostProcess"], global_config) |
| 51 | |
| 52 | # build model |
| 53 | if hasattr(post_process_class, "character"): |
| 54 | config["Architecture"]["Head"]["out_channels"] = len( |
| 55 | getattr(post_process_class, "character") |
| 56 | ) |
| 57 | |
| 58 | model = build_model(config["Architecture"]) |
| 59 | algorithm = config["Architecture"]["algorithm"] |
| 60 | |
| 61 | load_model(config, model) |
| 62 | |
| 63 | # create data ops |
| 64 | transforms = [] |
| 65 | for op in config["Eval"]["dataset"]["transforms"]: |
| 66 | op_name = list(op)[0] |
| 67 | if "Encode" in op_name: |
| 68 | continue |
| 69 | if op_name == "KeepKeys": |
| 70 | op[op_name]["keep_keys"] = ["image", "shape"] |
| 71 | transforms.append(op) |
| 72 | |
| 73 | global_config["infer_mode"] = True |
| 74 | ops = create_operators(transforms, global_config) |
| 75 | |
| 76 | save_res_path = config["Global"]["save_res_path"] |
| 77 | os.makedirs(save_res_path, exist_ok=True) |
| 78 | |
| 79 | model.eval() |
| 80 | with open( |
| 81 | os.path.join(save_res_path, "infer.txt"), mode="w", encoding="utf-8" |
| 82 | ) as f_w: |
| 83 | for file in get_image_file_list(config["Global"]["infer_img"]): |
| 84 | logger.info("infer_img: {}".format(file)) |
| 85 | with open(file, "rb") as f: |
| 86 | img = f.read() |
| 87 | data = {"image": img} |
| 88 | batch = transform(data, ops) |
| 89 | images = np.expand_dims(batch[0], axis=0) |
| 90 | shape_list = np.expand_dims(batch[1], axis=0) |
| 91 | |
| 92 | images = paddle.to_tensor(images) |
| 93 | preds = model(images) |
| 94 | post_result = post_process_class(preds, [shape_list]) |
| 95 | |
| 96 | structure_str_list = post_result["structure_batch_list"][0] |
| 97 | bbox_list = post_result["bbox_batch_list"][0] |
| 98 | structure_str_list = structure_str_list[0] |
| 99 | structure_str_list = ( |
| 100 | ["<html>", "<body>", "<table>"] |
| 101 | + structure_str_list |
| 102 | + ["</table>", "</body>", "</html>"] |
| 103 | ) |
no test coverage detected
searching dependent graphs…