MCPcopy
hub / github.com/dmlc/dgl / test_csrmm_backward

Function test_csrmm_backward

tests/python/common/test_sparse_ops-csr.py:67–110  ·  view source on GitHub ↗
(idtype, dtype, num_vtypes)

Source from the content-addressed store, hash-verified

65@pytest.mark.parametrize("dtype", [F.float32, F.float64])
66@pytest.mark.parametrize("num_vtypes", [1, 2])
67def test_csrmm_backward(idtype, dtype, num_vtypes):
68 a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, "A", "B", "AB")
69 b, B = _random_simple_graph(
70 idtype,
71 dtype,
72 F.ctx(),
73 4,
74 3,
75 6,
76 "B",
77 "A" if num_vtypes == 1 else "C",
78 "BA",
79 )
80 A_row, A_col = A.edges(order="eid")
81 B_row, B_col = B.edges(order="eid")
82 A_row = F.asnumpy(A_row)
83 A_col = F.asnumpy(A_col)
84 B_row = F.asnumpy(B_row)
85 B_col = F.asnumpy(B_col)
86 a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype))
87 b_dense = F.attach_grad(F.tensor(b.todense(), dtype=dtype))
88
89 A.edata["w"] = F.attach_grad(A.edata["w"])
90 B.edata["w"] = F.attach_grad(B.edata["w"])
91
92 with F.record_grad():
93 C = dgl.adj_product_graph(A, B, "w")
94 assert len(C.ntypes) == num_vtypes
95 assert len(C.etypes) == 1
96 C_dense = np.zeros((3, 3))
97 C_row, C_col = C.edges(order="eid")
98 C_row = F.asnumpy(C_row)
99 C_col = F.asnumpy(C_col)
100 C_dense[C_row, C_col] = F.asnumpy(C.edata["w"])
101 c_dense = F.matmul(a_dense, b_dense)
102 assert np.allclose(C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4)
103
104 F.backward(F.reduce_sum(C.edata["w"]) + F.reduce_sum(c_dense))
105 a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col]
106 b_dense_grad = F.asnumpy(F.grad(b_dense))[B_row, B_col]
107 A_spspmm_grad = F.asnumpy(F.grad(A.edata["w"]))
108 B_spspmm_grad = F.asnumpy(F.grad(B.edata["w"]))
109 assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4)
110 assert np.allclose(b_dense_grad, B_spspmm_grad, rtol=1e-4, atol=1e-4)
111
112
113@parametrize_idtype

Callers 1

Calls 6

_random_simple_graphFunction · 0.85
asnumpyMethod · 0.80
gradMethod · 0.80
ctxMethod · 0.45
edgesMethod · 0.45
backwardMethod · 0.45

Tested by

no test coverage detected