MCPcopy
hub / github.com/microsoft/Swin-Transformer / build_scheduler

Function build_scheduler

lr_scheduler.py:16–63  ·  view source on GitHub ↗
(config, optimizer, n_iter_per_epoch)

Source from the content-addressed store, hash-verified

14
15
16def build_scheduler(config, optimizer, n_iter_per_epoch):
17 num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
18 warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
19 decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
20 multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS]
21
22 lr_scheduler = None
23 if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
24 lr_scheduler = CosineLRScheduler(
25 optimizer,
26 t_initial=(num_steps - warmup_steps) if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX else num_steps,
27 t_mul=1.,
28 lr_min=config.TRAIN.MIN_LR,
29 warmup_lr_init=config.TRAIN.WARMUP_LR,
30 warmup_t=warmup_steps,
31 cycle_limit=1,
32 t_in_epochs=False,
33 warmup_prefix=config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX,
34 )
35 elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
36 lr_scheduler = LinearLRScheduler(
37 optimizer,
38 t_initial=num_steps,
39 lr_min_rate=0.01,
40 warmup_lr_init=config.TRAIN.WARMUP_LR,
41 warmup_t=warmup_steps,
42 t_in_epochs=False,
43 )
44 elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
45 lr_scheduler = StepLRScheduler(
46 optimizer,
47 decay_t=decay_steps,
48 decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
49 warmup_lr_init=config.TRAIN.WARMUP_LR,
50 warmup_t=warmup_steps,
51 t_in_epochs=False,
52 )
53 elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep':
54 lr_scheduler = MultiStepLRScheduler(
55 optimizer,
56 milestones=multi_steps,
57 gamma=config.TRAIN.LR_SCHEDULER.GAMMA,
58 warmup_lr_init=config.TRAIN.WARMUP_LR,
59 warmup_t=warmup_steps,
60 t_in_epochs=False,
61 )
62
63 return lr_scheduler
64
65
66class LinearLRScheduler(Scheduler):

Callers 4

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls 2

LinearLRSchedulerClass · 0.85

Tested by

no test coverage detected