MCPcopy
hub / github.com/coqui-ai/TTS / load_checkpoint

Method load_checkpoint

TTS/encoder/models/base_encoder.py:109–161  ·  view source on GitHub ↗
(
        self,
        config: Coqpit,
        checkpoint_path: str,
        eval: bool = False,
        use_cuda: bool = False,
        criterion=None,
        cache=False,
    )

Source from the content-addressed store, hash-verified

107 return criterion
108
109 def load_checkpoint(
110 self,
111 config: Coqpit,
112 checkpoint_path: str,
113 eval: bool = False,
114 use_cuda: bool = False,
115 criterion=None,
116 cache=False,
117 ):
118 state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
119 try:
120 self.load_state_dict(state["model"])
121 print(" > Model fully restored. ")
122 except (KeyError, RuntimeError) as error:
123 # If eval raise the error
124 if eval:
125 raise error
126
127 print(" > Partial model initialization.")
128 model_dict = self.state_dict()
129 model_dict = set_init_dict(model_dict, state["model"], c)
130 self.load_state_dict(model_dict)
131 del model_dict
132
133 # load the criterion for restore_path
134 if criterion is not None and "criterion" in state:
135 try:
136 criterion.load_state_dict(state["criterion"])
137 except (KeyError, RuntimeError) as error:
138 print(" > Criterion load ignored because of:", error)
139
140 # instance and load the criterion for the encoder classifier in inference time
141 if (
142 eval
143 and criterion is None
144 and "criterion" in state
145 and getattr(config, "map_classid_to_classname", None) is not None
146 ):
147 criterion = self.get_criterion(config, len(config.map_classid_to_classname))
148 criterion.load_state_dict(state["criterion"])
149
150 if use_cuda:
151 self.cuda()
152 if criterion is not None:
153 criterion = criterion.cuda()
154
155 if eval:
156 self.eval()
157 assert not self.training
158
159 if not eval:
160 return criterion, state["step"]
161 return criterion

Callers

nothing calls this directly

Calls 7

get_criterionMethod · 0.95
load_fsspecFunction · 0.90
set_init_dictFunction · 0.90
load_state_dictMethod · 0.80
state_dictMethod · 0.80
deviceMethod · 0.45
evalMethod · 0.45

Tested by

no test coverage detected