MCPcopy
hub / github.com/dmlc/dgl / test_module_add_self_loop

Function test_module_add_self_loop

tests/python/common/transforms/test_transform.py:2412–2504  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

2410
2411@parametrize_idtype
2412def test_module_add_self_loop(idtype):
2413 g = dgl.graph(([1, 1], [1, 2]), idtype=idtype, device=F.ctx())
2414 g.ndata["h"] = F.randn((g.num_nodes(), 2))
2415 g.edata["w"] = F.randn((g.num_edges(), 3))
2416
2417 # Case1: add self-loops with the default setting
2418 transform = dgl.AddSelfLoop()
2419 new_g = transform(g)
2420 assert new_g.device == g.device
2421 assert new_g.idtype == g.idtype
2422 assert new_g.num_nodes() == g.num_nodes()
2423 assert new_g.num_edges() == 4
2424 src, dst = new_g.edges()
2425 eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
2426 assert eset == {(0, 0), (1, 1), (1, 2), (2, 2)}
2427 assert "h" in new_g.ndata
2428 assert "w" in new_g.edata
2429
2430 # Case2: remove self-loops first to avoid duplicate ones
2431 transform = dgl.AddSelfLoop(allow_duplicate=True)
2432 new_g = transform(g)
2433 assert new_g.device == g.device
2434 assert new_g.idtype == g.idtype
2435 assert new_g.num_nodes() == g.num_nodes()
2436 assert new_g.num_edges() == 5
2437 src, dst = new_g.edges()
2438 eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
2439 assert eset == {(0, 0), (1, 1), (1, 2), (2, 2)}
2440 assert "h" in new_g.ndata
2441 assert "w" in new_g.edata
2442
2443 # Case3: add self-loops for a homogeneous graph (the example in doc)
2444 transform = dgl.AddSelfLoop(fill_data="sum")
2445 g = dgl.graph(([0, 0, 2], [2, 1, 0]), idtype=idtype, device=F.ctx())
2446 new_g = transform(g)
2447 assert new_g.device == g.device
2448 assert new_g.idtype == g.idtype
2449 assert new_g.num_nodes() == g.num_nodes()
2450 src, dst = new_g.edges()
2451 eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
2452 assert eset == {(0, 2), (0, 1), (2, 0), (0, 0), (1, 1), (2, 2)}
2453
2454 # Create a heterogeneous graph
2455 g = dgl.heterograph(
2456 {
2457 ("user", "plays", "game"): ([0], [1]),
2458 ("user", "follows", "user"): ([1], [3]),
2459 },
2460 idtype=idtype,
2461 device=F.ctx(),
2462 )
2463 g.nodes["user"].data["h1"] = F.randn((4, 2))
2464 g.edges["plays"].data["w1"] = F.randn((1, 3))
2465 g.nodes["game"].data["h2"] = F.randn((2, 4))
2466 g.edges["follows"].data["w2"] = F.randn((1, 5))
2467
2468 # Case4: add self-loops for a heterogeneous graph
2469 new_g = transform(g)

Callers

nothing calls this directly

Calls 7

transformFunction · 0.85
asnumpyMethod · 0.80
graphMethod · 0.45
ctxMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45
edgesMethod · 0.45

Tested by

no test coverage detected