(shape, d)
| 850 | @pytest.mark.parametrize("shape", [(3, 3), (3, 5), (5, 3)]) |
| 851 | @pytest.mark.parametrize("d", [None, 2]) |
| 852 | def test_identity(shape, d): |
| 853 | ctx = F.ctx() |
| 854 | # creation |
| 855 | mat = identity(shape, d) |
| 856 | # shape |
| 857 | assert mat.shape == shape |
| 858 | # val |
| 859 | len_val = min(shape) |
| 860 | if d is None: |
| 861 | val_shape = len_val |
| 862 | else: |
| 863 | val_shape = (len_val, d) |
| 864 | val = torch.ones(val_shape) |
| 865 | assert torch.allclose(val, mat.val) |
| 866 | |
| 867 | |
| 868 | @pytest.mark.parametrize("val_shape", [(3,), (3, 2)]) |