(dtype, device)
| 35 | @withHashTensor |
| 36 | @pytest.mark.parametrize('dtype', KEY_DTYPES) |
| 37 | def test_basic(dtype, device): |
| 38 | if dtype != torch.bool: |
| 39 | key = torch.tensor([2, 1, 0], dtype=dtype, device=device) |
| 40 | else: |
| 41 | key = torch.tensor([True, False], device=device) |
| 42 | |
| 43 | tensor = HashTensor(key) |
| 44 | if tensor.is_cuda: |
| 45 | assert str(tensor) == (f"HashTensor({tensor.as_tensor().tolist()}, " |
| 46 | f"device='{tensor.device}')") |
| 47 | else: |
| 48 | assert str(tensor) == f"HashTensor({tensor.as_tensor().tolist()})" |
| 49 | |
| 50 | assert tensor.dtype == torch.int64 |
| 51 | assert tensor.device == device |
| 52 | assert tensor.size() == (key.size(0), ) |
| 53 | |
| 54 | value = torch.randn(key.size(0), 2, device=device) |
| 55 | tensor = HashTensor(key, value) |
| 56 | assert str(tensor).startswith("HashTensor([") |
| 57 | assert tensor.dtype == torch.float |
| 58 | assert tensor.device == device |
| 59 | assert tensor.size() == (key.size(0), 2) |
| 60 | |
| 61 | |
| 62 | @withCUDA |
nothing calls this directly
no test coverage detected