(vae_type: str)
| 46 | |
| 47 | |
| 48 | def build_config(vae_type: str) -> Tuple[dict, int]: |
| 49 | if vae_type == "flux": |
| 50 | cfg = PRXFlux() |
| 51 | elif vae_type == "dc-ae": |
| 52 | cfg = PRXDCAE() |
| 53 | else: |
| 54 | raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'") |
| 55 | |
| 56 | config_dict = asdict(cfg) |
| 57 | config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index] |
| 58 | return config_dict |
| 59 | |
| 60 | |
| 61 | def create_parameter_mapping(depth: int) -> dict: |