(name: str, device: str | torch.device = "cuda")
| 696 | |
| 697 | |
| 698 | def load_ae(name: str, device: str | torch.device = "cuda") -> AutoEncoder: |
| 699 | config = configs[name] |
| 700 | ckpt_path = str(get_checkpoint_path(config.repo_id, config.repo_ae, "FLUX_AE")) |
| 701 | |
| 702 | # Loading the autoencoder |
| 703 | print("Init AE") |
| 704 | with torch.device("meta"): |
| 705 | ae = AutoEncoder(config.ae_params) |
| 706 | |
| 707 | print(f"Loading AE checkpoint: {ckpt_path}") |
| 708 | sd = load_sft(ckpt_path, device=str(device)) |
| 709 | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) |
| 710 | print_load_warning(missing, unexpected) |
| 711 | return ae |
| 712 | |
| 713 | |
| 714 | def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict: |
no test coverage detected
searching dependent graphs…