(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train_state: TrainState)
| 242 | |
| 243 | |
| 244 | def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train_state: TrainState): |
| 245 | scheduler_states = llm_load(os.path.join(ckpt_path, "schedulder.pt")) |
| 246 | if learning_rate != scheduler_states["base_lrs"][0] and gpc.is_rank_for_log(): |
| 247 | logger.warning( |
| 248 | f"Using new learning rate {learning_rate} to replace old learn rate {scheduler_states['base_lrs'][0]}." |
| 249 | ) |
| 250 | |
| 251 | base_lrs = copy.deepcopy(scheduler_states["base_lrs"]) |
| 252 | scheduler_states["base_lrs"] = [learning_rate] * len(scheduler_states["base_lrs"]) |
| 253 | if "after_scheduler_dict" in scheduler_states: |
| 254 | scheduler_states["after_scheduler_dict"]["base_lrs"] = [learning_rate] * len( |
| 255 | scheduler_states["after_scheduler_dict"]["base_lrs"] |
| 256 | ) |
| 257 | |
| 258 | lr_scheduler.load_state_dict(scheduler_states) |
| 259 | lr_scheduler.last_epoch = train_state.step_count + 1 |
| 260 | |
| 261 | ratios = [learning_rate / lr for lr in base_lrs] |
| 262 | for idx, param_group in enumerate(optimizer.param_groups): |
| 263 | param_group["lr"] = param_group["lr"] * ratios[idx] |
| 264 | torch.cuda.empty_cache() |
| 265 | |
| 266 | if gpc.is_rank_for_log(): |
| 267 | logger.info(f"reload load_scheduler:{lr_scheduler}") |
| 268 | |
| 269 | |
| 270 | class CheckpointManager: |
no test coverage detected