MCPcopy Index your code
hub / github.com/GuyTevet/motion-diffusion-model / save

Method save

train/training_loop.py:402–444  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

400 return self.step + self.resume_step
401
402 def save(self):
403 def save_checkpoint():
404 def del_clip(state_dict):
405 # Do not save CLIP weights
406 clip_weights = [
407 e for e in state_dict.keys() if e.startswith('clip_model.')
408 ]
409 for e in clip_weights:
410 del state_dict[e]
411
412 if self.use_fp16:
413 state_dict = self.model.state_dict()
414 else:
415 state_dict = self.mp_trainer.master_params_to_state_dict(
416 self.mp_trainer.master_params)
417 del_clip(state_dict)
418
419 if self.args.use_ema:
420 # save both the model and the average model
421 state_dict_avg = self.model_avg.state_dict()
422 del_clip(state_dict_avg)
423 state_dict = {'model': state_dict, 'model_avg': state_dict_avg}
424
425 logger.log(f"saving model...")
426 filename = self.ckpt_file_name()
427 with bf.BlobFile(bf.join(self.save_dir, filename), "wb") as f:
428 torch.save(state_dict, f)
429
430 save_checkpoint()
431
432 with bf.BlobFile(
433 bf.join(self.save_dir, f"opt{(self.total_step()):09d}.pt"),
434 "wb",
435 ) as f:
436 opt_state = self.opt.state_dict()
437 if self.use_fp16:
438 # with fp16 we also save the state dict
439 opt_state = {
440 'opt': opt_state,
441 'scaler': self.scaler.state_dict(),
442 }
443
444 torch.save(opt_state, f)
445
446
447def parse_resume_step_from_filename(filename):

Callers 6

run_loopMethod · 0.95
mainFunction · 0.45
mainFunction · 0.45
save_npyMethod · 0.45
npy2smplMethod · 0.45
save_checkpointMethod · 0.45

Calls 1

total_stepMethod · 0.95

Tested by

no test coverage detected