MCPcopy
hub / github.com/PaddlePaddle/PaddleOCR / load_model

Function load_model

ppocr/utils/save_load.py:66–196  ·  view source on GitHub ↗

load model from checkpoint or pretrained_model

(config, model, optimizer=None, model_type="det", ema=None)

Source from the content-addressed store, hash-verified

64
65
66def load_model(config, model, optimizer=None, model_type="det", ema=None):
67 """
68 load model from checkpoint or pretrained_model
69 """
70 logger = get_logger()
71 global_config = config["Global"]
72 checkpoints = global_config.get("checkpoints")
73 pretrained_model = global_config.get("pretrained_model")
74 best_model_dict = {}
75 is_float16 = False
76 is_nlp_model = model_type == "kie" and config["Architecture"]["algorithm"] not in [
77 "SDMGR"
78 ]
79
80 if is_nlp_model is True:
81 # NOTE: for kie model dsitillation, resume training is not supported now
82 if config["Architecture"]["algorithm"] in ["Distillation"]:
83 return best_model_dict
84 checkpoints = config["Architecture"]["Backbone"]["checkpoints"]
85 # load kie method metric
86 if checkpoints:
87 if os.path.exists(os.path.join(checkpoints, "metric.states")):
88 with open(os.path.join(checkpoints, "metric.states"), "rb") as f:
89 states_dict = pickle.load(f, encoding="latin1")
90 best_model_dict = states_dict.get("best_model_dict", {})
91 if "epoch" in states_dict:
92 best_model_dict["start_epoch"] = states_dict["epoch"] + 1
93 logger.info("resume from {}".format(checkpoints))
94
95 if optimizer is not None:
96 if checkpoints[-1] in ["/", "\\"]:
97 checkpoints = checkpoints[:-1]
98 if os.path.exists(checkpoints + ".pdopt"):
99 optim_dict = paddle.load(checkpoints + ".pdopt")
100 optimizer.set_state_dict(optim_dict)
101 else:
102 logger.warning(
103 "{}.pdopt is not exists, params of optimizer is not loaded".format(
104 checkpoints
105 )
106 )
107
108 return best_model_dict
109
110 if checkpoints:
111 if checkpoints.endswith(".pdparams"):
112 checkpoints = checkpoints.replace(".pdparams", "")
113 assert os.path.exists(
114 checkpoints + ".pdparams"
115 ), "The {}.pdparams does not exists!".format(checkpoints)
116
117 # load params from trained model
118 params = paddle.load(checkpoints + ".pdparams")
119 state_dict = model.state_dict()
120 new_state_dict = {}
121 for key, value in state_dict.items():
122 if key not in params:
123 logger.warning(

Callers 15

exportFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
__init__Method · 0.90
mainFunction · 0.90
__init__Method · 0.90
mainFunction · 0.90

Calls 5

get_loggerFunction · 0.90
formatMethod · 0.80
set_state_dictMethod · 0.80
load_pretrained_paramsFunction · 0.70
getMethod · 0.65

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…