MCPcopy
hub / github.com/dmlc/dgl / test_diag_conversions

Function test_diag_conversions

tests/python/pytorch/sparse/test_sparse_matrix.py:363–381  ·  view source on GitHub ↗
(shape)

Source from the content-addressed store, hash-verified

361
362@pytest.mark.parametrize("shape", [(3, 5), (5, 5), (5, 4)])
363def test_diag_conversions(shape):
364 n_rows, n_cols = shape
365 nnz = min(shape)
366 ctx = F.ctx()
367 val = torch.randn(nnz).to(ctx)
368 D = diag(val, shape)
369 row, col = D.coo()
370 assert torch.allclose(row, torch.arange(nnz).to(ctx))
371 assert torch.allclose(col, torch.arange(nnz).to(ctx))
372
373 indptr, indices, _ = D.csr()
374 exp_indptr = list(range(0, nnz + 1)) + [nnz] * (n_rows - nnz)
375 assert torch.allclose(indptr, torch.tensor(exp_indptr).to(ctx))
376 assert torch.allclose(indices, torch.arange(nnz).to(ctx))
377
378 indptr, indices, _ = D.csc()
379 exp_indptr = list(range(0, nnz + 1)) + [nnz] * (n_cols - nnz)
380 assert torch.allclose(indptr, torch.tensor(exp_indptr).to(ctx))
381 assert torch.allclose(indices, torch.arange(nnz).to(ctx))
382
383
384@pytest.mark.parametrize("val_shape", [(3), (3, 2)])

Callers

nothing calls this directly

Calls 7

diagFunction · 0.90
cooMethod · 0.80
csrMethod · 0.80
cscMethod · 0.80
minFunction · 0.50
ctxMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected