| 127 | |
| 128 | |
| 129 | class ContextParallelSplitHook(ModelHook): |
| 130 | def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None: |
| 131 | super().__init__() |
| 132 | self.metadata = metadata |
| 133 | self.parallel_config = parallel_config |
| 134 | self.module_forward_metadata = None |
| 135 | |
| 136 | def initialize_hook(self, module): |
| 137 | cls = unwrap_module(module).__class__ |
| 138 | self.module_forward_metadata = ModuleForwardMetadata(_cls=cls) |
| 139 | return module |
| 140 | |
| 141 | def pre_forward(self, module, *args, **kwargs): |
| 142 | args_list = list(args) |
| 143 | |
| 144 | for name, cpm in self.metadata.items(): |
| 145 | if isinstance(cpm, ContextParallelInput) and cpm.split_output: |
| 146 | continue |
| 147 | |
| 148 | # Maybe the parameter was passed as a keyword argument |
| 149 | input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs( |
| 150 | name, args_list, kwargs |
| 151 | ) |
| 152 | |
| 153 | if input_val is None: |
| 154 | continue |
| 155 | |
| 156 | # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard |
| 157 | # the output instead of input for a particular layer by setting split_output=True |
| 158 | if isinstance(input_val, torch.Tensor): |
| 159 | input_val = self._prepare_cp_input(input_val, cpm) |
| 160 | elif isinstance(input_val, (list, tuple)): |
| 161 | if len(input_val) != len(cpm): |
| 162 | raise ValueError( |
| 163 | f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}." |
| 164 | ) |
| 165 | sharded_input_val = [] |
| 166 | for i, x in enumerate(input_val): |
| 167 | if torch.is_tensor(x) and not cpm[i].split_output: |
| 168 | x = self._prepare_cp_input(x, cpm[i]) |
| 169 | sharded_input_val.append(x) |
| 170 | input_val = sharded_input_val |
| 171 | else: |
| 172 | raise ValueError(f"Unsupported input type: {type(input_val)}") |
| 173 | |
| 174 | if is_kwarg: |
| 175 | kwargs[name] = input_val |
| 176 | elif index is not None and index < len(args_list): |
| 177 | args_list[index] = input_val |
| 178 | else: |
| 179 | raise ValueError( |
| 180 | f"An unexpected error occurred while processing the input '{name}'. Please open an " |
| 181 | f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible " |
| 182 | f"example along with the full stack trace." |
| 183 | ) |
| 184 | |
| 185 | return tuple(args_list), kwargs |
| 186 |
no outgoing calls
no test coverage detected
searching dependent graphs…