()
| 18 | |
| 19 | |
| 20 | def test_default_get_optimizer(): |
| 21 | with patch("sys.argv", [""]): |
| 22 | parser = setup_parser() |
| 23 | args = parser.parse_args() |
| 24 | params_t = torch.tensor([1.5, 1.5]) |
| 25 | |
| 26 | param = Parameter(params_t) |
| 27 | optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param]) |
| 28 | assert optimizer_name == "torch.optim.adamw.AdamW" |
| 29 | assert optimizer_args == "" |
| 30 | assert isinstance(optimizer, torch.optim.AdamW) |
| 31 | |
| 32 | |
| 33 | def test_get_schedulefree_optimizer(): |
nothing calls this directly
no test coverage detected