MCPcopy
hub / github.com/PaddlePaddle/PaddleOCR / load_pretrained_params

Function load_pretrained_params

ppocr/utils/save_load.py:199–238  ·  view source on GitHub ↗
(model, path)

Source from the content-addressed store, hash-verified

197
198
199def load_pretrained_params(model, path):
200 logger = get_logger()
201 path = maybe_download_params(path)
202 if path.endswith(".pdparams"):
203 path = path.replace(".pdparams", "")
204 assert os.path.exists(
205 path + ".pdparams"
206 ), "The {}.pdparams does not exists!".format(path)
207
208 params = paddle.load(path + ".pdparams")
209
210 state_dict = model.state_dict()
211
212 new_state_dict = {}
213 is_float16 = False
214
215 for k1 in params.keys():
216 if k1 not in state_dict.keys():
217 logger.warning("The pretrained params {} not in model".format(k1))
218 else:
219 if params[k1].dtype == paddle.float16:
220 is_float16 = True
221 if params[k1].dtype != state_dict[k1].dtype:
222 params[k1] = params[k1].astype(state_dict[k1].dtype)
223 if list(state_dict[k1].shape) == list(params[k1].shape):
224 new_state_dict[k1] = params[k1]
225 else:
226 logger.warning(
227 "The shape of model params {} {} not matched with loaded params {} {} !".format(
228 k1, state_dict[k1].shape, k1, params[k1].shape
229 )
230 )
231
232 model.set_state_dict(new_state_dict)
233 if is_float16:
234 logger.info(
235 "The parameter type is float16, which is converted to float32 when loading"
236 )
237 logger.info("load pretrain successful from {}".format(path))
238 return is_float16
239
240
241def save_model(

Callers 2

__init__Method · 0.90
load_modelFunction · 0.70

Calls 4

get_loggerFunction · 0.90
maybe_download_paramsFunction · 0.90
formatMethod · 0.80
set_state_dictMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…