Wrap a model with FSDP using common defaults and optional transformer auto-wrapping. Args: model: Model to wrap device: Target device (e.g., accelerator.device) offload: Whether to enable CPU parameter offloading use_orig_params: Whether to use original para
(
model: torch.nn.Module,
device: str | torch.device,
offload: bool = True,
use_orig_params: bool = True,
limit_all_gathers: bool = True,
fsdp_kwargs: dict[str, Any] | None = None,
transformer_layer_cls: set[type[torch.nn.Module]] | None = None,
)
| 518 | |
| 519 | |
| 520 | def wrap_with_fsdp( |
| 521 | model: torch.nn.Module, |
| 522 | device: str | torch.device, |
| 523 | offload: bool = True, |
| 524 | use_orig_params: bool = True, |
| 525 | limit_all_gathers: bool = True, |
| 526 | fsdp_kwargs: dict[str, Any] | None = None, |
| 527 | transformer_layer_cls: set[type[torch.nn.Module]] | None = None, |
| 528 | ) -> FSDP: |
| 529 | """ |
| 530 | Wrap a model with FSDP using common defaults and optional transformer auto-wrapping. |
| 531 | |
| 532 | Args: |
| 533 | model: Model to wrap |
| 534 | device: Target device (e.g., accelerator.device) |
| 535 | offload: Whether to enable CPU parameter offloading |
| 536 | use_orig_params: Whether to use original parameters |
| 537 | limit_all_gathers: Whether to limit all gathers |
| 538 | fsdp_kwargs: FSDP arguments (sharding_strategy, etc.) — usually from Accelerate config |
| 539 | transformer_layer_cls: Classes for auto-wrapping (if not using policy from fsdp_kwargs) |
| 540 | |
| 541 | Returns: |
| 542 | FSDP-wrapped model |
| 543 | """ |
| 544 | |
| 545 | logger = get_logger(__name__) |
| 546 | |
| 547 | if transformer_layer_cls is None: |
| 548 | # Set the default layers if transformer_layer_cls is not provided |
| 549 | transformer_layer_cls = type(model.model.language_model.layers[0]) |
| 550 | logger.info(f"transformer_layer_cls is not provided, auto-inferred as {transformer_layer_cls.__name__}") |
| 551 | |
| 552 | # Add auto-wrap policy if transformer layers specified |
| 553 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={transformer_layer_cls}) |
| 554 | |
| 555 | config = { |
| 556 | "device_id": device, |
| 557 | "cpu_offload": CPUOffload(offload_params=offload) if offload else None, |
| 558 | "use_orig_params": use_orig_params, |
| 559 | "limit_all_gathers": limit_all_gathers, |
| 560 | "auto_wrap_policy": auto_wrap_policy, |
| 561 | } |
| 562 | |
| 563 | if fsdp_kwargs: |
| 564 | config.update(fsdp_kwargs) |
| 565 | |
| 566 | fsdp_model = FSDP(model, **config) |
| 567 | return fsdp_model |
| 568 | |
| 569 | |
| 570 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 |