MCPcopy
hub / github.com/meta-pytorch/opacus / FSDP2Wrapper

Function FSDP2Wrapper

opacus/utils/fsdp_utils.py:39–61  ·  view source on GitHub ↗
(model: nn.Module, **kwargs)

Source from the content-addressed store, hash-verified

37
38
39def FSDP2Wrapper(model: nn.Module, **kwargs) -> nn.Module:
40 sampler_classes = set(
41 list(GradSampleModuleFastGradientClippingFSDP.GRAD_SAMPLERS.keys())
42 + list(GradSampleModuleFastGradientClippingFSDP.NORM_SAMPLERS.keys())
43 )
44 mp_policy = kwargs.get("mp_policy", MixedPrecisionPolicy())
45 opacus_high_precision_layers = kwargs.get("opacus_high_precision_layers", [])
46 for module in iterate_submodules(model):
47 if (type(module) in sampler_classes) or (not has_trainable_params(module)):
48 if len(opacus_high_precision_layers) > 0 and isinstance(
49 module, opacus_high_precision_layers
50 ):
51 # For certain layers, higher precision is needed to stablize the training of DP-SGD.
52 fully_shard(
53 module,
54 mp_policy=MixedPrecisionPolicy(
55 param_dtype=torch.get_default_dtype()
56 ),
57 )
58 else:
59 fully_shard(module, mp_policy=mp_policy)
60 model = fully_shard(model, mp_policy=mp_policy)
61 return model

Callers 3

demo_basicFunction · 0.90
init_trainingFunction · 0.90
trainFunction · 0.90

Calls 2

has_trainable_paramsFunction · 0.90
iterate_submodulesFunction · 0.85

Tested by

no test coverage detected