()
| 27 | |
| 28 | |
| 29 | def test_optimizer_error(): |
| 30 | x1 = relax.Var("x1", R.Tensor((3, 3), "float32")) |
| 31 | x2 = relax.Var("x2", R.Tensor((3, 3), "float64")) |
| 32 | x3 = relax.Var("x3", R.Tuple([R.Tensor((3, 3), "float32")])) |
| 33 | x4 = relax.Var("x4", R.Tensor((3, 3), "int64")) |
| 34 | x5 = relax.Tuple([x1]) |
| 35 | |
| 36 | # fine cases |
| 37 | SGD(0.01).init(x1) |
| 38 | SGD(0.01).init([x1]) |
| 39 | assert SGD(0.01).init([x2]).dtype == "float64" |
| 40 | |
| 41 | with pytest.raises(ValueError): |
| 42 | SGD(0.01).init([x1, x1]) |
| 43 | with pytest.raises(ValueError): |
| 44 | SGD(0.01).init([x1, x2]) |
| 45 | with pytest.raises(ValueError): |
| 46 | SGD(0.01).init(x3) |
| 47 | with pytest.raises(ValueError): |
| 48 | SGD(0.01).init(x4) |
| 49 | with pytest.raises(ValueError): |
| 50 | SGD(0.01).init(x5) |
| 51 | with pytest.raises( |
| 52 | RuntimeError, |
| 53 | match="Please call init\\(\\) for the optimizer before calling get_function\\(\\)", |
| 54 | ): |
| 55 | SGD(0.01).get_function() |
| 56 | |
| 57 | |
| 58 | def test_sgd_simple(): |
nothing calls this directly
no test coverage detected
searching dependent graphs…