MCPcopy
hub / github.com/dmlc/dgl / test_diag

Function test_diag

tests/python/pytorch/sparse/test_sparse_matrix.py:820–847  ·  view source on GitHub ↗
(val_shape, mat_shape)

Source from the content-addressed store, hash-verified

818@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
819@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
820def test_diag(val_shape, mat_shape):
821 ctx = F.ctx()
822 # creation
823 val = torch.randn(val_shape).to(ctx)
824 mat = diag(val, mat_shape)
825
826 # val, shape attributes
827 assert torch.allclose(mat.val, val)
828 if mat_shape is None:
829 mat_shape = (val_shape[0], val_shape[0])
830 assert mat.shape == mat_shape
831
832 val = torch.randn(val_shape).to(ctx)
833
834 # nnz
835 assert mat.nnz == val.shape[0]
836 # dtype
837 assert mat.dtype == val.dtype
838 # device
839 assert mat.device == val.device
840
841 # row, col, val
842 edge_index = torch.arange(len(val)).to(mat.device)
843 row, col = mat.coo()
844 val = mat.val
845 assert torch.allclose(row, edge_index)
846 assert torch.allclose(col, edge_index)
847 assert torch.allclose(val, val)
848
849
850@pytest.mark.parametrize("shape", [(3, 3), (3, 5), (5, 3)])

Callers

nothing calls this directly

Calls 4

diagFunction · 0.90
cooMethod · 0.80
ctxMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected