load model from checkpoint or pretrained_model
(config, model, optimizer=None, model_type="det", ema=None)
| 64 | |
| 65 | |
| 66 | def load_model(config, model, optimizer=None, model_type="det", ema=None): |
| 67 | """ |
| 68 | load model from checkpoint or pretrained_model |
| 69 | """ |
| 70 | logger = get_logger() |
| 71 | global_config = config["Global"] |
| 72 | checkpoints = global_config.get("checkpoints") |
| 73 | pretrained_model = global_config.get("pretrained_model") |
| 74 | best_model_dict = {} |
| 75 | is_float16 = False |
| 76 | is_nlp_model = model_type == "kie" and config["Architecture"]["algorithm"] not in [ |
| 77 | "SDMGR" |
| 78 | ] |
| 79 | |
| 80 | if is_nlp_model is True: |
| 81 | # NOTE: for kie model dsitillation, resume training is not supported now |
| 82 | if config["Architecture"]["algorithm"] in ["Distillation"]: |
| 83 | return best_model_dict |
| 84 | checkpoints = config["Architecture"]["Backbone"]["checkpoints"] |
| 85 | # load kie method metric |
| 86 | if checkpoints: |
| 87 | if os.path.exists(os.path.join(checkpoints, "metric.states")): |
| 88 | with open(os.path.join(checkpoints, "metric.states"), "rb") as f: |
| 89 | states_dict = pickle.load(f, encoding="latin1") |
| 90 | best_model_dict = states_dict.get("best_model_dict", {}) |
| 91 | if "epoch" in states_dict: |
| 92 | best_model_dict["start_epoch"] = states_dict["epoch"] + 1 |
| 93 | logger.info("resume from {}".format(checkpoints)) |
| 94 | |
| 95 | if optimizer is not None: |
| 96 | if checkpoints[-1] in ["/", "\\"]: |
| 97 | checkpoints = checkpoints[:-1] |
| 98 | if os.path.exists(checkpoints + ".pdopt"): |
| 99 | optim_dict = paddle.load(checkpoints + ".pdopt") |
| 100 | optimizer.set_state_dict(optim_dict) |
| 101 | else: |
| 102 | logger.warning( |
| 103 | "{}.pdopt is not exists, params of optimizer is not loaded".format( |
| 104 | checkpoints |
| 105 | ) |
| 106 | ) |
| 107 | |
| 108 | return best_model_dict |
| 109 | |
| 110 | if checkpoints: |
| 111 | if checkpoints.endswith(".pdparams"): |
| 112 | checkpoints = checkpoints.replace(".pdparams", "") |
| 113 | assert os.path.exists( |
| 114 | checkpoints + ".pdparams" |
| 115 | ), "The {}.pdparams does not exists!".format(checkpoints) |
| 116 | |
| 117 | # load params from trained model |
| 118 | params = paddle.load(checkpoints + ".pdparams") |
| 119 | state_dict = model.state_dict() |
| 120 | new_state_dict = {} |
| 121 | for key, value in state_dict.items(): |
| 122 | if key not in params: |
| 123 | logger.warning( |
no test coverage detected
searching dependent graphs…