MCPcopy
hub / github.com/black-forest-labs/flux / load_ae

Function load_ae

src/flux/util.py:698–711  ·  view source on GitHub ↗
(name: str, device: str | torch.device = "cuda")

Source from the content-addressed store, hash-verified

696
697
698def load_ae(name: str, device: str | torch.device = "cuda") -> AutoEncoder:
699 config = configs[name]
700 ckpt_path = str(get_checkpoint_path(config.repo_id, config.repo_ae, "FLUX_AE"))
701
702 # Loading the autoencoder
703 print("Init AE")
704 with torch.device("meta"):
705 ae = AutoEncoder(config.ae_params)
706
707 print(f"Loading AE checkpoint: {ckpt_path}")
708 sd = load_sft(ckpt_path, device=str(device))
709 missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
710 print_load_warning(missing, unexpected)
711 return ae
712
713
714def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict:

Callers 8

get_modelsFunction · 0.90
get_modelsFunction · 0.90
get_modelsFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls 3

AutoEncoderClass · 0.90
get_checkpoint_pathFunction · 0.85
print_load_warningFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…