There should be weights with names similar to the following under the folder. - folder - model_tp{tp_rank}_pp{pp_rank}.pt If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
(folder, model)
| 97 | |
| 98 | |
| 99 | def load_model_checkpoint(folder, model): |
| 100 | """ |
| 101 | There should be weights with names similar to the following under the folder. |
| 102 | - folder |
| 103 | - model_tp{tp_rank}_pp{pp_rank}.pt |
| 104 | |
| 105 | If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading. |
| 106 | """ |
| 107 | |
| 108 | tp_size = gpc.get_world_size(ParallelMode.TENSOR) |
| 109 | pp_size = gpc.get_world_size(ParallelMode.PIPELINE) |
| 110 | tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) |
| 111 | pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) |
| 112 | |
| 113 | fns = get_fns(folder) |
| 114 | max_pp, max_tp = 0, 0 |
| 115 | for fn in fns: |
| 116 | if fn.startswith("model_t") and not fn.endswith(".md5"): |
| 117 | segements = os.path.splitext(fn)[0].split("_") |
| 118 | max_pp = max(max_pp, int(segements[-1][2:])) |
| 119 | max_tp = max(max_tp, int(segements[-2][2:])) |
| 120 | |
| 121 | assert ( |
| 122 | pp_size == max_pp + 1 |
| 123 | ), f"The weights are save for {max_pp+1} pipelines, while current has {pp_size} pipelines" |
| 124 | assert ( |
| 125 | tp_size == max_tp + 1 |
| 126 | ), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism" |
| 127 | |
| 128 | should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt" |
| 129 | fp = os.path.join(folder, should_load_name) |
| 130 | states = llm_load(fp, map_location=get_current_device()) |
| 131 | |
| 132 | missing_k, unexpected_keys = model.load_state_dict(states, strict=False) |
| 133 | if len(missing_k) != 0: |
| 134 | logger.warning(f"Warning: missing keys {missing_k}") |
| 135 | if len(unexpected_keys) != 0: |
| 136 | logger.warning(f"Warning: unexpected keys {unexpected_keys}") |
| 137 | |
| 138 | # avoid to cuda oom, Ref: https://discuss.pytorch.org/t/load-state-dict-causes-memory-leak/36189/11 |
| 139 | del states |
| 140 | torch.cuda.empty_cache() |
| 141 | |
| 142 | |
| 143 | def save_optimizer_checkpoint(optim, state_path): |
no test coverage detected