Unified API to get any scheduler from its name. Args: name (`str` or `SchedulerType`): The name of the scheduler to use. optimizer (`torch.optim.Optimizer`): The optimizer that will be used during training. step_rules (`str`, *optional*):
(
name: str | SchedulerType,
optimizer: Optimizer,
step_rules: str | None = None,
num_warmup_steps: int | None = None,
num_training_steps: int | None = None,
num_cycles: int = 1,
power: float = 1.0,
last_epoch: int = -1,
)
| 286 | |
| 287 | |
| 288 | def get_scheduler( |
| 289 | name: str | SchedulerType, |
| 290 | optimizer: Optimizer, |
| 291 | step_rules: str | None = None, |
| 292 | num_warmup_steps: int | None = None, |
| 293 | num_training_steps: int | None = None, |
| 294 | num_cycles: int = 1, |
| 295 | power: float = 1.0, |
| 296 | last_epoch: int = -1, |
| 297 | ) -> LambdaLR: |
| 298 | """ |
| 299 | Unified API to get any scheduler from its name. |
| 300 | |
| 301 | Args: |
| 302 | name (`str` or `SchedulerType`): |
| 303 | The name of the scheduler to use. |
| 304 | optimizer (`torch.optim.Optimizer`): |
| 305 | The optimizer that will be used during training. |
| 306 | step_rules (`str`, *optional*): |
| 307 | A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler. |
| 308 | num_warmup_steps (`int`, *optional*): |
| 309 | The number of warmup steps to do. This is not required by all schedulers (hence the argument being |
| 310 | optional), the function will raise an error if it's unset and the scheduler type requires it. |
| 311 | num_training_steps (`int``, *optional*): |
| 312 | The number of training steps to do. This is not required by all schedulers (hence the argument being |
| 313 | optional), the function will raise an error if it's unset and the scheduler type requires it. |
| 314 | num_cycles (`int`, *optional*): |
| 315 | The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. |
| 316 | power (`float`, *optional*, defaults to 1.0): |
| 317 | Power factor. See `POLYNOMIAL` scheduler |
| 318 | last_epoch (`int`, *optional*, defaults to -1): |
| 319 | The index of the last epoch when resuming training. |
| 320 | """ |
| 321 | name = SchedulerType(name) |
| 322 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] |
| 323 | if name == SchedulerType.CONSTANT: |
| 324 | return schedule_func(optimizer, last_epoch=last_epoch) |
| 325 | |
| 326 | if name == SchedulerType.PIECEWISE_CONSTANT: |
| 327 | return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch) |
| 328 | |
| 329 | # All other schedulers require `num_warmup_steps` |
| 330 | if num_warmup_steps is None: |
| 331 | raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") |
| 332 | |
| 333 | if name == SchedulerType.CONSTANT_WITH_WARMUP: |
| 334 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch) |
| 335 | |
| 336 | # All other schedulers require `num_training_steps` |
| 337 | if num_training_steps is None: |
| 338 | raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") |
| 339 | |
| 340 | if name == SchedulerType.COSINE_WITH_RESTARTS: |
| 341 | return schedule_func( |
| 342 | optimizer, |
| 343 | num_warmup_steps=num_warmup_steps, |
| 344 | num_training_steps=num_training_steps, |
| 345 | num_cycles=num_cycles, |
no test coverage detected
searching dependent graphs…