(alg, args)
| 21 | |
| 22 | |
| 23 | def get_optimizer(alg, args): |
| 24 | params = get_params(alg, args) |
| 25 | optimizer = torch.optim.Adam( |
| 26 | params, lr=args.lr, weight_decay=args.weight_decay, betas=(args.beta1, 0.9)) |
| 27 | return optimizer |
| 28 | |
| 29 | |
| 30 | def get_scheduler(optimizer, args): |