Apply context parallel on a model.
(
module: torch.nn.Module,
parallel_config: ContextParallelConfig,
plan: dict[str, ContextParallelModelPlan],
)
| 78 | |
| 79 | |
| 80 | def apply_context_parallel( |
| 81 | module: torch.nn.Module, |
| 82 | parallel_config: ContextParallelConfig, |
| 83 | plan: dict[str, ContextParallelModelPlan], |
| 84 | ) -> None: |
| 85 | """Apply context parallel on a model.""" |
| 86 | logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}") |
| 87 | |
| 88 | for module_id, cp_model_plan in plan.items(): |
| 89 | submodule = _get_submodule_by_name(module, module_id) |
| 90 | if not isinstance(submodule, list): |
| 91 | submodule = [submodule] |
| 92 | |
| 93 | logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules") |
| 94 | |
| 95 | for m in submodule: |
| 96 | if isinstance(cp_model_plan, dict): |
| 97 | hook = ContextParallelSplitHook(cp_model_plan, parallel_config) |
| 98 | hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id) |
| 99 | elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)): |
| 100 | if isinstance(cp_model_plan, ContextParallelOutput): |
| 101 | cp_model_plan = [cp_model_plan] |
| 102 | if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan): |
| 103 | raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}") |
| 104 | hook = ContextParallelGatherHook(cp_model_plan, parallel_config) |
| 105 | hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id) |
| 106 | else: |
| 107 | raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}") |
| 108 | registry = HookRegistry.check_if_exists_or_initialize(m) |
| 109 | registry.register_hook(hook, hook_name) |
| 110 | |
| 111 | |
| 112 | def remove_context_parallel(module: torch.nn.Module, plan: dict[str, ContextParallelModelPlan]) -> None: |
no test coverage detected
searching dependent graphs…