| 818 | @pytest.mark.parametrize("val_shape", [(3,), (3, 2)]) |
| 819 | @pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)]) |
| 820 | def test_diag(val_shape, mat_shape): |
| 821 | ctx = F.ctx() |
| 822 | # creation |
| 823 | val = torch.randn(val_shape).to(ctx) |
| 824 | mat = diag(val, mat_shape) |
| 825 | |
| 826 | # val, shape attributes |
| 827 | assert torch.allclose(mat.val, val) |
| 828 | if mat_shape is None: |
| 829 | mat_shape = (val_shape[0], val_shape[0]) |
| 830 | assert mat.shape == mat_shape |
| 831 | |
| 832 | val = torch.randn(val_shape).to(ctx) |
| 833 | |
| 834 | # nnz |
| 835 | assert mat.nnz == val.shape[0] |
| 836 | # dtype |
| 837 | assert mat.dtype == val.dtype |
| 838 | # device |
| 839 | assert mat.device == val.device |
| 840 | |
| 841 | # row, col, val |
| 842 | edge_index = torch.arange(len(val)).to(mat.device) |
| 843 | row, col = mat.coo() |
| 844 | val = mat.val |
| 845 | assert torch.allclose(row, edge_index) |
| 846 | assert torch.allclose(col, edge_index) |
| 847 | assert torch.allclose(val, val) |
| 848 | |
| 849 | |
| 850 | @pytest.mark.parametrize("shape", [(3, 3), (3, 5), (5, 3)]) |