(self, ckpt_path: str)
| 420 | self.model = model_weight_initializer(model=self.model) |
| 421 | |
| 422 | def _load_resuming_checkpoint(self, ckpt_path: str): |
| 423 | logging.info(f"Resuming training from {ckpt_path}") |
| 424 | |
| 425 | with g_pathmgr.open(ckpt_path, "rb") as f: |
| 426 | checkpoint = torch.load(f, map_location="cpu") |
| 427 | load_state_dict_into_model( |
| 428 | model=self.model, |
| 429 | state_dict=checkpoint["model"], |
| 430 | ignore_missing_keys=self.checkpoint_conf.skip_saving_parameters, |
| 431 | ) |
| 432 | |
| 433 | self.optim.optimizer.load_state_dict(checkpoint["optimizer"]) |
| 434 | self.loss.load_state_dict(checkpoint["loss"], strict=True) |
| 435 | self.epoch = checkpoint["epoch"] |
| 436 | self.steps = checkpoint["steps"] |
| 437 | self.ckpt_time_elapsed = checkpoint.get("time_elapsed") |
| 438 | |
| 439 | if self.optim_conf.amp.enabled and "scaler" in checkpoint: |
| 440 | self.scaler.load_state_dict(checkpoint["scaler"]) |
| 441 | |
| 442 | self.best_meter_values = checkpoint.get("best_meter_values", {}) |
| 443 | |
| 444 | if "train_dataset" in checkpoint and self.train_dataset is not None: |
| 445 | self.train_dataset.load_checkpoint_state(checkpoint["train_dataset"]) |
| 446 | |
| 447 | def is_intermediate_val_epoch(self, epoch): |
| 448 | return epoch % self.val_epoch_freq == 0 and epoch < self.max_epochs - 1 |
no test coverage detected