MCPcopy Index your code
hub / github.com/apache/tvm / test_simple

Function test_simple

tests/python/relax/test_training_setup_trainer.py:31–99  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

29
30
31def test_simple():
32 # fmt: off
33 @I.ir_module
34 class Backbone:
35 I.module_attrs({"param_num": 1, "state_num": 0})
36 @R.function
37 def backbone(x: R.Tensor((2, 2), "float64"), y: R.Tensor((2, 2), "float64")):
38 with R.dataflow():
39 x1 = x + y
40 R.output(x1)
41 return x1
42
43 @I.ir_module
44 class Expected:
45 I.module_attrs({"input_num": 1, "param_num": 1, "state_num": 0})
46 @R.function
47 def backbone(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64")) -> R.Tensor((2, 2), dtype="float64"):
48 with R.dataflow():
49 x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
50 R.output(x1)
51 return x1
52
53 @R.function
54 def backbone_loss(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), targets: R.Tensor((2, 2), dtype="float64")) -> R.Tensor((), dtype="float64"):
55 with R.dataflow():
56 x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
57 lv: R.Tensor((2, 2), dtype="float64") = R.subtract(x1, targets)
58 lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv, lv)
59 gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False)
60 R.output(gv)
61 return gv
62
63 @R.function
64 def backbone_loss_adjoint(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), targets: R.Tensor((2, 2), dtype="float64")) -> R.Tuple(R.Tensor((), dtype="float64"), R.Tuple(R.Tensor((2, 2), dtype="float64"))):
65 with R.dataflow():
66 x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
67 lv: R.Tensor((2, 2), dtype="float64") = R.subtract(x1, targets)
68 lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv, lv)
69 gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False)
70 gv_adjoint: R.Tensor((), dtype="float64") = R.ones(R.shape([]), dtype="float64")
71 lv1_adjoint: R.Tensor((2, 2), dtype="float64") = R.broadcast_to(gv_adjoint, R.shape([2, 2]))
72 lv_adjoint: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv)
73 lv_1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv)
74 lv_adjoint1: R.Tensor((2, 2), dtype="float64") = R.add(lv_adjoint, lv_1)
75 x1_adjoint: R.Tensor((2, 2), dtype="float64") = lv_adjoint1
76 y_adjoint: R.Tensor((2, 2), dtype="float64") = x1_adjoint
77 y_adjoint_out: R.Tensor((2, 2), dtype="float64") = y_adjoint
78 R.output(gv, y_adjoint_out)
79 return (gv, (y_adjoint_out,))
80
81 @R.function
82 def optimizer(params: R.Tuple(R.Tensor((2, 2), dtype="float64")), gradients: R.Tuple(R.Tensor((2, 2), dtype="float64")), optim_states: R.Tuple(R.Tensor((), dtype="int64"))) -> R.Tuple(R.Tuple(R.Tensor((2, 2), dtype="float64")), R.Tuple(R.Tensor((), dtype="int64"))):
83 with R.dataflow():
84 num_steps: R.Tensor((), dtype="int64") = optim_states[0]
85 num_steps_new: R.Tensor((), dtype="int64") = R.add(num_steps, R.const(1, "int64"))
86 y: R.Tensor((2, 2), dtype="float64") = params[0]
87 y_grad: R.Tensor((2, 2), dtype="float64") = gradients[0]
88 lv: R.Tensor((2, 2), dtype="float64") = R.multiply(R.const(0.10000000000000001, "float64"), y_grad)

Callers

nothing calls this directly

Calls 5

SetupTrainerClass · 0.90
MSELossClass · 0.90
SGDClass · 0.90
assert_structural_equalFunction · 0.90
without_attrMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…