(
self,
*,
config: ParallelConfig | ContextParallelConfig,
cp_plan: dict[str, ContextParallelModelPlan] | None = None,
)
| 1583 | ) |
| 1584 | |
| 1585 | def enable_parallelism( |
| 1586 | self, |
| 1587 | *, |
| 1588 | config: ParallelConfig | ContextParallelConfig, |
| 1589 | cp_plan: dict[str, ContextParallelModelPlan] | None = None, |
| 1590 | ): |
| 1591 | logger.warning( |
| 1592 | "`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning." |
| 1593 | ) |
| 1594 | |
| 1595 | if not torch.distributed.is_available() and not torch.distributed.is_initialized(): |
| 1596 | raise RuntimeError( |
| 1597 | "torch.distributed must be available and initialized before calling `enable_parallelism`." |
| 1598 | ) |
| 1599 | |
| 1600 | from ..hooks.context_parallel import apply_context_parallel |
| 1601 | from .attention import AttentionModuleMixin |
| 1602 | from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry |
| 1603 | from .attention_processor import Attention, MochiAttention |
| 1604 | |
| 1605 | if isinstance(config, ContextParallelConfig): |
| 1606 | config = ParallelConfig(context_parallel_config=config) |
| 1607 | |
| 1608 | rank = torch.distributed.get_rank() |
| 1609 | world_size = torch.distributed.get_world_size() |
| 1610 | device_type = torch._C._get_accelerator().type |
| 1611 | device_module = torch.get_device_module(device_type) |
| 1612 | device = torch.device(device_type, rank % device_module.device_count()) |
| 1613 | |
| 1614 | attention_classes = (Attention, MochiAttention, AttentionModuleMixin) |
| 1615 | |
| 1616 | if config.context_parallel_config is not None: |
| 1617 | for module in self.modules(): |
| 1618 | if not isinstance(module, attention_classes): |
| 1619 | continue |
| 1620 | |
| 1621 | processor = module.processor |
| 1622 | if processor is None or not hasattr(processor, "_attention_backend"): |
| 1623 | continue |
| 1624 | |
| 1625 | attention_backend = processor._attention_backend |
| 1626 | if attention_backend is None: |
| 1627 | attention_backend, _ = _AttentionBackendRegistry.get_active_backend() |
| 1628 | else: |
| 1629 | attention_backend = AttentionBackendName(attention_backend) |
| 1630 | |
| 1631 | if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend): |
| 1632 | compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel) |
| 1633 | raise ValueError( |
| 1634 | f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' " |
| 1635 | f"is using backend '{attention_backend.value}' which does not support context parallelism. " |
| 1636 | f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before " |
| 1637 | f"calling `model.enable_parallelism()`." |
| 1638 | ) |
| 1639 | |
| 1640 | # All modules use the same attention processor and backend. We don't need to |
| 1641 | # iterate over all modules after checking the first processor |
| 1642 | break |
no test coverage detected