(model, path)
| 197 | |
| 198 | |
| 199 | def load_pretrained_params(model, path): |
| 200 | logger = get_logger() |
| 201 | path = maybe_download_params(path) |
| 202 | if path.endswith(".pdparams"): |
| 203 | path = path.replace(".pdparams", "") |
| 204 | assert os.path.exists( |
| 205 | path + ".pdparams" |
| 206 | ), "The {}.pdparams does not exists!".format(path) |
| 207 | |
| 208 | params = paddle.load(path + ".pdparams") |
| 209 | |
| 210 | state_dict = model.state_dict() |
| 211 | |
| 212 | new_state_dict = {} |
| 213 | is_float16 = False |
| 214 | |
| 215 | for k1 in params.keys(): |
| 216 | if k1 not in state_dict.keys(): |
| 217 | logger.warning("The pretrained params {} not in model".format(k1)) |
| 218 | else: |
| 219 | if params[k1].dtype == paddle.float16: |
| 220 | is_float16 = True |
| 221 | if params[k1].dtype != state_dict[k1].dtype: |
| 222 | params[k1] = params[k1].astype(state_dict[k1].dtype) |
| 223 | if list(state_dict[k1].shape) == list(params[k1].shape): |
| 224 | new_state_dict[k1] = params[k1] |
| 225 | else: |
| 226 | logger.warning( |
| 227 | "The shape of model params {} {} not matched with loaded params {} {} !".format( |
| 228 | k1, state_dict[k1].shape, k1, params[k1].shape |
| 229 | ) |
| 230 | ) |
| 231 | |
| 232 | model.set_state_dict(new_state_dict) |
| 233 | if is_float16: |
| 234 | logger.info( |
| 235 | "The parameter type is float16, which is converted to float32 when loading" |
| 236 | ) |
| 237 | logger.info("load pretrain successful from {}".format(path)) |
| 238 | return is_float16 |
| 239 | |
| 240 | |
| 241 | def save_model( |
no test coverage detected
searching dependent graphs…