Applies [First Block Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching) to a given module. First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It
(module: torch.nn.Module, config: FirstBlockCacheConfig)
| 191 | |
| 192 | |
| 193 | def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None: |
| 194 | """ |
| 195 | Applies [First Block |
| 196 | Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching) |
| 197 | to a given module. |
| 198 | |
| 199 | First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler |
| 200 | to implement generically for a wide range of models and has been integrated first for experimental purposes. |
| 201 | |
| 202 | Args: |
| 203 | module (`torch.nn.Module`): |
| 204 | The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in |
| 205 | Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work. |
| 206 | config (`FirstBlockCacheConfig`): |
| 207 | The configuration to use for applying the FBCache method. |
| 208 | |
| 209 | Example: |
| 210 | ```python |
| 211 | >>> import torch |
| 212 | >>> from diffusers import CogView4Pipeline |
| 213 | >>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig |
| 214 | |
| 215 | >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) |
| 216 | >>> pipe.to("cuda") |
| 217 | |
| 218 | >>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2)) |
| 219 | |
| 220 | >>> prompt = "A photo of an astronaut riding a horse on mars" |
| 221 | >>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0] |
| 222 | >>> image.save("output.png") |
| 223 | ``` |
| 224 | """ |
| 225 | |
| 226 | state_manager = StateManager(FBCSharedBlockState, (), {}) |
| 227 | remaining_blocks = [] |
| 228 | |
| 229 | for name, submodule in module.named_children(): |
| 230 | if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): |
| 231 | continue |
| 232 | for index, block in enumerate(submodule): |
| 233 | remaining_blocks.append((f"{name}.{index}", block)) |
| 234 | |
| 235 | head_block_name, head_block = remaining_blocks.pop(0) |
| 236 | tail_block_name, tail_block = remaining_blocks.pop(-1) |
| 237 | |
| 238 | logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'") |
| 239 | _apply_fbc_head_block_hook(head_block, state_manager, config.threshold) |
| 240 | |
| 241 | for name, block in remaining_blocks: |
| 242 | logger.debug(f"Applying FBCBlockHook to '{name}'") |
| 243 | _apply_fbc_block_hook(block, state_manager) |
| 244 | |
| 245 | logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'") |
| 246 | _apply_fbc_block_hook(tail_block, state_manager, is_tail=True) |
| 247 | |
| 248 | |
| 249 | def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None: |
nothing calls this directly
no test coverage detected
searching dependent graphs…