()
| 613 | |
| 614 | |
| 615 | def test_simple_pool(): |
| 616 | g = dgl.from_networkx(nx.path_graph(15)).to(F.ctx()) |
| 617 | |
| 618 | sum_pool = nn.SumPooling() |
| 619 | avg_pool = nn.AvgPooling() |
| 620 | max_pool = nn.MaxPooling() |
| 621 | sort_pool = nn.SortPooling(10) # k = 10 |
| 622 | print(sum_pool, avg_pool, max_pool, sort_pool) |
| 623 | |
| 624 | # test#1: basic |
| 625 | h0 = F.randn((g.num_nodes(), 5)) |
| 626 | h1 = sum_pool(g, h0) |
| 627 | check_close(F.squeeze(h1, 0), F.sum(h0, 0)) |
| 628 | h1 = avg_pool(g, h0) |
| 629 | check_close(F.squeeze(h1, 0), F.mean(h0, 0)) |
| 630 | h1 = max_pool(g, h0) |
| 631 | check_close(F.squeeze(h1, 0), F.max(h0, 0)) |
| 632 | h1 = sort_pool(g, h0) |
| 633 | assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2 |
| 634 | |
| 635 | # test#2: batched graph |
| 636 | g_ = dgl.from_networkx(nx.path_graph(5)).to(F.ctx()) |
| 637 | bg = dgl.batch([g, g_, g, g_, g]) |
| 638 | h0 = F.randn((bg.num_nodes(), 5)) |
| 639 | h1 = sum_pool(bg, h0) |
| 640 | truth = mx.nd.stack( |
| 641 | F.sum(h0[:15], 0), |
| 642 | F.sum(h0[15:20], 0), |
| 643 | F.sum(h0[20:35], 0), |
| 644 | F.sum(h0[35:40], 0), |
| 645 | F.sum(h0[40:55], 0), |
| 646 | axis=0, |
| 647 | ) |
| 648 | check_close(h1, truth) |
| 649 | |
| 650 | h1 = avg_pool(bg, h0) |
| 651 | truth = mx.nd.stack( |
| 652 | F.mean(h0[:15], 0), |
| 653 | F.mean(h0[15:20], 0), |
| 654 | F.mean(h0[20:35], 0), |
| 655 | F.mean(h0[35:40], 0), |
| 656 | F.mean(h0[40:55], 0), |
| 657 | axis=0, |
| 658 | ) |
| 659 | check_close(h1, truth) |
| 660 | |
| 661 | h1 = max_pool(bg, h0) |
| 662 | truth = mx.nd.stack( |
| 663 | F.max(h0[:15], 0), |
| 664 | F.max(h0[15:20], 0), |
| 665 | F.max(h0[20:35], 0), |
| 666 | F.max(h0[35:40], 0), |
| 667 | F.max(h0[40:55], 0), |
| 668 | axis=0, |
| 669 | ) |
| 670 | check_close(h1, truth) |
| 671 | |
| 672 | h1 = sort_pool(bg, h0) |
no test coverage detected