MCPcopy
hub / github.com/TencentARC/Pixal3D / load

Method load

pixal3d/trainers/basic.py:342–389  ·  view source on GitHub ↗

Load a checkpoint. Should be called by all processes.

(self, load_dir, step=0)

Source from the content-addressed store, hash-verified

340 master_params[i].data.copy_(param.data)
341
342 def load(self, load_dir, step=0):
343 """
344 Load a checkpoint.
345 Should be called by all processes.
346 """
347 if self.is_master:
348 print(f'\nLoading checkpoint from step {step}...', end='')
349
350 model_ckpts = {}
351 for name, model in self.models.items():
352 model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True)
353 model_ckpts[name] = model_ckpt
354 model.load_state_dict(model_ckpt)
355 self._state_dicts_to_master_params(self.master_params, model_ckpts)
356 del model_ckpts
357
358 if self.is_master:
359 for i, ema_rate in enumerate(self.ema_rate):
360 ema_ckpts = {}
361 for name, model in self.models.items():
362 ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True)
363 ema_ckpts[name] = ema_ckpt
364 self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts)
365 del ema_ckpts
366
367 misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False)
368 self.optimizer.load_state_dict(misc_ckpt['optimizer'])
369 self.step = misc_ckpt['step']
370 self.data_sampler.load_state_dict(misc_ckpt['data_sampler'])
371 if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
372 self.scaler.load_state_dict(misc_ckpt['scaler'])
373 elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16:
374 self.log_scale = misc_ckpt['log_scale']
375 if self.lr_scheduler_config is not None:
376 self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler'])
377 if self.elastic_controller_config is not None:
378 self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller'])
379 if self.grad_clip is not None and not isinstance(self.grad_clip, float):
380 self.grad_clip.load_state_dict(misc_ckpt['grad_clip'])
381 del misc_ckpt
382
383 if self.world_size > 1:
384 dist.barrier()
385 if self.is_master:
386 print(' Done.')
387
388 if self.world_size > 1:
389 self.check_ddp()
390
391 def save(self, non_blocking=True):
392 """

Callers 15

__init__Method · 0.95
train.pyFile · 0.45
unpack_stateFunction · 0.45
progress_pollFunction · 0.45
_dual_grid_meshFunction · 0.45
workerFunction · 0.45
_pbr_voxelize_viewFunction · 0.45
_pbr_voxelizeFunction · 0.45
mainFunction · 0.45
_dual_grid_mesh_viewFunction · 0.45
check_sha256Function · 0.45

Calls 5

check_ddpMethod · 0.95
read_file_distFunction · 0.85
load_state_dictMethod · 0.45
deviceMethod · 0.45

Tested by

no test coverage detected