MCPcopy Index your code
hub / github.com/huggingface/diffusers / FBCBlockHook

Class FBCBlockHook

src/diffusers/hooks/first_block_cache.py:145–190  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

143
144
145class FBCBlockHook(ModelHook):
146 def __init__(self, state_manager: StateManager, is_tail: bool = False):
147 super().__init__()
148 self.state_manager = state_manager
149 self.is_tail = is_tail
150 self._metadata = None
151
152 def initialize_hook(self, module):
153 unwrapped_module = unwrap_module(module)
154 self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
155 return module
156
157 def new_forward(self, module: torch.nn.Module, *args, **kwargs):
158 original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
159 original_encoder_hidden_states = None
160 if self._metadata.return_encoder_hidden_states_index is not None:
161 original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
162 "encoder_hidden_states", args, kwargs
163 )
164
165 shared_state = self.state_manager.get_state()
166
167 if shared_state.should_compute:
168 output = self.fn_ref.original_forward(*args, **kwargs)
169 if self.is_tail:
170 hidden_states_residual = encoder_hidden_states_residual = None
171 if isinstance(output, tuple):
172 hidden_states_residual = (
173 output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
174 )
175 encoder_hidden_states_residual = (
176 output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
177 )
178 else:
179 hidden_states_residual = output - shared_state.head_block_output
180 shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
181 return output
182
183 if original_encoder_hidden_states is None:
184 return_output = original_hidden_states
185 else:
186 return_output = [None, None]
187 return_output[self._metadata.return_hidden_states_index] = original_hidden_states
188 return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
189 return_output = tuple(return_output)
190 return return_output
191
192
193def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:

Callers 1

_apply_fbc_block_hookFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…