MCPcopy Index your code
hub / github.com/huggingface/diffusers / apply_context_parallel

Function apply_context_parallel

src/diffusers/hooks/context_parallel.py:80–109  ·  view source on GitHub ↗

Apply context parallel on a model.

(
    module: torch.nn.Module,
    parallel_config: ContextParallelConfig,
    plan: dict[str, ContextParallelModelPlan],
)

Source from the content-addressed store, hash-verified

78
79
80def apply_context_parallel(
81 module: torch.nn.Module,
82 parallel_config: ContextParallelConfig,
83 plan: dict[str, ContextParallelModelPlan],
84) -> None:
85 """Apply context parallel on a model."""
86 logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
87
88 for module_id, cp_model_plan in plan.items():
89 submodule = _get_submodule_by_name(module, module_id)
90 if not isinstance(submodule, list):
91 submodule = [submodule]
92
93 logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
94
95 for m in submodule:
96 if isinstance(cp_model_plan, dict):
97 hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
98 hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
99 elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
100 if isinstance(cp_model_plan, ContextParallelOutput):
101 cp_model_plan = [cp_model_plan]
102 if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
103 raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
104 hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
105 hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
106 else:
107 raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
108 registry = HookRegistry.check_if_exists_or_initialize(m)
109 registry.register_hook(hook, hook_name)
110
111
112def remove_context_parallel(module: torch.nn.Module, plan: dict[str, ContextParallelModelPlan]) -> None:

Callers 1

enable_parallelismMethod · 0.85

Calls 5

_get_submodule_by_nameFunction · 0.85
register_hookMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…