| 55 | self.ema.eval() |
| 56 | |
| 57 | def _load_pretrained_parameters(self, args): |
| 58 | checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage) |
| 59 | if "ema" in checkpoint: # supports checkpoints from train.py |
| 60 | self.logging.info("Using ema ckpt!") |
| 61 | checkpoint = checkpoint["ema"] |
| 62 | |
| 63 | model_dict = self.model.state_dict() |
| 64 | # 1. filter out unnecessary keys |
| 65 | pretrained_dict = {} |
| 66 | for k, v in checkpoint.items(): |
| 67 | if k in model_dict: |
| 68 | pretrained_dict[k] = v |
| 69 | else: |
| 70 | self.logging.info("Ignoring: {}".format(k)) |
| 71 | self.logging.info(f"Successfully Load {len(pretrained_dict) / len(checkpoint.items()) * 100}% original pretrained model weights ") |
| 72 | |
| 73 | # 2. overwrite entries in the existing state dict |
| 74 | model_dict.update(pretrained_dict) |
| 75 | self.model.load_state_dict(model_dict) |
| 76 | self.logging.info(f"Successfully load model at {args.pretrained}!") |
| 77 | |
| 78 | # self.global_step = int(args.pretrained.split("/")[-1].split(".")[0]) # dirty implementation |
| 79 | |
| 80 | def training_step(self, batch, batch_idx): |
| 81 | x = batch["video"].to(self.device) |