MCPcopy
hub / github.com/dmlc/dgl / _test

Function _test

tests/python/common/test_heterograph-apply-edges.py:71–120  ·  view source on GitHub ↗
(mfunc)

Source from the content-addressed store, hash-verified

69@parametrize_idtype
70def test_unary_copy_u(idtype):
71 def _test(mfunc):
72 g = create_test_heterograph(idtype)
73
74 x1 = F.randn((g.num_nodes("user"), feat_size))
75 x2 = F.randn((g.num_nodes("developer"), feat_size))
76
77 F.attach_grad(x1)
78 F.attach_grad(x2)
79 g.nodes["user"].data["h"] = x1
80 g.nodes["developer"].data["h"] = x2
81
82 #################################################################
83 # apply_edges() is called on each relation type separately
84 #################################################################
85
86 with F.record_grad():
87 [
88 g.apply_edges(fn.copy_u("h", "m"), etype=rel)
89 for rel in g.canonical_etypes
90 ]
91 r1 = g["plays"].edata["m"]
92 F.backward(r1, F.ones(r1.shape))
93 n_grad1 = F.grad(g.ndata["h"]["user"])
94 # TODO (Israt): clear not working
95 g.edata["m"].clear()
96
97 #################################################################
98 # apply_edges() is called on all relation types
99 #################################################################
100
101 g.apply_edges(fn.copy_u("h", "m"))
102 r2 = g["plays"].edata["m"]
103 F.backward(r2, F.ones(r2.shape))
104 n_grad2 = F.grad(g.nodes["user"].data["h"])
105
106 # correctness check
107 def _print_error(a, b):
108 for i, (x, y) in enumerate(
109 zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
110 ):
111 if not np.allclose(x, y):
112 print("@{} {} v.s. {}".format(i, x, y))
113
114 if not F.allclose(r1, r2):
115 _print_error(r1, r2)
116 assert F.allclose(r1, r2)
117 if not F.allclose(n_grad1, n_grad2):
118 print("node grad")
119 _print_error(n_grad1, n_grad2)
120 assert F.allclose(n_grad1, n_grad2)
121
122 _test(fn.copy_u)
123

Callers 3

test_unary_copy_uFunction · 0.70
test_unary_copy_eFunction · 0.70
test_binary_opFunction · 0.70

Calls 9

gradMethod · 0.80
formatMethod · 0.80
create_test_heterographFunction · 0.70
_print_errorFunction · 0.70
num_nodesMethod · 0.45
apply_edgesMethod · 0.45
backwardMethod · 0.45
clearMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected