MCPcopy
hub / github.com/dmlc/dgl / test_add_sparse_diag

Function test_add_sparse_diag

tests/python/pytorch/sparse/test_elementwise_op.py:181–201  ·  view source on GitHub ↗
(val_shape)

Source from the content-addressed store, hash-verified

179
180@pytest.mark.parametrize("val_shape", [(), (2,)])
181def test_add_sparse_diag(val_shape):
182 ctx = F.ctx()
183 row = torch.tensor([1, 0, 2]).to(ctx)
184 col = torch.tensor([0, 3, 2]).to(ctx)
185 val = torch.randn(row.shape + val_shape).to(ctx)
186 A = dglsp.from_coo(row, col, val)
187
188 shape = (3, 4)
189 val_shape = (shape[0],) + val_shape
190 D = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)
191
192 sum1 = (A + D).to_dense()
193 sum2 = (D + A).to_dense()
194 sum3 = dglsp.add(A, D).to_dense()
195 sum4 = dglsp.add(D, A).to_dense()
196 dense_sum = A.to_dense() + D.to_dense()
197
198 assert torch.allclose(dense_sum, sum1)
199 assert torch.allclose(dense_sum, sum2)
200 assert torch.allclose(dense_sum, sum3)
201 assert torch.allclose(dense_sum, sum4)
202
203
204@pytest.mark.parametrize("val_shape", [(), (2,)])

Callers

nothing calls this directly

Calls 3

to_denseMethod · 0.80
ctxMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected