| 298 | model_state_dict_lst = [None] * total_shards |
| 299 | |
| 300 | def process_one_shard(rank: int, model_state_dict_lst: list): |
| 301 | model_path = Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_{rank}.pt" |
| 302 | state_dict = torch.load(model_path, map_location="cpu", weights_only=False) |
| 303 | model_state_dict_lst[rank] = state_dict |
| 304 | return state_dict |
| 305 | |
| 306 | with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: |
| 307 | futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)] |