MCPcopy
hub / github.com/pyg-team/pytorch_geometric / test_softmax

Function test_softmax

test/utils/test_softmax.py:12–27  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

10
11
12def test_softmax():
13 src = torch.tensor([1., 1., 1., 1.])
14 index = torch.tensor([0, 0, 1, 2])
15 ptr = torch.tensor([0, 2, 3, 4])
16
17 out = softmax(src, index)
18 assert out.tolist() == [0.5, 0.5, 1, 1]
19 assert softmax(src, ptr=ptr).tolist() == out.tolist()
20
21 src = src.view(-1, 1)
22 out = softmax(src, index)
23 assert out.tolist() == [[0.5], [0.5], [1], [1]]
24 assert softmax(src, ptr=ptr).tolist() == out.tolist()
25
26 jit = torch.jit.script(softmax)
27 assert torch.allclose(jit(src, index), out)
28
29
30def test_softmax_backward():

Callers

nothing calls this directly

Calls 4

softmaxFunction · 0.90
jitFunction · 0.85
viewMethod · 0.80
tolistMethod · 0.45

Tested by

no test coverage detected