MCPcopy
hub / github.com/Rudrabha/Wav2Lip / load_checkpoint

Function load_checkpoint

hq_wav2lip_train.py:375–395  ·  view source on GitHub ↗
(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True)

Source from the content-addressed store, hash-verified

373
374
375def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
376 global global_step
377 global global_epoch
378
379 print("Load checkpoint from: {}".format(path))
380 checkpoint = _load(path)
381 s = checkpoint["state_dict"]
382 new_s = {}
383 for k, v in s.items():
384 new_s[k.replace('module.', '')] = v
385 model.load_state_dict(new_s)
386 if not reset_optimizer:
387 optimizer_state = checkpoint["optimizer"]
388 if optimizer_state is not None:
389 print("Load optimizer state from {}".format(path))
390 optimizer.load_state_dict(checkpoint["optimizer"])
391 if overwrite_global_states:
392 global_step = checkpoint["global_step"]
393 global_epoch = checkpoint["global_epoch"]
394
395 return model
396
397if __name__ == "__main__":
398 checkpoint_dir = args.checkpoint_dir

Callers 1

Calls 1

_loadFunction · 0.70

Tested by

no test coverage detected