Finetune from a checkpoint. Should be called by all processes.
(self, finetune_ckpt)
| 487 | return remapped_ckpt |
| 488 | |
| 489 | def finetune_from(self, finetune_ckpt): |
| 490 | """ |
| 491 | Finetune from a checkpoint. |
| 492 | Should be called by all processes. |
| 493 | """ |
| 494 | # Allow missing keys (e.g., register_buffer parameters) |
| 495 | ALLOWED_MISSING_KEYS = {'rope_phases'} |
| 496 | |
| 497 | if self.is_master: |
| 498 | print('\nFinetuning from:') |
| 499 | for name, path in finetune_ckpt.items(): |
| 500 | print(f' - {name}: {path}') |
| 501 | |
| 502 | model_ckpts = {} |
| 503 | for name, model in self.models.items(): |
| 504 | model_state_dict = model.state_dict() |
| 505 | if name in finetune_ckpt: |
| 506 | model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True) |
| 507 | |
| 508 | # Remap checkpoint keys to handle structural changes (e.g., ProjectAttention wrapper) |
| 509 | model_ckpt = self._remap_checkpoint_keys(model_ckpt, model_state_dict) |
| 510 | |
| 511 | # Check extra keys (in ckpt but not in model) |
| 512 | for k, v in model_ckpt.items(): |
| 513 | if k not in model_state_dict: |
| 514 | if self.is_master: |
| 515 | print(f'Warning: {k} not found in model_state_dict, skipped.') |
| 516 | model_ckpt[k] = None |
| 517 | elif model_ckpt[k].shape != model_state_dict[k].shape: |
| 518 | if self.is_master: |
| 519 | print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.') |
| 520 | model_ckpt[k] = model_state_dict[k] |
| 521 | model_ckpt = {k: v for k, v in model_ckpt.items() if v is not None} |
| 522 | |
| 523 | # Check missing keys (in model but not in ckpt) |
| 524 | missing_keys = set(model_state_dict.keys()) - set(model_ckpt.keys()) |
| 525 | unexpected_missing = missing_keys - ALLOWED_MISSING_KEYS |
| 526 | if unexpected_missing and self.is_master: |
| 527 | print(f'Error: Missing keys in checkpoint: {unexpected_missing}') |
| 528 | raise RuntimeError(f'Missing keys in checkpoint: {unexpected_missing}') |
| 529 | if missing_keys & ALLOWED_MISSING_KEYS and self.is_master: |
| 530 | print(f'Info: Using model initialized values for: {missing_keys & ALLOWED_MISSING_KEYS}') |
| 531 | |
| 532 | # Fill in missing keys (using model initialized values) |
| 533 | for k in missing_keys: |
| 534 | model_ckpt[k] = model_state_dict[k] |
| 535 | |
| 536 | model_ckpts[name] = model_ckpt |
| 537 | model.load_state_dict(model_ckpt) |
| 538 | else: |
| 539 | if self.is_master: |
| 540 | print(f'Warning: {name} not found in finetune_ckpt, skipped.') |
| 541 | model_ckpts[name] = model_state_dict |
| 542 | self._state_dicts_to_master_params(self.master_params, model_ckpts) |
| 543 | if self.is_master: |
| 544 | for i, ema_rate in enumerate(self.ema_rate): |
| 545 | self._state_dicts_to_master_params(self.ema_params[i], model_ckpts) |
| 546 | del model_ckpts |
no test coverage detected