MCPcopy
hub / github.com/huggingface/diffusers / convert_transformer

Function convert_transformer

scripts/convert_cosmos_to_diffusers.py:579–639  ·  view source on GitHub ↗
(
    transformer_type: str,
    state_dict: Optional[Dict[str, Any]] = None,
    weights_only: bool = True,
)

Source from the content-addressed store, hash-verified

577
578
579def convert_transformer(
580 transformer_type: str,
581 state_dict: Optional[Dict[str, Any]] = None,
582 weights_only: bool = True,
583):
584 PREFIX_KEY = "net."
585
586 if "Cosmos-1.0" in transformer_type:
587 TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
588 TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
589 elif "Cosmos-2.0" in transformer_type:
590 TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
591 TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
592 elif "Cosmos-2.5" in transformer_type:
593 TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
594 TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
595 else:
596 assert False
597
598 with init_empty_weights():
599 config = TRANSFORMER_CONFIGS[transformer_type]
600 transformer = CosmosTransformer3DModel(**config)
601
602 old2new = {}
603 new2old = {}
604 for key in list(state_dict.keys()):
605 new_key = key[:]
606 if new_key.startswith(PREFIX_KEY):
607 new_key = new_key.removeprefix(PREFIX_KEY)
608 for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
609 new_key = new_key.replace(replace_key, rename_key)
610 print(key, "->", new_key, flush=True)
611 assert new_key not in new2old, f"new key {new_key} already mapped"
612 assert key not in old2new, f"old key {key} already mapped"
613 old2new[key] = new_key
614 new2old[new_key] = key
615 update_state_dict_(state_dict, key, new_key)
616
617 for key in list(state_dict.keys()):
618 for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
619 if special_key not in key:
620 continue
621 handler_fn_inplace(key, state_dict)
622
623 expected_keys = set(transformer.state_dict().keys())
624 mapped_keys = set(state_dict.keys())
625 missing_keys = expected_keys - mapped_keys
626 unexpected_keys = mapped_keys - expected_keys
627 if missing_keys:
628 print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr)
629 for k in missing_keys:
630 print(k)
631 sys.exit(1)
632 if unexpected_keys:
633 print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr)
634 for k in unexpected_keys:
635 print(k)
636 sys.exit(2)

Callers 1

Calls 4

update_state_dict_Function · 0.70
state_dictMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…