Load a PyTorch file without redundant fetches across MPI ranks.
(path, **kwargs)
| 51 | |
| 52 | |
| 53 | def load_state_dict(path, **kwargs): |
| 54 | """ |
| 55 | Load a PyTorch file without redundant fetches across MPI ranks. |
| 56 | """ |
| 57 | if MPI.COMM_WORLD.Get_rank() == 0: |
| 58 | with bf.BlobFile(path, "rb") as f: |
| 59 | data = f.read() |
| 60 | else: |
| 61 | data = None |
| 62 | data = MPI.COMM_WORLD.bcast(data) |
| 63 | return th.load(io.BytesIO(data), **kwargs) |
| 64 | |
| 65 | |
| 66 | def sync_params(params): |