MCPcopy
hub / github.com/yangchris11/samurai / _load_resuming_checkpoint

Method _load_resuming_checkpoint

sam2/training/trainer.py:422–445  ·  view source on GitHub ↗
(self, ckpt_path: str)

Source from the content-addressed store, hash-verified

420 self.model = model_weight_initializer(model=self.model)
421
422 def _load_resuming_checkpoint(self, ckpt_path: str):
423 logging.info(f"Resuming training from {ckpt_path}")
424
425 with g_pathmgr.open(ckpt_path, "rb") as f:
426 checkpoint = torch.load(f, map_location="cpu")
427 load_state_dict_into_model(
428 model=self.model,
429 state_dict=checkpoint["model"],
430 ignore_missing_keys=self.checkpoint_conf.skip_saving_parameters,
431 )
432
433 self.optim.optimizer.load_state_dict(checkpoint["optimizer"])
434 self.loss.load_state_dict(checkpoint["loss"], strict=True)
435 self.epoch = checkpoint["epoch"]
436 self.steps = checkpoint["steps"]
437 self.ckpt_time_elapsed = checkpoint.get("time_elapsed")
438
439 if self.optim_conf.amp.enabled and "scaler" in checkpoint:
440 self.scaler.load_state_dict(checkpoint["scaler"])
441
442 self.best_meter_values = checkpoint.get("best_meter_values", {})
443
444 if "train_dataset" in checkpoint and self.train_dataset is not None:
445 self.train_dataset.load_checkpoint_state(checkpoint["train_dataset"])
446
447 def is_intermediate_val_epoch(self, epoch):
448 return epoch % self.val_epoch_freq == 0 and epoch < self.max_epochs - 1

Callers 1

load_checkpointMethod · 0.95

Calls 5

infoMethod · 0.80
load_state_dictMethod · 0.80
loadMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected