MCPcopy
hub / github.com/TencentARC/Pixal3D / finetune_from

Method finetune_from

pixal3d/trainers/basic.py:489–554  ·  view source on GitHub ↗

Finetune from a checkpoint. Should be called by all processes.

(self, finetune_ckpt)

Source from the content-addressed store, hash-verified

487 return remapped_ckpt
488
489 def finetune_from(self, finetune_ckpt):
490 """
491 Finetune from a checkpoint.
492 Should be called by all processes.
493 """
494 # Allow missing keys (e.g., register_buffer parameters)
495 ALLOWED_MISSING_KEYS = {'rope_phases'}
496
497 if self.is_master:
498 print('\nFinetuning from:')
499 for name, path in finetune_ckpt.items():
500 print(f' - {name}: {path}')
501
502 model_ckpts = {}
503 for name, model in self.models.items():
504 model_state_dict = model.state_dict()
505 if name in finetune_ckpt:
506 model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True)
507
508 # Remap checkpoint keys to handle structural changes (e.g., ProjectAttention wrapper)
509 model_ckpt = self._remap_checkpoint_keys(model_ckpt, model_state_dict)
510
511 # Check extra keys (in ckpt but not in model)
512 for k, v in model_ckpt.items():
513 if k not in model_state_dict:
514 if self.is_master:
515 print(f'Warning: {k} not found in model_state_dict, skipped.')
516 model_ckpt[k] = None
517 elif model_ckpt[k].shape != model_state_dict[k].shape:
518 if self.is_master:
519 print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.')
520 model_ckpt[k] = model_state_dict[k]
521 model_ckpt = {k: v for k, v in model_ckpt.items() if v is not None}
522
523 # Check missing keys (in model but not in ckpt)
524 missing_keys = set(model_state_dict.keys()) - set(model_ckpt.keys())
525 unexpected_missing = missing_keys - ALLOWED_MISSING_KEYS
526 if unexpected_missing and self.is_master:
527 print(f'Error: Missing keys in checkpoint: {unexpected_missing}')
528 raise RuntimeError(f'Missing keys in checkpoint: {unexpected_missing}')
529 if missing_keys & ALLOWED_MISSING_KEYS and self.is_master:
530 print(f'Info: Using model initialized values for: {missing_keys & ALLOWED_MISSING_KEYS}')
531
532 # Fill in missing keys (using model initialized values)
533 for k in missing_keys:
534 model_ckpt[k] = model_state_dict[k]
535
536 model_ckpts[name] = model_ckpt
537 model.load_state_dict(model_ckpt)
538 else:
539 if self.is_master:
540 print(f'Warning: {name} not found in finetune_ckpt, skipped.')
541 model_ckpts[name] = model_state_dict
542 self._state_dicts_to_master_params(self.master_params, model_ckpts)
543 if self.is_master:
544 for i, ema_rate in enumerate(self.ema_rate):
545 self._state_dicts_to_master_params(self.ema_params[i], model_ckpts)
546 del model_ckpts

Callers 1

__init__Method · 0.95

Calls 7

check_ddpMethod · 0.95
read_file_distFunction · 0.85
state_dictMethod · 0.45
loadMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected