MCPcopy
hub / github.com/microsoft/Cream / load_model

Function load_model

CDARTS/CDARTS_segmentation/tools/utils/pyt_utils.py:78–115  ·  view source on GitHub ↗
(model, model_file, is_restore=False)

Source from the content-addressed store, hash-verified

76
77
78def load_model(model, model_file, is_restore=False):
79 t_start = time.time()
80 if isinstance(model_file, str):
81 state_dict = torch.load(model_file)
82 if 'model' in state_dict.keys():
83 state_dict = state_dict['model']
84 else:
85 state_dict = model_file
86 t_ioend = time.time()
87
88 if is_restore:
89 new_state_dict = OrderedDict()
90 for k, v in state_dict.items():
91 name = 'module.' + k
92 new_state_dict[name] = v
93 state_dict = new_state_dict
94
95 model.load_state_dict(state_dict, strict=False)
96 ckpt_keys = set(state_dict.keys())
97 own_keys = set(model.state_dict().keys())
98 missing_keys = own_keys - ckpt_keys
99 unexpected_keys = ckpt_keys - own_keys
100
101 if len(missing_keys) > 0:
102 logger.warning('Missing key(s) in state_dict: {}'.format(
103 ', '.join('{}'.format(k) for k in missing_keys)))
104
105 if len(unexpected_keys) > 0:
106 logger.warning('Unexpected key(s) in state_dict: {}'.format(
107 ', '.join('{}'.format(k) for k in unexpected_keys)))
108
109 del state_dict
110 t_end = time.time()
111 logger.info(
112 "Load model, Time usage:\n\tIO: {}, initialize parameters: {}".format(
113 t_ioend - t_start, t_end - t_ioend))
114
115 return model
116
117
118def parse_devices(input_devices):

Callers 2

runMethod · 0.90
runMethod · 0.90

Calls 3

formatMethod · 0.80
load_state_dictMethod · 0.45
state_dictMethod · 0.45

Tested by 1

runMethod · 0.72