(optimizer, args)
| 54 | |
| 55 | |
| 56 | def get_scheduler(optimizer, args): |
| 57 | if not args.schuse: |
| 58 | return None |
| 59 | if args.schusech == 'cos': |
| 60 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| 61 | optimizer, args.max_epoch * args.steps_per_epoch) |
| 62 | else: |
| 63 | scheduler = torch.optim.lr_scheduler.LambdaLR( |
| 64 | optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) |
| 65 | return scheduler |