| 175 | self._fn_refs = [] |
| 176 | |
| 177 | def register_hook(self, hook: ModelHook, name: str) -> None: |
| 178 | if name in self.hooks.keys(): |
| 179 | raise ValueError( |
| 180 | f"Hook with name {name} already exists in the registry. Please use a different name or " |
| 181 | f"first remove the existing hook and then add a new one." |
| 182 | ) |
| 183 | |
| 184 | self._module_ref = hook.initialize_hook(self._module_ref) |
| 185 | |
| 186 | def create_new_forward(function_reference: HookFunctionReference): |
| 187 | def new_forward(module, *args, **kwargs): |
| 188 | args, kwargs = function_reference.pre_forward(module, *args, **kwargs) |
| 189 | output = function_reference.forward(*args, **kwargs) |
| 190 | return function_reference.post_forward(module, output) |
| 191 | |
| 192 | return new_forward |
| 193 | |
| 194 | forward = self._module_ref.forward |
| 195 | |
| 196 | fn_ref = HookFunctionReference() |
| 197 | fn_ref.pre_forward = hook.pre_forward |
| 198 | fn_ref.post_forward = hook.post_forward |
| 199 | fn_ref.forward = forward |
| 200 | |
| 201 | if hasattr(hook, "new_forward"): |
| 202 | fn_ref.original_forward = forward |
| 203 | fn_ref.forward = functools.update_wrapper( |
| 204 | functools.partial(hook.new_forward, self._module_ref), hook.new_forward |
| 205 | ) |
| 206 | |
| 207 | rewritten_forward = create_new_forward(fn_ref) |
| 208 | self._module_ref.forward = functools.update_wrapper( |
| 209 | functools.partial(rewritten_forward, self._module_ref), rewritten_forward |
| 210 | ) |
| 211 | |
| 212 | hook.fn_ref = fn_ref |
| 213 | self.hooks[name] = hook |
| 214 | self._hook_order.append(name) |
| 215 | self._fn_refs.append(fn_ref) |
| 216 | |
| 217 | def get_hook(self, name: str) -> ModelHook | None: |
| 218 | return self.hooks.get(name, None) |