MCPcopy
hub / github.com/dmlc/dgl / _test

Function _test

tests/python/common/test_heterograph-kernel.py:91–148  ·  view source on GitHub ↗
(red, partial)

Source from the content-addressed store, hash-verified

89
90def test_copy_src_reduce():
91 def _test(red, partial):
92 g = dgl.from_networkx(nx.erdos_renyi_graph(100, 0.1))
93 # NOTE(zihao): add self-loop to avoid zero-degree nodes.
94 # https://github.com/dmlc/dgl/issues/761
95 g.add_edges(g.nodes(), g.nodes())
96 g = g.to(F.ctx())
97 hu, hv, he = generate_feature(g, "none", "none")
98 if partial:
99 nid = F.tensor(list(range(0, 100, 2)), g.idtype)
100
101 g.ndata["u"] = F.attach_grad(F.clone(hu))
102 g.ndata["v"] = F.attach_grad(F.clone(hv))
103 g.edata["e"] = F.attach_grad(F.clone(he))
104
105 with F.record_grad():
106 if partial:
107 g.pull(
108 nid,
109 fn.copy_u(u="u", out="m"),
110 builtin[red](msg="m", out="r1"),
111 )
112 else:
113 g.update_all(
114 fn.copy_u(u="u", out="m"), builtin[red](msg="m", out="r1")
115 )
116 r1 = g.ndata["r1"]
117 F.backward(F.reduce_sum(r1))
118 n_grad1 = F.grad(g.ndata["u"])
119
120 # reset grad
121 g.ndata["u"] = F.attach_grad(F.clone(hu))
122 g.ndata["v"] = F.attach_grad(F.clone(hv))
123 g.edata["e"] = F.attach_grad(F.clone(he))
124
125 with F.record_grad():
126 if partial:
127 g.pull(nid, udf_copy_src, udf_reduce[red])
128 else:
129 g.update_all(udf_copy_src, udf_reduce[red])
130 r2 = g.ndata["r2"]
131 F.backward(F.reduce_sum(r2))
132 n_grad2 = F.grad(g.ndata["u"])
133
134 def _print_error(a, b):
135 print("ERROR: Test copy_src_{} partial: {}".format(red, partial))
136 for i, (x, y) in enumerate(
137 zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
138 ):
139 if not np.allclose(x, y):
140 print("@{} {} v.s. {}".format(i, x, y))
141
142 if not F.allclose(r1, r2):
143 _print_error(r1, r2)
144 assert F.allclose(r1, r2)
145 if not F.allclose(n_grad1, n_grad2):
146 print("node grad")
147 _print_error(n_grad1, n_grad2)
148 assert F.allclose(n_grad1, n_grad2)

Callers 3

test_copy_src_reduceFunction · 0.70
test_copy_edge_reduceFunction · 0.70
test_all_binary_builtinsFunction · 0.70

Calls 13

generate_featureFunction · 0.85
target_feature_switchFunction · 0.85
update_allMethod · 0.80
gradMethod · 0.80
formatMethod · 0.80
_print_errorFunction · 0.70
add_edgesMethod · 0.45
nodesMethod · 0.45
toMethod · 0.45
ctxMethod · 0.45
cloneMethod · 0.45
pullMethod · 0.45

Tested by

no test coverage detected