MCPcopy
hub / github.com/dmlc/dgl / test_simple_pool

Function test_simple_pool

tests/python/mxnet/test_nn.py:615–673  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

613
614
615def test_simple_pool():
616 g = dgl.from_networkx(nx.path_graph(15)).to(F.ctx())
617
618 sum_pool = nn.SumPooling()
619 avg_pool = nn.AvgPooling()
620 max_pool = nn.MaxPooling()
621 sort_pool = nn.SortPooling(10) # k = 10
622 print(sum_pool, avg_pool, max_pool, sort_pool)
623
624 # test#1: basic
625 h0 = F.randn((g.num_nodes(), 5))
626 h1 = sum_pool(g, h0)
627 check_close(F.squeeze(h1, 0), F.sum(h0, 0))
628 h1 = avg_pool(g, h0)
629 check_close(F.squeeze(h1, 0), F.mean(h0, 0))
630 h1 = max_pool(g, h0)
631 check_close(F.squeeze(h1, 0), F.max(h0, 0))
632 h1 = sort_pool(g, h0)
633 assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
634
635 # test#2: batched graph
636 g_ = dgl.from_networkx(nx.path_graph(5)).to(F.ctx())
637 bg = dgl.batch([g, g_, g, g_, g])
638 h0 = F.randn((bg.num_nodes(), 5))
639 h1 = sum_pool(bg, h0)
640 truth = mx.nd.stack(
641 F.sum(h0[:15], 0),
642 F.sum(h0[15:20], 0),
643 F.sum(h0[20:35], 0),
644 F.sum(h0[35:40], 0),
645 F.sum(h0[40:55], 0),
646 axis=0,
647 )
648 check_close(h1, truth)
649
650 h1 = avg_pool(bg, h0)
651 truth = mx.nd.stack(
652 F.mean(h0[:15], 0),
653 F.mean(h0[15:20], 0),
654 F.mean(h0[20:35], 0),
655 F.mean(h0[35:40], 0),
656 F.mean(h0[40:55], 0),
657 axis=0,
658 )
659 check_close(h1, truth)
660
661 h1 = max_pool(bg, h0)
662 truth = mx.nd.stack(
663 F.max(h0[:15], 0),
664 F.max(h0[15:20], 0),
665 F.max(h0[20:35], 0),
666 F.max(h0[35:40], 0),
667 F.max(h0[40:55], 0),
668 axis=0,
669 )
670 check_close(h1, truth)
671
672 h1 = sort_pool(bg, h0)

Callers 1

test_nn.pyFile · 0.70

Calls 4

check_closeFunction · 0.85
toMethod · 0.45
ctxMethod · 0.45
num_nodesMethod · 0.45

Tested by

no test coverage detected