MCPcopy Index your code
hub / github.com/huggingface/diffusers / register_hook

Method register_hook

src/diffusers/hooks/hooks.py:177–215  ·  view source on GitHub ↗
(self, hook: ModelHook, name: str)

Source from the content-addressed store, hash-verified

175 self._fn_refs = []
176
177 def register_hook(self, hook: ModelHook, name: str) -> None:
178 if name in self.hooks.keys():
179 raise ValueError(
180 f"Hook with name {name} already exists in the registry. Please use a different name or "
181 f"first remove the existing hook and then add a new one."
182 )
183
184 self._module_ref = hook.initialize_hook(self._module_ref)
185
186 def create_new_forward(function_reference: HookFunctionReference):
187 def new_forward(module, *args, **kwargs):
188 args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
189 output = function_reference.forward(*args, **kwargs)
190 return function_reference.post_forward(module, output)
191
192 return new_forward
193
194 forward = self._module_ref.forward
195
196 fn_ref = HookFunctionReference()
197 fn_ref.pre_forward = hook.pre_forward
198 fn_ref.post_forward = hook.post_forward
199 fn_ref.forward = forward
200
201 if hasattr(hook, "new_forward"):
202 fn_ref.original_forward = forward
203 fn_ref.forward = functools.update_wrapper(
204 functools.partial(hook.new_forward, self._module_ref), hook.new_forward
205 )
206
207 rewritten_forward = create_new_forward(fn_ref)
208 self._module_ref.forward = functools.update_wrapper(
209 functools.partial(rewritten_forward, self._module_ref), rewritten_forward
210 )
211
212 hook.fn_ref = fn_ref
213 self.hooks[name] = hook
214 self._hook_order.append(name)
215 self._fn_refs.append(fn_ref)
216
217 def get_hook(self, name: str) -> ModelHook | None:
218 return self.hooks.get(name, None)

Calls 2

initialize_hookMethod · 0.45