MCPcopy
hub / github.com/dmlc/dgl / test_spmm

Function test_spmm

tests/python/common/ops/test_ops.py:119–181  ·  view source on GitHub ↗
(idtype, dtype, g, shp, msg, reducer)

Source from the content-addressed store, hash-verified

117@parametrize_idtype
118@pytest.mark.parametrize("dtype", [np.float32, np.float64])
119def test_spmm(idtype, dtype, g, shp, msg, reducer):
120 g = g.astype(idtype).to(F.ctx())
121 print(g)
122 print(g.idtype)
123
124 hu = F.tensor(
125 np.random.rand(*((g.number_of_src_nodes(),) + shp[0])).astype(dtype) + 1
126 )
127 he = F.tensor(
128 np.random.rand(*((g.num_edges(),) + shp[1])).astype(dtype) + 1
129 )
130 print("u shape: {}, e shape: {}".format(F.shape(hu), F.shape(he)))
131
132 g.srcdata["x"] = F.attach_grad(F.clone(hu))
133 g.edata["w"] = F.attach_grad(F.clone(he))
134 print("SpMM(message func: {}, reduce func: {})".format(msg, reducer))
135
136 u = F.attach_grad(F.clone(hu))
137 e = F.attach_grad(F.clone(he))
138 with F.record_grad():
139 v = gspmm(g, msg, reducer, u, e)
140 if reducer in ["max", "min"]:
141 v = F.replace_inf_with_zero(v)
142 if g.num_edges() > 0:
143 F.backward(F.reduce_sum(v))
144 if msg != "copy_rhs":
145 grad_u = F.grad(u)
146 if msg != "copy_lhs":
147 grad_e = F.grad(e)
148
149 with F.record_grad():
150 g.update_all(udf_msg[msg], udf_reduce[reducer])
151 if g.num_edges() > 0:
152 v1 = g.dstdata["v"]
153 assert F.allclose(v, v1)
154 print("forward passed")
155
156 F.backward(F.reduce_sum(v1))
157 if msg != "copy_rhs":
158 if reducer in [
159 "min",
160 "max",
161 ]: # there might be some numerical errors
162 rate = F.reduce_sum(
163 F.abs(F.grad(g.srcdata["x"]) - grad_u)
164 ) / F.reduce_sum(F.abs(grad_u))
165 assert F.as_scalar(rate) < 1e-2, rate
166 else:
167 assert F.allclose(F.grad(g.srcdata["x"]), grad_u)
168 if msg != "copy_lhs":
169 if reducer in ["min", "max"]:
170 rate = F.reduce_sum(
171 F.abs(F.grad(g.edata["w"]) - grad_e)
172 ) / F.reduce_sum(F.abs(grad_e))
173 assert F.as_scalar(rate) < 1e-2, rate
174 else:
175 assert F.allclose(F.grad(g.edata["w"]), grad_e)
176 print("backward passed")

Callers

nothing calls this directly

Calls 12

gspmmFunction · 0.90
number_of_src_nodesMethod · 0.80
formatMethod · 0.80
gradMethod · 0.80
update_allMethod · 0.80
toMethod · 0.45
astypeMethod · 0.45
ctxMethod · 0.45
num_edgesMethod · 0.45
shapeMethod · 0.45
cloneMethod · 0.45
backwardMethod · 0.45

Tested by

no test coverage detected