()
| 13 | |
| 14 | |
| 15 | def test_gine_layer(): |
| 16 | layer = GINELayer() |
| 17 | x = torch.eye(4) |
| 18 | edge_index = (torch.tensor([0, 0, 0, 1, 1, 2]), torch.tensor([1, 2, 3, 2, 3, 3])) |
| 19 | graph = Graph(x=x, edge_index=edge_index, edge_attr=torch.randn(6, 4)) |
| 20 | x = layer(graph, x) |
| 21 | assert tuple(x.shape) == (4, 4) |
| 22 | |
| 23 | |
| 24 | if __name__ == "__main__": |
no test coverage detected