(
transformer_type: str,
state_dict: Optional[Dict[str, Any]] = None,
weights_only: bool = True,
)
| 577 | |
| 578 | |
| 579 | def convert_transformer( |
| 580 | transformer_type: str, |
| 581 | state_dict: Optional[Dict[str, Any]] = None, |
| 582 | weights_only: bool = True, |
| 583 | ): |
| 584 | PREFIX_KEY = "net." |
| 585 | |
| 586 | if "Cosmos-1.0" in transformer_type: |
| 587 | TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 |
| 588 | TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 |
| 589 | elif "Cosmos-2.0" in transformer_type: |
| 590 | TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 |
| 591 | TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 |
| 592 | elif "Cosmos-2.5" in transformer_type: |
| 593 | TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 |
| 594 | TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 |
| 595 | else: |
| 596 | assert False |
| 597 | |
| 598 | with init_empty_weights(): |
| 599 | config = TRANSFORMER_CONFIGS[transformer_type] |
| 600 | transformer = CosmosTransformer3DModel(**config) |
| 601 | |
| 602 | old2new = {} |
| 603 | new2old = {} |
| 604 | for key in list(state_dict.keys()): |
| 605 | new_key = key[:] |
| 606 | if new_key.startswith(PREFIX_KEY): |
| 607 | new_key = new_key.removeprefix(PREFIX_KEY) |
| 608 | for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): |
| 609 | new_key = new_key.replace(replace_key, rename_key) |
| 610 | print(key, "->", new_key, flush=True) |
| 611 | assert new_key not in new2old, f"new key {new_key} already mapped" |
| 612 | assert key not in old2new, f"old key {key} already mapped" |
| 613 | old2new[key] = new_key |
| 614 | new2old[new_key] = key |
| 615 | update_state_dict_(state_dict, key, new_key) |
| 616 | |
| 617 | for key in list(state_dict.keys()): |
| 618 | for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): |
| 619 | if special_key not in key: |
| 620 | continue |
| 621 | handler_fn_inplace(key, state_dict) |
| 622 | |
| 623 | expected_keys = set(transformer.state_dict().keys()) |
| 624 | mapped_keys = set(state_dict.keys()) |
| 625 | missing_keys = expected_keys - mapped_keys |
| 626 | unexpected_keys = mapped_keys - expected_keys |
| 627 | if missing_keys: |
| 628 | print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr) |
| 629 | for k in missing_keys: |
| 630 | print(k) |
| 631 | sys.exit(1) |
| 632 | if unexpected_keys: |
| 633 | print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr) |
| 634 | for k in unexpected_keys: |
| 635 | print(k) |
| 636 | sys.exit(2) |
no test coverage detected
searching dependent graphs…