(params, lr, betas, eps, momentum, optimizer_name)
| 350 | |
| 351 | |
| 352 | def get_optimizer(params, lr, betas, eps, momentum, optimizer_name): |
| 353 | if optimizer_name.lower() == "adamw": |
| 354 | optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps) |
| 355 | elif optimizer_name.lower() == "sgd": |
| 356 | optimizer = optim.SGD(params, lr=lr, momentum=momentum) |
| 357 | elif optimizer_name.lower() == "adam": |
| 358 | optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps) |
| 359 | else: |
| 360 | raise ValueError("optimizer name is not correct") |
| 361 | return optimizer |
no outgoing calls
no test coverage detected