MCPcopy
hub / github.com/dmlc/dgl / test_module_heat_kernel

Function test_module_heat_kernel

tests/python/common/transforms/test_transform.py:2916–2936  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

2914)
2915@parametrize_idtype
2916def test_module_heat_kernel(idtype):
2917 # Case1: directed graph
2918 g = dgl.graph(
2919 ([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, device=F.ctx()
2920 )
2921 g.ndata["h"] = F.randn((6, 2))
2922 transform = dgl.HeatKernel(avg_degree=1)
2923 new_g = transform(g)
2924 assert new_g.idtype == g.idtype
2925 assert new_g.device == g.device
2926 assert new_g.num_nodes() == g.num_nodes()
2927 assert F.allclose(g.ndata["h"], new_g.ndata["h"])
2928 assert "w" in new_g.edata
2929
2930 # Case2: weighted undirected graph
2931 g = dgl.graph(([0, 1, 2, 3], [1, 0, 3, 2]), idtype=idtype, device=F.ctx())
2932 g.edata["w"] = F.tensor([0.1, 0.2, 0.3, 0.4])
2933 new_g = transform(g)
2934 src, dst = new_g.edges()
2935 eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
2936 assert eset == {(0, 0), (1, 1), (2, 2), (3, 3)}
2937
2938
2939@unittest.skipIf(

Callers 1

test_transform.pyFile · 0.85

Calls 6

transformFunction · 0.85
asnumpyMethod · 0.80
graphMethod · 0.45
ctxMethod · 0.45
num_nodesMethod · 0.45
edgesMethod · 0.45

Tested by

no test coverage detected