| 76 | return rank, local_rank, world_size |
| 77 | |
| 78 | def load_ckpt(load_from_location, expected_hash=None): |
| 79 | os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' #Disable this to speed up debugging errors with downloading from the hub |
| 80 | if local0(): |
| 81 | repo_id = "si-pbc/hertz-dev" |
| 82 | print0(f'Loading checkpoint from repo_id {repo_id} and filename {load_from_location}.pt. This may take a while...') |
| 83 | save_path = hf_hub_download(repo_id=repo_id, filename=f"{load_from_location}.pt") |
| 84 | print0(f'Downloaded checkpoint to {save_path}') |
| 85 | if expected_hash is not None: |
| 86 | with open(save_path, 'rb') as f: |
| 87 | file_hash = hashlib.md5(f.read()).hexdigest() |
| 88 | if file_hash != expected_hash: |
| 89 | print(f'Hash mismatch for {save_path}. Expected {expected_hash} but got {file_hash}. Deleting checkpoint and trying again.') |
| 90 | os.remove(save_path) |
| 91 | return load_ckpt(load_from_location, expected_hash) |
| 92 | if T.distributed.is_initialized(): |
| 93 | save_path = [save_path] |
| 94 | T.distributed.broadcast_object_list(save_path, src=0) |
| 95 | save_path = save_path[0] |
| 96 | loaded = T.load(save_path, weights_only=False, map_location='cpu') |
| 97 | print0(f'Loaded checkpoint from {save_path}') |
| 98 | return loaded |