| 143 | |
| 144 | |
| 145 | class FBCBlockHook(ModelHook): |
| 146 | def __init__(self, state_manager: StateManager, is_tail: bool = False): |
| 147 | super().__init__() |
| 148 | self.state_manager = state_manager |
| 149 | self.is_tail = is_tail |
| 150 | self._metadata = None |
| 151 | |
| 152 | def initialize_hook(self, module): |
| 153 | unwrapped_module = unwrap_module(module) |
| 154 | self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) |
| 155 | return module |
| 156 | |
| 157 | def new_forward(self, module: torch.nn.Module, *args, **kwargs): |
| 158 | original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) |
| 159 | original_encoder_hidden_states = None |
| 160 | if self._metadata.return_encoder_hidden_states_index is not None: |
| 161 | original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( |
| 162 | "encoder_hidden_states", args, kwargs |
| 163 | ) |
| 164 | |
| 165 | shared_state = self.state_manager.get_state() |
| 166 | |
| 167 | if shared_state.should_compute: |
| 168 | output = self.fn_ref.original_forward(*args, **kwargs) |
| 169 | if self.is_tail: |
| 170 | hidden_states_residual = encoder_hidden_states_residual = None |
| 171 | if isinstance(output, tuple): |
| 172 | hidden_states_residual = ( |
| 173 | output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0] |
| 174 | ) |
| 175 | encoder_hidden_states_residual = ( |
| 176 | output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1] |
| 177 | ) |
| 178 | else: |
| 179 | hidden_states_residual = output - shared_state.head_block_output |
| 180 | shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual) |
| 181 | return output |
| 182 | |
| 183 | if original_encoder_hidden_states is None: |
| 184 | return_output = original_hidden_states |
| 185 | else: |
| 186 | return_output = [None, None] |
| 187 | return_output[self._metadata.return_hidden_states_index] = original_hidden_states |
| 188 | return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states |
| 189 | return_output = tuple(return_output) |
| 190 | return return_output |
| 191 | |
| 192 | |
| 193 | def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None: |
no outgoing calls
no test coverage detected
searching dependent graphs…