MCPcopy
hub / github.com/mosaicml/composer / parallelize_model

Function parallelize_model

composer/distributed/prepare_distributed.py:138–195  ·  view source on GitHub ↗

Prepare a model for distributed training. This function currently applies FSDP2 to the model, initializes parameters, and optionally applies activation checkpointing. It handles parameter consistency between the optimizer and model when an optimizer is provided. Args: model

(
    model: torch.nn.Module,
    config: FSDP2Config,
    optimizer: Optional[torch.optim.Optimizer] = None,
    precision: Optional[Union[str, Precision]] = None,
    fsdp_wrap_policy: Optional[CustomPolicy] = None,
    activation_checkpointing_check_fn: Optional[Callable] = None,
    param_init_fn: Callable[[torch.nn.Module], None] = lambda m: None,
)

Source from the content-addressed store, hash-verified

136
137
138def parallelize_model(
139 model: torch.nn.Module,
140 config: FSDP2Config,
141 optimizer: Optional[torch.optim.Optimizer] = None,
142 precision: Optional[Union[str, Precision]] = None,
143 fsdp_wrap_policy: Optional[CustomPolicy] = None,
144 activation_checkpointing_check_fn: Optional[Callable] = None,
145 param_init_fn: Callable[[torch.nn.Module], None] = lambda m: None,
146):
147 """Prepare a model for distributed training.
148
149 This function currently applies FSDP2 to the model, initializes parameters,
150 and optionally applies activation checkpointing. It handles parameter consistency between the optimizer
151 and model when an optimizer is provided.
152
153 Args:
154 model (torch.nn.Module): The model to prepare for distributed training.
155 config (FSDP2Config): The configuration for FSDP distributed training.
156 optimizer (Optional[torch.optim.Optimizer]): The optimizer to synchronize with the model.
157 If provided, parameter states will be properly synchronized during sharding.
158 precision (Precision): The precision to use for the model. Defaults to AMP_FP16 for GPU and FP32 for CPU.
159 It doesn't have an optional type because `parallelize_composer_model` already sets the precision.
160 fsdp_wrap_policy (Optional[CustomPolicy]): Custom policy to determine which modules should
161 be wrapped with FSDP. If None, default wrapping behavior is used.
162 activation_checkpointing_check_fn (Optional[Callable]): Function that determines whether a
163 module's activations should be checkpointed or offloaded. Only used when activation
164 checkpointing or CPU offloading is enabled in the config.
165 param_init_fn (Callable[[torch.nn.Module], None]): Function to initialize model parameters
166 after FSDP wrapping. Defaults to a no-op function.
167
168 Raises:
169 ValueError: If the config is not an FSDP2Config or if activation_checkpointing_check_fn is provided
170 but neither activation_checkpointing nor activation_cpu_offload is enabled in the config.
171 """
172 if precision is None:
173 precision = Precision.AMP_FP16 if isinstance(get_device(), DeviceGPU) else Precision.FP32
174 elif isinstance(precision, str):
175 precision = Precision(precision)
176 _validate_precision(precision, get_device())
177
178 if activation_checkpointing_check_fn is not None:
179 if not config.activation_checkpointing and not config.activation_cpu_offload:
180 raise ValueError(
181 'Activation checkpointing or offloading must be enabled if activation_checkpointing_check_fn is provided',
182 )
183
184 if config.activation_checkpointing or config.activation_cpu_offload:
185 apply_ac(
186 model,
187 config.activation_checkpointing,
188 config.activation_cpu_offload,
189 activation_checkpointing_check_fn,
190 )
191
192 # Use the context manager for optimizer synchronization if optimizer is provided
193 with sync_optimizer_and_model_params(optimizer, model) if optimizer is not None else nullcontext():
194 _parallelize_model_helper(model, config, precision, fsdp_wrap_policy, param_init_fn)
195 # NOTE appy_ac can not be included in this context as it would wrap and replace the sub-modules thus disqualify FQN of params

Calls 6

get_deviceFunction · 0.90
PrecisionClass · 0.90
_validate_precisionFunction · 0.90
apply_acFunction · 0.90