(cls, config_dict, **kwargs)
| 497 | |
| 498 | @classmethod |
| 499 | def extract_init_dict(cls, config_dict, **kwargs): |
| 500 | # Skip keys that were not present in the original config, so default __init__ values were used |
| 501 | used_defaults = config_dict.get("_use_default_values", []) |
| 502 | config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"} |
| 503 | |
| 504 | # 0. Copy origin config dict |
| 505 | original_dict = dict(config_dict.items()) |
| 506 | |
| 507 | # 1. Retrieve expected config attributes from __init__ signature |
| 508 | expected_keys = cls._get_init_keys(cls) |
| 509 | expected_keys.remove("self") |
| 510 | # remove general kwargs if present in dict |
| 511 | if "kwargs" in expected_keys: |
| 512 | expected_keys.remove("kwargs") |
| 513 | # remove flax internal keys |
| 514 | if hasattr(cls, "_flax_internal_args"): |
| 515 | for arg in cls._flax_internal_args: |
| 516 | expected_keys.remove(arg) |
| 517 | |
| 518 | # 2. Remove attributes that cannot be expected from expected config attributes |
| 519 | # remove keys to be ignored |
| 520 | if len(cls.ignore_for_config) > 0: |
| 521 | expected_keys = expected_keys - set(cls.ignore_for_config) |
| 522 | |
| 523 | # load diffusers library to import compatible and original scheduler |
| 524 | diffusers_library = importlib.import_module(__name__.split(".")[0]) |
| 525 | |
| 526 | if cls.has_compatibles: |
| 527 | compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)] |
| 528 | else: |
| 529 | compatible_classes = [] |
| 530 | |
| 531 | expected_keys_comp_cls = set() |
| 532 | for c in compatible_classes: |
| 533 | expected_keys_c = cls._get_init_keys(c) |
| 534 | expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c) |
| 535 | expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls) |
| 536 | config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls} |
| 537 | |
| 538 | # remove attributes from orig class that cannot be expected |
| 539 | orig_cls_name = config_dict.pop("_class_name", cls.__name__) |
| 540 | if ( |
| 541 | isinstance(orig_cls_name, str) |
| 542 | and orig_cls_name != cls.__name__ |
| 543 | and hasattr(diffusers_library, orig_cls_name) |
| 544 | ): |
| 545 | orig_cls = getattr(diffusers_library, orig_cls_name) |
| 546 | unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys |
| 547 | config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig} |
| 548 | elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)): |
| 549 | raise ValueError( |
| 550 | "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)." |
| 551 | ) |
| 552 | |
| 553 | # remove private attributes |
| 554 | config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} |
| 555 | |
| 556 | # remove quantization_config |
no test coverage detected