MCPcopy
hub / github.com/openai/guided-diffusion / _load_ema_parameters

Method _load_ema_parameters

guided_diffusion/train_util.py:125–139  ·  view source on GitHub ↗
(self, rate)

Source from the content-addressed store, hash-verified

123 dist_util.sync_params(self.model.parameters())
124
125 def _load_ema_parameters(self, rate):
126 ema_params = copy.deepcopy(self.mp_trainer.master_params)
127
128 main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
129 ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
130 if ema_checkpoint:
131 if dist.get_rank() == 0:
132 logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
133 state_dict = dist_util.load_state_dict(
134 ema_checkpoint, map_location=dist_util.dev()
135 )
136 ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
137
138 dist_util.sync_params(ema_params)
139 return ema_params
140
141 def _load_optimizer_state(self):
142 main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint

Callers 1

__init__Method · 0.95

Calls 4

find_resume_checkpointFunction · 0.85
find_ema_checkpointFunction · 0.85
logMethod · 0.80

Tested by

no test coverage detected