MCPcopy Index your code
hub / github.com/openai/guided-diffusion / save

Method save

guided_diffusion/train_util.py:232–255  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

230 logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
231
232 def save(self):
233 def save_checkpoint(rate, params):
234 state_dict = self.mp_trainer.master_params_to_state_dict(params)
235 if dist.get_rank() == 0:
236 logger.log(f"saving model {rate}...")
237 if not rate:
238 filename = f"model{(self.step+self.resume_step):06d}.pt"
239 else:
240 filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
241 with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
242 th.save(state_dict, f)
243
244 save_checkpoint(0, self.mp_trainer.master_params)
245 for rate, params in zip(self.ema_rate, self.ema_params):
246 save_checkpoint(rate, params)
247
248 if dist.get_rank() == 0:
249 with bf.BlobFile(
250 bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
251 "wb",
252 ) as f:
253 th.save(self.opt.state_dict(), f)
254
255 dist.barrier()
256
257
258def parse_resume_step_from_filename(filename):

Callers 4

run_loopMethod · 0.95
save_checkpointMethod · 0.80
save_modelFunction · 0.80
dump_imagesFunction · 0.80

Calls 1

get_blob_logdirFunction · 0.85

Tested by

no test coverage detected