| 384 | @pytest.mark.parametrize("val_shape", [(3), (3, 2)]) |
| 385 | @pytest.mark.parametrize("shape", [(3, 5), (5, 5)]) |
| 386 | def test_val_like(val_shape, shape): |
| 387 | def check_val_like(A, B): |
| 388 | assert A.shape == B.shape |
| 389 | assert A.nnz == B.nnz |
| 390 | assert torch.allclose(torch.stack(A.coo()), torch.stack(B.coo())) |
| 391 | assert A.val.device == B.val.device |
| 392 | |
| 393 | ctx = F.ctx() |
| 394 | |
| 395 | # COO |
| 396 | row = torch.tensor([1, 1, 2]).to(ctx) |
| 397 | col = torch.tensor([2, 4, 3]).to(ctx) |
| 398 | val = torch.randn(3).to(ctx) |
| 399 | coo_A = from_coo(row, col, val, shape) |
| 400 | new_val = torch.randn(val_shape).to(ctx) |
| 401 | coo_B = val_like(coo_A, new_val) |
| 402 | check_val_like(coo_A, coo_B) |
| 403 | |
| 404 | # CSR |
| 405 | indptr, indices, _ = coo_A.csr() |
| 406 | csr_A = from_csr(indptr, indices, val, shape) |
| 407 | csr_B = val_like(csr_A, new_val) |
| 408 | check_val_like(csr_A, csr_B) |
| 409 | |
| 410 | # CSC |
| 411 | indptr, indices, _ = coo_A.csc() |
| 412 | csc_A = from_csc(indptr, indices, val, shape) |
| 413 | csc_B = val_like(csc_A, new_val) |
| 414 | check_val_like(csc_A, csc_B) |
| 415 | |
| 416 | |
| 417 | def test_coalesce(): |