MCPcopy
hub / github.com/apple/ml-mgie / patch_FSDP_use_orig_params

Function patch_FSDP_use_orig_params

mgie_train.py:761–765  ·  view source on GitHub ↗
(func)

Source from the content-addressed store, hash-verified

759
760 from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
761 def patch_FSDP_use_orig_params(func):
762 def wrap_func(*args, **kwargs):
763 use_orig_params = kwargs.pop('use_orig_params', True)
764 return func(*args, **kwargs, use_orig_params=use_orig_params)
765 return wrap_func
766
767 FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
768

Callers 1

trainFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected