(config, optimizer, n_iter_per_epoch)
| 14 | |
| 15 | |
| 16 | def build_scheduler(config, optimizer, n_iter_per_epoch): |
| 17 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) |
| 18 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) |
| 19 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) |
| 20 | multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS] |
| 21 | |
| 22 | lr_scheduler = None |
| 23 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': |
| 24 | lr_scheduler = CosineLRScheduler( |
| 25 | optimizer, |
| 26 | t_initial=(num_steps - warmup_steps) if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX else num_steps, |
| 27 | t_mul=1., |
| 28 | lr_min=config.TRAIN.MIN_LR, |
| 29 | warmup_lr_init=config.TRAIN.WARMUP_LR, |
| 30 | warmup_t=warmup_steps, |
| 31 | cycle_limit=1, |
| 32 | t_in_epochs=False, |
| 33 | warmup_prefix=config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX, |
| 34 | ) |
| 35 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': |
| 36 | lr_scheduler = LinearLRScheduler( |
| 37 | optimizer, |
| 38 | t_initial=num_steps, |
| 39 | lr_min_rate=0.01, |
| 40 | warmup_lr_init=config.TRAIN.WARMUP_LR, |
| 41 | warmup_t=warmup_steps, |
| 42 | t_in_epochs=False, |
| 43 | ) |
| 44 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': |
| 45 | lr_scheduler = StepLRScheduler( |
| 46 | optimizer, |
| 47 | decay_t=decay_steps, |
| 48 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, |
| 49 | warmup_lr_init=config.TRAIN.WARMUP_LR, |
| 50 | warmup_t=warmup_steps, |
| 51 | t_in_epochs=False, |
| 52 | ) |
| 53 | elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep': |
| 54 | lr_scheduler = MultiStepLRScheduler( |
| 55 | optimizer, |
| 56 | milestones=multi_steps, |
| 57 | gamma=config.TRAIN.LR_SCHEDULER.GAMMA, |
| 58 | warmup_lr_init=config.TRAIN.WARMUP_LR, |
| 59 | warmup_t=warmup_steps, |
| 60 | t_in_epochs=False, |
| 61 | ) |
| 62 | |
| 63 | return lr_scheduler |
| 64 | |
| 65 | |
| 66 | class LinearLRScheduler(Scheduler): |
no test coverage detected