(func)
| 759 | |
| 760 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP |
| 761 | def patch_FSDP_use_orig_params(func): |
| 762 | def wrap_func(*args, **kwargs): |
| 763 | use_orig_params = kwargs.pop('use_orig_params', True) |
| 764 | return func(*args, **kwargs, use_orig_params=use_orig_params) |
| 765 | return wrap_func |
| 766 | |
| 767 | FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__) |
| 768 |