| 76 | |
| 77 | |
| 78 | def load_model(model, model_file, is_restore=False): |
| 79 | t_start = time.time() |
| 80 | if isinstance(model_file, str): |
| 81 | state_dict = torch.load(model_file) |
| 82 | if 'model' in state_dict.keys(): |
| 83 | state_dict = state_dict['model'] |
| 84 | else: |
| 85 | state_dict = model_file |
| 86 | t_ioend = time.time() |
| 87 | |
| 88 | if is_restore: |
| 89 | new_state_dict = OrderedDict() |
| 90 | for k, v in state_dict.items(): |
| 91 | name = 'module.' + k |
| 92 | new_state_dict[name] = v |
| 93 | state_dict = new_state_dict |
| 94 | |
| 95 | model.load_state_dict(state_dict, strict=False) |
| 96 | ckpt_keys = set(state_dict.keys()) |
| 97 | own_keys = set(model.state_dict().keys()) |
| 98 | missing_keys = own_keys - ckpt_keys |
| 99 | unexpected_keys = ckpt_keys - own_keys |
| 100 | |
| 101 | if len(missing_keys) > 0: |
| 102 | logger.warning('Missing key(s) in state_dict: {}'.format( |
| 103 | ', '.join('{}'.format(k) for k in missing_keys))) |
| 104 | |
| 105 | if len(unexpected_keys) > 0: |
| 106 | logger.warning('Unexpected key(s) in state_dict: {}'.format( |
| 107 | ', '.join('{}'.format(k) for k in unexpected_keys))) |
| 108 | |
| 109 | del state_dict |
| 110 | t_end = time.time() |
| 111 | logger.info( |
| 112 | "Load model, Time usage:\n\tIO: {}, initialize parameters: {}".format( |
| 113 | t_ioend - t_start, t_end - t_ioend)) |
| 114 | |
| 115 | return model |
| 116 | |
| 117 | |
| 118 | def parse_devices(input_devices): |