MCPcopy
hub / github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI / go

Function go

infer/lib/train/utils.py:25–51  ·  view source on GitHub ↗
(model, bkey)

Source from the content-addressed store, hash-verified

23
24 ##################
25 def go(model, bkey):
26 saved_state_dict = checkpoint_dict[bkey]
27 if hasattr(model, "module"):
28 state_dict = model.module.state_dict()
29 else:
30 state_dict = model.state_dict()
31 new_state_dict = {}
32 for k, v in state_dict.items(): # 模型需要的shape
33 try:
34 new_state_dict[k] = saved_state_dict[k]
35 if saved_state_dict[k].shape != state_dict[k].shape:
36 logger.warning(
37 "shape-%s-mismatch. need: %s, get: %s",
38 k,
39 state_dict[k].shape,
40 saved_state_dict[k].shape,
41 ) #
42 raise KeyError
43 except:
44 # logger.info(traceback.format_exc())
45 logger.info("%s is not in the checkpoint", k) # pretrain缺失的
46 new_state_dict[k] = v # 模型自带的随机值
47 if hasattr(model, "module"):
48 model.module.load_state_dict(new_state_dict, strict=False)
49 else:
50 model.load_state_dict(new_state_dict, strict=False)
51 return model
52
53 go(combd, "combd")
54 model = go(sbd, "sbd")

Callers 1

load_checkpoint_dFunction · 0.85

Calls 1

itemsMethod · 0.80

Tested by

no test coverage detected