| 56 | |
| 57 | @pytest.mark.parametrize("val_shape", [(3,), (3, 2)]) |
| 58 | def test_pow(val_shape): |
| 59 | # A ** v |
| 60 | ctx = F.ctx() |
| 61 | row = torch.tensor([1, 0, 2]).to(ctx) |
| 62 | col = torch.tensor([0, 3, 2]).to(ctx) |
| 63 | val = torch.randn(val_shape).to(ctx) |
| 64 | A = from_coo(row, col, val, shape=(3, 4)) |
| 65 | exponent = 2 |
| 66 | A_new = A**exponent |
| 67 | assert torch.allclose(A_new.val, val**exponent) |
| 68 | assert A_new.shape == A.shape |
| 69 | new_row, new_col = A_new.coo() |
| 70 | assert torch.allclose(new_row, row) |
| 71 | assert torch.allclose(new_col, col) |
| 72 | |
| 73 | # power(A, v) |
| 74 | A_new = power(A, exponent) |
| 75 | assert torch.allclose(A_new.val, val**exponent) |
| 76 | assert A_new.shape == A.shape |
| 77 | new_row, new_col = A_new.coo() |
| 78 | assert torch.allclose(new_row, row) |
| 79 | assert torch.allclose(new_col, col) |
| 80 | |
| 81 | |
| 82 | @pytest.mark.parametrize("op", ["add", "sub"]) |