MCPcopy Index your code
hub / github.com/openai/improved-diffusion / save

Method save

improved_diffusion/train_util.py:271–294  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

269 logger.logkv("lg_loss_scale", self.lg_loss_scale)
270
271 def save(self):
272 def save_checkpoint(rate, params):
273 state_dict = self._master_params_to_state_dict(params)
274 if dist.get_rank() == 0:
275 logger.log(f"saving model {rate}...")
276 if not rate:
277 filename = f"model{(self.step+self.resume_step):06d}.pt"
278 else:
279 filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
280 with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
281 th.save(state_dict, f)
282
283 save_checkpoint(0, self.master_params)
284 for rate, params in zip(self.ema_rate, self.ema_params):
285 save_checkpoint(rate, params)
286
287 if dist.get_rank() == 0:
288 with bf.BlobFile(
289 bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
290 "wb",
291 ) as f:
292 th.save(self.opt.state_dict(), f)
293
294 dist.barrier()
295
296 def _master_params_to_state_dict(self, master_params):
297 if self.use_fp16:

Callers 4

run_loopMethod · 0.95
save_checkpointMethod · 0.80
dump_imagesFunction · 0.80
mainFunction · 0.80

Calls 1

get_blob_logdirFunction · 0.85

Tested by

no test coverage detected