MCPcopy Index your code
hub / github.com/apache/tvm / test_optimizer_error

Function test_optimizer_error

tests/python/relax/test_training_optimizer.py:29–55  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

27
28
29def test_optimizer_error():
30 x1 = relax.Var("x1", R.Tensor((3, 3), "float32"))
31 x2 = relax.Var("x2", R.Tensor((3, 3), "float64"))
32 x3 = relax.Var("x3", R.Tuple([R.Tensor((3, 3), "float32")]))
33 x4 = relax.Var("x4", R.Tensor((3, 3), "int64"))
34 x5 = relax.Tuple([x1])
35
36 # fine cases
37 SGD(0.01).init(x1)
38 SGD(0.01).init([x1])
39 assert SGD(0.01).init([x2]).dtype == "float64"
40
41 with pytest.raises(ValueError):
42 SGD(0.01).init([x1, x1])
43 with pytest.raises(ValueError):
44 SGD(0.01).init([x1, x2])
45 with pytest.raises(ValueError):
46 SGD(0.01).init(x3)
47 with pytest.raises(ValueError):
48 SGD(0.01).init(x4)
49 with pytest.raises(ValueError):
50 SGD(0.01).init(x5)
51 with pytest.raises(
52 RuntimeError,
53 match="Please call init\\(\\) for the optimizer before calling get_function\\(\\)",
54 ):
55 SGD(0.01).get_function()
56
57
58def test_sgd_simple():

Callers

nothing calls this directly

Calls 4

SGDClass · 0.90
TensorMethod · 0.80
initMethod · 0.45
get_functionMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…