Load the optimizer state from the local file system or remote object storage Service (OSS). Args: optim (Optimizer): optimizer folder (str): The FS/OSS path where the optimizer will be stored.
(folder, optim)
| 169 | |
| 170 | |
| 171 | def load_optimizer_checkpoint(folder, optim): |
| 172 | """Load the optimizer state from the local file system or remote |
| 173 | object storage Service (OSS). |
| 174 | |
| 175 | Args: |
| 176 | optim (Optimizer): optimizer |
| 177 | folder (str): The FS/OSS path where the optimizer will be stored. |
| 178 | """ |
| 179 | |
| 180 | fns = get_fns(folder) |
| 181 | max_tp, max_pp, max_zero = 0, 0, 0 |
| 182 | for fn in fns: |
| 183 | if fn.startswith("optimizer_") and not fn.endswith(".md5"): |
| 184 | _, tp, pp, zero = os.path.splitext(fn)[0].split("_") |
| 185 | max_zero = max(max_zero, int(zero[2:])) |
| 186 | max_tp = max(max_tp, int(tp[2:])) |
| 187 | max_pp = max(max_pp, int(pp[2:])) |
| 188 | |
| 189 | zero_size = gpc.get_world_size(ParallelMode.ZERO1) |
| 190 | zero_rank = gpc.get_local_rank(ParallelMode.ZERO1) |
| 191 | tp_size = gpc.get_world_size(ParallelMode.TENSOR) |
| 192 | pp_size = gpc.get_world_size(ParallelMode.PIPELINE) |
| 193 | |
| 194 | assert ( |
| 195 | zero_size == max_zero + 1 |
| 196 | ), f"The weights are save for {max_zero+1} data parallel, while current has {zero_size} zero broadcast range." |
| 197 | assert ( |
| 198 | pp_size == max_pp + 1 |
| 199 | ), f"The weights are save for {max_pp+1} pipelines, while current has {pp_size} pipelines" |
| 200 | assert ( |
| 201 | tp_size == max_tp + 1 |
| 202 | ), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism" |
| 203 | |
| 204 | fp = f"optimizer_tp{gpc.get_local_rank(ParallelMode.TENSOR)}_" |
| 205 | fp += f"pp{gpc.get_local_rank(ParallelMode.PIPELINE)}_" |
| 206 | fp += f"zo{zero_rank}.pt" |
| 207 | states = llm_load(os.path.join(folder, fp), map_location=get_current_device()) |
| 208 | |
| 209 | if isinstance(optim, HybridZeroOptimizer): |
| 210 | fp_meta = os.path.join(folder, optim.rank_unique_id) |
| 211 | try: |
| 212 | zero_devide_optim_plan = llm_load(fp_meta) |
| 213 | states.update({"zero_devide_optim_plan": zero_devide_optim_plan}) |
| 214 | except Exception as e: |
| 215 | logger.warning( |
| 216 | f"Read zero optimzer split file '{fp_meta}', for '{e}'" |
| 217 | f"Please check whether loading ckpts are saved with the HybridZeroOptimizer." |
| 218 | ) |
| 219 | |
| 220 | optim.load_state_dict(states) |
| 221 | del states |
| 222 | torch.cuda.empty_cache() |
| 223 | |
| 224 | |
| 225 | def load_sampler(ckpt_path: str, sampler): |
no test coverage detected