MCPcopy
hub / github.com/pyg-team/pytorch_geometric / test_basic

Function test_basic

test/test_hash_tensor.py:37–59  ·  view source on GitHub ↗
(dtype, device)

Source from the content-addressed store, hash-verified

35@withHashTensor
36@pytest.mark.parametrize('dtype', KEY_DTYPES)
37def test_basic(dtype, device):
38 if dtype != torch.bool:
39 key = torch.tensor([2, 1, 0], dtype=dtype, device=device)
40 else:
41 key = torch.tensor([True, False], device=device)
42
43 tensor = HashTensor(key)
44 if tensor.is_cuda:
45 assert str(tensor) == (f"HashTensor({tensor.as_tensor().tolist()}, "
46 f"device='{tensor.device}')")
47 else:
48 assert str(tensor) == f"HashTensor({tensor.as_tensor().tolist()})"
49
50 assert tensor.dtype == torch.int64
51 assert tensor.device == device
52 assert tensor.size() == (key.size(0), )
53
54 value = torch.randn(key.size(0), 2, device=device)
55 tensor = HashTensor(key, value)
56 assert str(tensor).startswith("HashTensor([")
57 assert tensor.dtype == torch.float
58 assert tensor.device == device
59 assert tensor.size() == (key.size(0), 2)
60
61
62@withCUDA

Callers

nothing calls this directly

Calls 4

as_tensorMethod · 0.95
HashTensorClass · 0.90
tolistMethod · 0.45
sizeMethod · 0.45

Tested by

no test coverage detected