load_pretrained_model(model=model, model_path="", modules_to_load=None, exclude_modules="")
(model, model_path, modules_to_load=None, exclude_modules=None)
| 64 | |
| 65 | # Load and save |
| 66 | def load_pretrained_model(model, model_path, modules_to_load=None, exclude_modules=None): |
| 67 | ''' |
| 68 | load_pretrained_model(model=model, model_path="", |
| 69 | modules_to_load=None, exclude_modules="") |
| 70 | ''' |
| 71 | model_dict = torch.load(model_path, map_location=lambda storage, loc: storage) |
| 72 | if exclude_modules: |
| 73 | for e in exclude_modules.split(","): |
| 74 | model_dict = {k: v for k, v in model_dict.items() if not k.startswith(e)} |
| 75 | |
| 76 | if not modules_to_load: |
| 77 | src_dict = model_dict |
| 78 | else: |
| 79 | src_dict = {} |
| 80 | for module in modules_to_load.split(","): |
| 81 | src_dict.update({k: v for k, v in model_dict.items() if k.startswith(module)}) |
| 82 | |
| 83 | dst_state = model.state_dict() |
| 84 | dst_state.update(src_dict) |
| 85 | model.load_state_dict(dst_state) |
| 86 | def torch_save(model, save_path, optimizer=None, local_rank=0): |
| 87 | if local_rank != 0: |
| 88 | return |
no test coverage detected