MCPcopy
hub / github.com/Vchitect/Latte / _load_pretrained_parameters

Method _load_pretrained_parameters

train_with_img_pl.py:57–78  ·  view source on GitHub ↗
(self, args)

Source from the content-addressed store, hash-verified

55 self.ema.eval()
56
57 def _load_pretrained_parameters(self, args):
58 checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage)
59 if "ema" in checkpoint: # supports checkpoints from train.py
60 self.logging.info("Using ema ckpt!")
61 checkpoint = checkpoint["ema"]
62
63 model_dict = self.model.state_dict()
64 # 1. filter out unnecessary keys
65 pretrained_dict = {}
66 for k, v in checkpoint.items():
67 if k in model_dict:
68 pretrained_dict[k] = v
69 else:
70 self.logging.info("Ignoring: {}".format(k))
71 self.logging.info(f"Successfully Load {len(pretrained_dict) / len(checkpoint.items()) * 100}% original pretrained model weights ")
72
73 # 2. overwrite entries in the existing state dict
74 model_dict.update(pretrained_dict)
75 self.model.load_state_dict(model_dict)
76 self.logging.info(f"Successfully load model at {args.pretrained}!")
77
78 # self.global_step = int(args.pretrained.split("/")[-1].split(".")[0]) # dirty implementation
79
80 def training_step(self, batch, batch_idx):
81 x = batch["video"].to(self.device)

Callers 1

__init__Method · 0.95

Calls 2

loadMethod · 0.80
updateMethod · 0.45

Tested by

no test coverage detected