Prepare a model for distributed training. This function currently applies FSDP2 to the model, initializes parameters, and optionally applies activation checkpointing. It handles parameter consistency between the optimizer and model when an optimizer is provided. Args: model
(
model: torch.nn.Module,
config: FSDP2Config,
optimizer: Optional[torch.optim.Optimizer] = None,
precision: Optional[Union[str, Precision]] = None,
fsdp_wrap_policy: Optional[CustomPolicy] = None,
activation_checkpointing_check_fn: Optional[Callable] = None,
param_init_fn: Callable[[torch.nn.Module], None] = lambda m: None,
)
| 136 | |
| 137 | |
| 138 | def parallelize_model( |
| 139 | model: torch.nn.Module, |
| 140 | config: FSDP2Config, |
| 141 | optimizer: Optional[torch.optim.Optimizer] = None, |
| 142 | precision: Optional[Union[str, Precision]] = None, |
| 143 | fsdp_wrap_policy: Optional[CustomPolicy] = None, |
| 144 | activation_checkpointing_check_fn: Optional[Callable] = None, |
| 145 | param_init_fn: Callable[[torch.nn.Module], None] = lambda m: None, |
| 146 | ): |
| 147 | """Prepare a model for distributed training. |
| 148 | |
| 149 | This function currently applies FSDP2 to the model, initializes parameters, |
| 150 | and optionally applies activation checkpointing. It handles parameter consistency between the optimizer |
| 151 | and model when an optimizer is provided. |
| 152 | |
| 153 | Args: |
| 154 | model (torch.nn.Module): The model to prepare for distributed training. |
| 155 | config (FSDP2Config): The configuration for FSDP distributed training. |
| 156 | optimizer (Optional[torch.optim.Optimizer]): The optimizer to synchronize with the model. |
| 157 | If provided, parameter states will be properly synchronized during sharding. |
| 158 | precision (Precision): The precision to use for the model. Defaults to AMP_FP16 for GPU and FP32 for CPU. |
| 159 | It doesn't have an optional type because `parallelize_composer_model` already sets the precision. |
| 160 | fsdp_wrap_policy (Optional[CustomPolicy]): Custom policy to determine which modules should |
| 161 | be wrapped with FSDP. If None, default wrapping behavior is used. |
| 162 | activation_checkpointing_check_fn (Optional[Callable]): Function that determines whether a |
| 163 | module's activations should be checkpointed or offloaded. Only used when activation |
| 164 | checkpointing or CPU offloading is enabled in the config. |
| 165 | param_init_fn (Callable[[torch.nn.Module], None]): Function to initialize model parameters |
| 166 | after FSDP wrapping. Defaults to a no-op function. |
| 167 | |
| 168 | Raises: |
| 169 | ValueError: If the config is not an FSDP2Config or if activation_checkpointing_check_fn is provided |
| 170 | but neither activation_checkpointing nor activation_cpu_offload is enabled in the config. |
| 171 | """ |
| 172 | if precision is None: |
| 173 | precision = Precision.AMP_FP16 if isinstance(get_device(), DeviceGPU) else Precision.FP32 |
| 174 | elif isinstance(precision, str): |
| 175 | precision = Precision(precision) |
| 176 | _validate_precision(precision, get_device()) |
| 177 | |
| 178 | if activation_checkpointing_check_fn is not None: |
| 179 | if not config.activation_checkpointing and not config.activation_cpu_offload: |
| 180 | raise ValueError( |
| 181 | 'Activation checkpointing or offloading must be enabled if activation_checkpointing_check_fn is provided', |
| 182 | ) |
| 183 | |
| 184 | if config.activation_checkpointing or config.activation_cpu_offload: |
| 185 | apply_ac( |
| 186 | model, |
| 187 | config.activation_checkpointing, |
| 188 | config.activation_cpu_offload, |
| 189 | activation_checkpointing_check_fn, |
| 190 | ) |
| 191 | |
| 192 | # Use the context manager for optimizer synchronization if optimizer is provided |
| 193 | with sync_optimizer_and_model_params(optimizer, model) if optimizer is not None else nullcontext(): |
| 194 | _parallelize_model_helper(model, config, precision, fsdp_wrap_policy, param_init_fn) |
| 195 | # NOTE appy_ac can not be included in this context as it would wrap and replace the sub-modules thus disqualify FQN of params |