MCPcopy Index your code
hub / github.com/huggingface/diffusers / ContextParallelSplitHook

Class ContextParallelSplitHook

src/diffusers/hooks/context_parallel.py:129–217  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

127
128
129class ContextParallelSplitHook(ModelHook):
130 def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
131 super().__init__()
132 self.metadata = metadata
133 self.parallel_config = parallel_config
134 self.module_forward_metadata = None
135
136 def initialize_hook(self, module):
137 cls = unwrap_module(module).__class__
138 self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
139 return module
140
141 def pre_forward(self, module, *args, **kwargs):
142 args_list = list(args)
143
144 for name, cpm in self.metadata.items():
145 if isinstance(cpm, ContextParallelInput) and cpm.split_output:
146 continue
147
148 # Maybe the parameter was passed as a keyword argument
149 input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
150 name, args_list, kwargs
151 )
152
153 if input_val is None:
154 continue
155
156 # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
157 # the output instead of input for a particular layer by setting split_output=True
158 if isinstance(input_val, torch.Tensor):
159 input_val = self._prepare_cp_input(input_val, cpm)
160 elif isinstance(input_val, (list, tuple)):
161 if len(input_val) != len(cpm):
162 raise ValueError(
163 f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
164 )
165 sharded_input_val = []
166 for i, x in enumerate(input_val):
167 if torch.is_tensor(x) and not cpm[i].split_output:
168 x = self._prepare_cp_input(x, cpm[i])
169 sharded_input_val.append(x)
170 input_val = sharded_input_val
171 else:
172 raise ValueError(f"Unsupported input type: {type(input_val)}")
173
174 if is_kwarg:
175 kwargs[name] = input_val
176 elif index is not None and index < len(args_list):
177 args_list[index] = input_val
178 else:
179 raise ValueError(
180 f"An unexpected error occurred while processing the input '{name}'. Please open an "
181 f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
182 f"example along with the full stack trace."
183 )
184
185 return tuple(args_list), kwargs
186

Callers 1

apply_context_parallelFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…