MCPcopy
hub / github.com/dmlc/dgl / _test

Function _test

tests/python/common/test_heterograph-update-all.py:92–152  ·  view source on GitHub ↗
(mfunc, rfunc)

Source from the content-addressed store, hash-verified

90@parametrize_idtype
91def test_unary_copy_u(idtype):
92 def _test(mfunc, rfunc):
93 g = create_test_heterograph_2(idtype)
94 g0 = create_test_heterograph(idtype)
95 g1 = create_test_heterograph_large(idtype)
96 cross_reducer = rfunc.__name__
97 x1 = F.randn((g.num_nodes("user"), feat_size))
98 x2 = F.randn((g.num_nodes("developer"), feat_size))
99 F.attach_grad(x1)
100 F.attach_grad(x2)
101 g.nodes["user"].data["h"] = x1
102 g.nodes["developer"].data["h"] = x2
103
104 #################################################################
105 # multi_update_all(): call msg_passing separately for each etype
106 #################################################################
107
108 with F.record_grad():
109 g.multi_update_all(
110 {
111 etype: (mfunc("h", "m"), rfunc("m", "y"))
112 for etype in g.canonical_etypes
113 },
114 cross_reducer,
115 )
116 r1 = g.nodes["game"].data["y"].clone()
117 r2 = g.nodes["user"].data["y"].clone()
118 r3 = g.nodes["player"].data["y"].clone()
119 loss = r1.sum() + r2.sum() + r3.sum()
120 F.backward(loss)
121 n_grad1 = F.grad(g.nodes["user"].data["h"]).clone()
122 n_grad2 = F.grad(g.nodes["developer"].data["h"]).clone()
123
124 g.nodes["user"].data.clear()
125 g.nodes["developer"].data.clear()
126 g.nodes["game"].data.clear()
127 g.nodes["player"].data.clear()
128
129 #################################################################
130 # update_all(): call msg_passing for all etypes
131 #################################################################
132
133 F.attach_grad(x1)
134 F.attach_grad(x2)
135 g.nodes["user"].data["h"] = x1
136 g.nodes["developer"].data["h"] = x2
137
138 with F.record_grad():
139 g.update_all(mfunc("h", "m"), rfunc("m", "y"))
140 r4 = g.nodes["game"].data["y"]
141 r5 = g.nodes["user"].data["y"]
142 r6 = g.nodes["player"].data["y"]
143 loss = r4.sum() + r5.sum() + r6.sum()
144 F.backward(loss)
145 n_grad3 = F.grad(g.nodes["user"].data["h"])
146 n_grad4 = F.grad(g.nodes["developer"].data["h"])
147
148 assert F.allclose(r1, r4)
149 assert F.allclose(r2, r5)

Callers 3

test_unary_copy_uFunction · 0.70
test_unary_copy_eFunction · 0.70
test_binary_opFunction · 0.70

Calls 15

multi_update_allMethod · 0.80
gradMethod · 0.80
update_allMethod · 0.80
formatMethod · 0.80
create_test_heterographFunction · 0.70
mfuncFunction · 0.70
rfuncFunction · 0.70
_print_errorFunction · 0.70
num_nodesMethod · 0.45
cloneMethod · 0.45

Tested by

no test coverage detected