MCPcopy Index your code
hub / github.com/Standard-Intelligence/hertz-dev / load_ckpt

Function load_ckpt

utils/dist.py:78–98  ·  view source on GitHub ↗
(load_from_location, expected_hash=None)

Source from the content-addressed store, hash-verified

76 return rank, local_rank, world_size
77
78def load_ckpt(load_from_location, expected_hash=None):
79 os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' #Disable this to speed up debugging errors with downloading from the hub
80 if local0():
81 repo_id = "si-pbc/hertz-dev"
82 print0(f'Loading checkpoint from repo_id {repo_id} and filename {load_from_location}.pt. This may take a while...')
83 save_path = hf_hub_download(repo_id=repo_id, filename=f"{load_from_location}.pt")
84 print0(f'Downloaded checkpoint to {save_path}')
85 if expected_hash is not None:
86 with open(save_path, 'rb') as f:
87 file_hash = hashlib.md5(f.read()).hexdigest()
88 if file_hash != expected_hash:
89 print(f'Hash mismatch for {save_path}. Expected {expected_hash} but got {file_hash}. Deleting checkpoint and trying again.')
90 os.remove(save_path)
91 return load_ckpt(load_from_location, expected_hash)
92 if T.distributed.is_initialized():
93 save_path = [save_path]
94 T.distributed.broadcast_object_list(save_path, src=0)
95 save_path = save_path[0]
96 loaded = T.load(save_path, weights_only=False, map_location='cpu')
97 print0(f'Loaded checkpoint from {save_path}')
98 return loaded

Callers 5

make_tokenizerFunction · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90

Calls 2

local0Function · 0.85
print0Function · 0.85

Tested by

no test coverage detected