(model, model_path)
| 8 | |
| 9 | |
| 10 | def load_head_from_pretrained_model(model, model_path): |
| 11 | model_dict = torch.load(model_path, map_location=lambda storage, loc: storage) |
| 12 | if "model" in model_dict.keys(): |
| 13 | model_dict = model_dict["model"] |
| 14 | src_dict = {k: v for k, v in model_dict.items() if "decoder.embed." in k or "ctc." in k or "decoder.output_layer." in k} |
| 15 | dst_state = model.state_dict() |
| 16 | dst_state.update(src_dict) |
| 17 | for key in dst_state.keys(): |
| 18 | if key in src_dict.keys(): |
| 19 | logging.info("loading " + key) |
| 20 | model.load_state_dict(dst_state) |
| 21 | |
| 22 | def load_adapter_from_pretrained_model(model, model_path, src_adapter, tgt_adapter): |
| 23 | ''' |
no test coverage detected