MCPcopy Index your code
hub / github.com/microsoft/TRELLIS.2 / load

Method load

trellis2/trainers/basic.py:319–366  ·  view source on GitHub ↗

Load a checkpoint. Should be called by all processes.

(self, load_dir, step=0)

Source from the content-addressed store, hash-verified

317 master_params[i].data.copy_(param.data)
318
319 def load(self, load_dir, step=0):
320 """
321 Load a checkpoint.
322 Should be called by all processes.
323 """
324 if self.is_master:
325 print(f'\nLoading checkpoint from step {step}...', end='')
326
327 model_ckpts = {}
328 for name, model in self.models.items():
329 model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True)
330 model_ckpts[name] = model_ckpt
331 model.load_state_dict(model_ckpt)
332 self._state_dicts_to_master_params(self.master_params, model_ckpts)
333 del model_ckpts
334
335 if self.is_master:
336 for i, ema_rate in enumerate(self.ema_rate):
337 ema_ckpts = {}
338 for name, model in self.models.items():
339 ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True)
340 ema_ckpts[name] = ema_ckpt
341 self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts)
342 del ema_ckpts
343
344 misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False)
345 self.optimizer.load_state_dict(misc_ckpt['optimizer'])
346 self.step = misc_ckpt['step']
347 self.data_sampler.load_state_dict(misc_ckpt['data_sampler'])
348 if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
349 self.scaler.load_state_dict(misc_ckpt['scaler'])
350 elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16:
351 self.log_scale = misc_ckpt['log_scale']
352 if self.lr_scheduler_config is not None:
353 self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler'])
354 if self.elastic_controller_config is not None:
355 self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller'])
356 if self.grad_clip is not None and not isinstance(self.grad_clip, float):
357 self.grad_clip.load_state_dict(misc_ckpt['grad_clip'])
358 del misc_ckpt
359
360 if self.world_size > 1:
361 dist.barrier()
362 if self.is_master:
363 print(' Done.')
364
365 if self.world_size > 1:
366 self.check_ddp()
367
368 def save(self, non_blocking=True):
369 """

Callers 15

__init__Method · 0.95
train.pyFile · 0.45
shapeimage_to_texFunction · 0.45
app_texturing.pyFile · 0.45
app.pyFile · 0.45
__init__Method · 0.45
from_pretrainedMethod · 0.45
from_pretrainedFunction · 0.45
finetune_fromMethod · 0.45
__init__Method · 0.45

Calls 5

check_ddpMethod · 0.95
read_file_distFunction · 0.85
load_state_dictMethod · 0.45
deviceMethod · 0.45

Tested by

no test coverage detected