()
| 183 | |
| 184 | |
| 185 | def _get_transformers_attn_impl() -> str: |
| 186 | try: |
| 187 | import flash_attn # noqa: F401 |
| 188 | return "flash_attention_2" |
| 189 | except ImportError: |
| 190 | logger.warning( |
| 191 | "flash_attn not installed. Falling back to torch.sdpa. Speedup will be lower. " |
| 192 | "For optimal speedup in Transformers backend, please install: " |
| 193 | "pip install flash-attn --no-build-isolation" |
| 194 | ) |
| 195 | return "sdpa" |
| 196 | |
| 197 | |
| 198 | def _run_transformers(args: argparse.Namespace) -> None: |