MCPcopy Index your code
hub / github.com/huggingface/diffusers / enable_parallelism

Method enable_parallelism

src/diffusers/models/modeling_utils.py:1585–1670  ·  view source on GitHub ↗
(
        self,
        *,
        config: ParallelConfig | ContextParallelConfig,
        cp_plan: dict[str, ContextParallelModelPlan] | None = None,
    )

Source from the content-addressed store, hash-verified

1583 )
1584
1585 def enable_parallelism(
1586 self,
1587 *,
1588 config: ParallelConfig | ContextParallelConfig,
1589 cp_plan: dict[str, ContextParallelModelPlan] | None = None,
1590 ):
1591 logger.warning(
1592 "`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
1593 )
1594
1595 if not torch.distributed.is_available() and not torch.distributed.is_initialized():
1596 raise RuntimeError(
1597 "torch.distributed must be available and initialized before calling `enable_parallelism`."
1598 )
1599
1600 from ..hooks.context_parallel import apply_context_parallel
1601 from .attention import AttentionModuleMixin
1602 from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
1603 from .attention_processor import Attention, MochiAttention
1604
1605 if isinstance(config, ContextParallelConfig):
1606 config = ParallelConfig(context_parallel_config=config)
1607
1608 rank = torch.distributed.get_rank()
1609 world_size = torch.distributed.get_world_size()
1610 device_type = torch._C._get_accelerator().type
1611 device_module = torch.get_device_module(device_type)
1612 device = torch.device(device_type, rank % device_module.device_count())
1613
1614 attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
1615
1616 if config.context_parallel_config is not None:
1617 for module in self.modules():
1618 if not isinstance(module, attention_classes):
1619 continue
1620
1621 processor = module.processor
1622 if processor is None or not hasattr(processor, "_attention_backend"):
1623 continue
1624
1625 attention_backend = processor._attention_backend
1626 if attention_backend is None:
1627 attention_backend, _ = _AttentionBackendRegistry.get_active_backend()
1628 else:
1629 attention_backend = AttentionBackendName(attention_backend)
1630
1631 if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend):
1632 compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
1633 raise ValueError(
1634 f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
1635 f"is using backend '{attention_backend.value}' which does not support context parallelism. "
1636 f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
1637 f"calling `model.enable_parallelism()`."
1638 )
1639
1640 # All modules use the same attention processor and backend. We don't need to
1641 # iterate over all modules after checking the first processor
1642 break

Callers 4

from_pretrainedMethod · 0.80
_context_parallel_workerFunction · 0.80
_custom_mesh_workerFunction · 0.80

Calls 7

setupMethod · 0.95
apply_context_parallelFunction · 0.85
get_active_backendMethod · 0.80
ParallelConfigClass · 0.70
deviceMethod · 0.45

Tested by

no test coverage detected