| 54 | |
| 55 | |
| 56 | def test_coalesce_jit(): |
| 57 | @torch.jit.script |
| 58 | def wrapper1(edge_index: Tensor) -> Tensor: |
| 59 | return coalesce(edge_index) |
| 60 | |
| 61 | @torch.jit.script |
| 62 | def wrapper2( |
| 63 | edge_index: Tensor, |
| 64 | edge_attr: Optional[Tensor], |
| 65 | ) -> Tuple[Tensor, Optional[Tensor]]: |
| 66 | return coalesce(edge_index, edge_attr) |
| 67 | |
| 68 | @torch.jit.script |
| 69 | def wrapper3( |
| 70 | edge_index: Tensor, |
| 71 | edge_attr: List[Tensor], |
| 72 | ) -> Tuple[Tensor, List[Tensor]]: |
| 73 | return coalesce(edge_index, edge_attr) |
| 74 | |
| 75 | edge_index = torch.tensor([[2, 1, 1, 0], [1, 2, 0, 1]]) |
| 76 | edge_attr = torch.tensor([[1], [2], [3], [4]]) |
| 77 | |
| 78 | out = wrapper1(edge_index) |
| 79 | assert out.size() == edge_index.size() |
| 80 | |
| 81 | out = wrapper2(edge_index, None) |
| 82 | assert out[0].size() == edge_index.size() |
| 83 | assert out[1] is None |
| 84 | |
| 85 | out = wrapper2(edge_index, edge_attr) |
| 86 | assert out[0].size() == edge_index.size() |
| 87 | assert out[1].size() == edge_attr.size() |
| 88 | |
| 89 | out = wrapper3(edge_index, [edge_attr, edge_attr.view(-1)]) |
| 90 | assert out[0].size() == edge_index.size() |
| 91 | assert len(out[1]) == 2 |
| 92 | assert out[1][0].size() == edge_attr.size() |
| 93 | assert out[1][1].size() == edge_attr.view(-1).size() |