MCPcopy Index your code
hub / github.com/huggingface/diffusers / get_scheduler

Function get_scheduler

src/diffusers/optimization.py:288–360  ·  view source on GitHub ↗

Unified API to get any scheduler from its name. Args: name (`str` or `SchedulerType`): The name of the scheduler to use. optimizer (`torch.optim.Optimizer`): The optimizer that will be used during training. step_rules (`str`, *optional*):

(
    name: str | SchedulerType,
    optimizer: Optimizer,
    step_rules: str | None = None,
    num_warmup_steps: int | None = None,
    num_training_steps: int | None = None,
    num_cycles: int = 1,
    power: float = 1.0,
    last_epoch: int = -1,
)

Source from the content-addressed store, hash-verified

286
287
288def get_scheduler(
289 name: str | SchedulerType,
290 optimizer: Optimizer,
291 step_rules: str | None = None,
292 num_warmup_steps: int | None = None,
293 num_training_steps: int | None = None,
294 num_cycles: int = 1,
295 power: float = 1.0,
296 last_epoch: int = -1,
297) -> LambdaLR:
298 """
299 Unified API to get any scheduler from its name.
300
301 Args:
302 name (`str` or `SchedulerType`):
303 The name of the scheduler to use.
304 optimizer (`torch.optim.Optimizer`):
305 The optimizer that will be used during training.
306 step_rules (`str`, *optional*):
307 A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
308 num_warmup_steps (`int`, *optional*):
309 The number of warmup steps to do. This is not required by all schedulers (hence the argument being
310 optional), the function will raise an error if it's unset and the scheduler type requires it.
311 num_training_steps (`int``, *optional*):
312 The number of training steps to do. This is not required by all schedulers (hence the argument being
313 optional), the function will raise an error if it's unset and the scheduler type requires it.
314 num_cycles (`int`, *optional*):
315 The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
316 power (`float`, *optional*, defaults to 1.0):
317 Power factor. See `POLYNOMIAL` scheduler
318 last_epoch (`int`, *optional*, defaults to -1):
319 The index of the last epoch when resuming training.
320 """
321 name = SchedulerType(name)
322 schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
323 if name == SchedulerType.CONSTANT:
324 return schedule_func(optimizer, last_epoch=last_epoch)
325
326 if name == SchedulerType.PIECEWISE_CONSTANT:
327 return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch)
328
329 # All other schedulers require `num_warmup_steps`
330 if num_warmup_steps is None:
331 raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
332
333 if name == SchedulerType.CONSTANT_WITH_WARMUP:
334 return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)
335
336 # All other schedulers require `num_training_steps`
337 if num_training_steps is None:
338 raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
339
340 if name == SchedulerType.COSINE_WITH_RESTARTS:
341 return schedule_func(
342 optimizer,
343 num_warmup_steps=num_warmup_steps,
344 num_training_steps=num_training_steps,
345 num_cycles=num_cycles,

Callers 15

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls 1

SchedulerTypeClass · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…