MCPcopy
hub / github.com/TencentARC/Pixal3D / save

Method save

pixal3d/trainers/basic.py:391–444  ·  view source on GitHub ↗

Save a checkpoint. Should be called only by the rank 0 process.

(self, non_blocking=True)

Source from the content-addressed store, hash-verified

389 self.check_ddp()
390
391 def save(self, non_blocking=True):
392 """
393 Save a checkpoint.
394 Should be called only by the rank 0 process.
395 """
396 assert self.is_master, 'save() should be called only by the rank 0 process.'
397 print(f'\nSaving checkpoint at step {self.step}...', end='')
398
399 model_ckpts = self._master_params_to_state_dicts(self.master_params)
400 for name, model_ckpt in model_ckpts.items():
401 model_ckpt = {k: v.cpu() for k, v in model_ckpt.items()} # Move to CPU for saving
402 if non_blocking:
403 threading.Thread(
404 target=torch.save,
405 args=(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')),
406 ).start()
407 else:
408 torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt'))
409
410 for i, ema_rate in enumerate(self.ema_rate):
411 ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i])
412 for name, ema_ckpt in ema_ckpts.items():
413 ema_ckpt = {k: v.cpu() for k, v in ema_ckpt.items()} # Move to CPU for saving
414 if non_blocking:
415 threading.Thread(
416 target=torch.save,
417 args=(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt')),
418 ).start()
419 else:
420 torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt'))
421
422 misc_ckpt = {
423 'optimizer': self.optimizer.state_dict(),
424 'step': self.step,
425 'data_sampler': self.data_sampler.state_dict(),
426 }
427 if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
428 misc_ckpt['scaler'] = self.scaler.state_dict()
429 elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16:
430 misc_ckpt['log_scale'] = self.log_scale
431 if self.lr_scheduler_config is not None:
432 misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict()
433 if self.elastic_controller_config is not None:
434 misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict()
435 if self.grad_clip is not None and not isinstance(self.grad_clip, float):
436 misc_ckpt['grad_clip'] = self.grad_clip.state_dict()
437 if non_blocking:
438 threading.Thread(
439 target=torch.save,
440 args=(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt')),
441 ).start()
442 else:
443 torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt'))
444 print(' Done.')
445
446 def _remap_checkpoint_keys(self, model_ckpt, model_state_dict):
447 """

Callers 9

check_abortMethod · 0.95
runMethod · 0.95
preprocessFunction · 0.45
generate_3dFunction · 0.45
run_inferenceFunction · 0.45
mainFunction · 0.45
mainFunction · 0.45
extract_imageFunction · 0.45
visualize_projectionMethod · 0.45

Calls 3

cpuMethod · 0.45
state_dictMethod · 0.45

Tested by

no test coverage detected