(model: nn.Module, **kwargs)
| 37 | |
| 38 | |
| 39 | def FSDP2Wrapper(model: nn.Module, **kwargs) -> nn.Module: |
| 40 | sampler_classes = set( |
| 41 | list(GradSampleModuleFastGradientClippingFSDP.GRAD_SAMPLERS.keys()) |
| 42 | + list(GradSampleModuleFastGradientClippingFSDP.NORM_SAMPLERS.keys()) |
| 43 | ) |
| 44 | mp_policy = kwargs.get("mp_policy", MixedPrecisionPolicy()) |
| 45 | opacus_high_precision_layers = kwargs.get("opacus_high_precision_layers", []) |
| 46 | for module in iterate_submodules(model): |
| 47 | if (type(module) in sampler_classes) or (not has_trainable_params(module)): |
| 48 | if len(opacus_high_precision_layers) > 0 and isinstance( |
| 49 | module, opacus_high_precision_layers |
| 50 | ): |
| 51 | # For certain layers, higher precision is needed to stablize the training of DP-SGD. |
| 52 | fully_shard( |
| 53 | module, |
| 54 | mp_policy=MixedPrecisionPolicy( |
| 55 | param_dtype=torch.get_default_dtype() |
| 56 | ), |
| 57 | ) |
| 58 | else: |
| 59 | fully_shard(module, mp_policy=mp_policy) |
| 60 | model = fully_shard(model, mp_policy=mp_policy) |
| 61 | return model |
no test coverage detected