(self, *args, **kwargs)
| 742 | |
| 743 | @functools.wraps(original_init) |
| 744 | def init(self, *args, **kwargs): |
| 745 | if not isinstance(self, ConfigMixin): |
| 746 | raise RuntimeError( |
| 747 | f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " |
| 748 | "not inherit from `ConfigMixin`." |
| 749 | ) |
| 750 | |
| 751 | # Ignore private kwargs in the init. Retrieve all passed attributes |
| 752 | init_kwargs = dict(kwargs.items()) |
| 753 | |
| 754 | # Retrieve default values |
| 755 | fields = dataclasses.fields(self) |
| 756 | default_kwargs = {} |
| 757 | for field in fields: |
| 758 | # ignore flax specific attributes |
| 759 | if field.name in self._flax_internal_args: |
| 760 | continue |
| 761 | if type(field.default) == dataclasses._MISSING_TYPE: |
| 762 | default_kwargs[field.name] = None |
| 763 | else: |
| 764 | default_kwargs[field.name] = getattr(self, field.name) |
| 765 | |
| 766 | # Make sure init_kwargs override default kwargs |
| 767 | new_kwargs = {**default_kwargs, **init_kwargs} |
| 768 | # dtype should be part of `init_kwargs`, but not `new_kwargs` |
| 769 | if "dtype" in new_kwargs: |
| 770 | new_kwargs.pop("dtype") |
| 771 | |
| 772 | # Get positional arguments aligned with kwargs |
| 773 | for i, arg in enumerate(args): |
| 774 | name = fields[i].name |
| 775 | new_kwargs[name] = arg |
| 776 | |
| 777 | # Take note of the parameters that were not present in the loaded config |
| 778 | if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: |
| 779 | new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs)) |
| 780 | |
| 781 | getattr(self, "register_to_config")(**new_kwargs) |
| 782 | original_init(self, *args, **kwargs) |
| 783 | |
| 784 | cls.__init__ = init |
| 785 | return cls |
no test coverage detected
searching dependent graphs…