MCPcopy Index your code
hub / github.com/huggingface/diffusers / wrap_with_fsdp

Function wrap_with_fsdp

src/diffusers/training_utils.py:520–567  ·  view source on GitHub ↗

Wrap a model with FSDP using common defaults and optional transformer auto-wrapping. Args: model: Model to wrap device: Target device (e.g., accelerator.device) offload: Whether to enable CPU parameter offloading use_orig_params: Whether to use original para

(
    model: torch.nn.Module,
    device: str | torch.device,
    offload: bool = True,
    use_orig_params: bool = True,
    limit_all_gathers: bool = True,
    fsdp_kwargs: dict[str, Any] | None = None,
    transformer_layer_cls: set[type[torch.nn.Module]] | None = None,
)

Source from the content-addressed store, hash-verified

518
519
520def wrap_with_fsdp(
521 model: torch.nn.Module,
522 device: str | torch.device,
523 offload: bool = True,
524 use_orig_params: bool = True,
525 limit_all_gathers: bool = True,
526 fsdp_kwargs: dict[str, Any] | None = None,
527 transformer_layer_cls: set[type[torch.nn.Module]] | None = None,
528) -> FSDP:
529 """
530 Wrap a model with FSDP using common defaults and optional transformer auto-wrapping.
531
532 Args:
533 model: Model to wrap
534 device: Target device (e.g., accelerator.device)
535 offload: Whether to enable CPU parameter offloading
536 use_orig_params: Whether to use original parameters
537 limit_all_gathers: Whether to limit all gathers
538 fsdp_kwargs: FSDP arguments (sharding_strategy, etc.) — usually from Accelerate config
539 transformer_layer_cls: Classes for auto-wrapping (if not using policy from fsdp_kwargs)
540
541 Returns:
542 FSDP-wrapped model
543 """
544
545 logger = get_logger(__name__)
546
547 if transformer_layer_cls is None:
548 # Set the default layers if transformer_layer_cls is not provided
549 transformer_layer_cls = type(model.model.language_model.layers[0])
550 logger.info(f"transformer_layer_cls is not provided, auto-inferred as {transformer_layer_cls.__name__}")
551
552 # Add auto-wrap policy if transformer layers specified
553 auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={transformer_layer_cls})
554
555 config = {
556 "device_id": device,
557 "cpu_offload": CPUOffload(offload_params=offload) if offload else None,
558 "use_orig_params": use_orig_params,
559 "limit_all_gathers": limit_all_gathers,
560 "auto_wrap_policy": auto_wrap_policy,
561 }
562
563 if fsdp_kwargs:
564 config.update(fsdp_kwargs)
565
566 fsdp_model = FSDP(model, **config)
567 return fsdp_model
568
569
570# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14

Callers 5

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls 3

get_loggerFunction · 0.90
infoMethod · 0.80
updateMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…