| 12 | class DynamicSwapInstaller: |
| 13 | @staticmethod |
| 14 | def _install_module(module: torch.nn.Module, **kwargs): |
| 15 | original_class = module.__class__ |
| 16 | module.__dict__['forge_backup_original_class'] = original_class |
| 17 | |
| 18 | def hacked_get_attr(self, name: str): |
| 19 | if '_parameters' in self.__dict__: |
| 20 | _parameters = self.__dict__['_parameters'] |
| 21 | if name in _parameters: |
| 22 | p = _parameters[name] |
| 23 | if p is None: |
| 24 | return None |
| 25 | if p.__class__ == torch.nn.Parameter: |
| 26 | return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad) |
| 27 | else: |
| 28 | return p.to(**kwargs) |
| 29 | if '_buffers' in self.__dict__: |
| 30 | _buffers = self.__dict__['_buffers'] |
| 31 | if name in _buffers: |
| 32 | return _buffers[name].to(**kwargs) |
| 33 | return super(original_class, self).__getattr__(name) |
| 34 | |
| 35 | module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), { |
| 36 | '__getattr__': hacked_get_attr, |
| 37 | }) |
| 38 | |
| 39 | return |
| 40 | |
| 41 | @staticmethod |
| 42 | def _uninstall_module(module: torch.nn.Module): |