| 10 | |
| 11 | |
| 12 | def test_softmax(): |
| 13 | src = torch.tensor([1., 1., 1., 1.]) |
| 14 | index = torch.tensor([0, 0, 1, 2]) |
| 15 | ptr = torch.tensor([0, 2, 3, 4]) |
| 16 | |
| 17 | out = softmax(src, index) |
| 18 | assert out.tolist() == [0.5, 0.5, 1, 1] |
| 19 | assert softmax(src, ptr=ptr).tolist() == out.tolist() |
| 20 | |
| 21 | src = src.view(-1, 1) |
| 22 | out = softmax(src, index) |
| 23 | assert out.tolist() == [[0.5], [0.5], [1], [1]] |
| 24 | assert softmax(src, ptr=ptr).tolist() == out.tolist() |
| 25 | |
| 26 | jit = torch.jit.script(softmax) |
| 27 | assert torch.allclose(jit(src, index), out) |
| 28 | |
| 29 | |
| 30 | def test_softmax_backward(): |