| 2410 | |
| 2411 | @parametrize_idtype |
| 2412 | def test_module_add_self_loop(idtype): |
| 2413 | g = dgl.graph(([1, 1], [1, 2]), idtype=idtype, device=F.ctx()) |
| 2414 | g.ndata["h"] = F.randn((g.num_nodes(), 2)) |
| 2415 | g.edata["w"] = F.randn((g.num_edges(), 3)) |
| 2416 | |
| 2417 | # Case1: add self-loops with the default setting |
| 2418 | transform = dgl.AddSelfLoop() |
| 2419 | new_g = transform(g) |
| 2420 | assert new_g.device == g.device |
| 2421 | assert new_g.idtype == g.idtype |
| 2422 | assert new_g.num_nodes() == g.num_nodes() |
| 2423 | assert new_g.num_edges() == 4 |
| 2424 | src, dst = new_g.edges() |
| 2425 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2426 | assert eset == {(0, 0), (1, 1), (1, 2), (2, 2)} |
| 2427 | assert "h" in new_g.ndata |
| 2428 | assert "w" in new_g.edata |
| 2429 | |
| 2430 | # Case2: remove self-loops first to avoid duplicate ones |
| 2431 | transform = dgl.AddSelfLoop(allow_duplicate=True) |
| 2432 | new_g = transform(g) |
| 2433 | assert new_g.device == g.device |
| 2434 | assert new_g.idtype == g.idtype |
| 2435 | assert new_g.num_nodes() == g.num_nodes() |
| 2436 | assert new_g.num_edges() == 5 |
| 2437 | src, dst = new_g.edges() |
| 2438 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2439 | assert eset == {(0, 0), (1, 1), (1, 2), (2, 2)} |
| 2440 | assert "h" in new_g.ndata |
| 2441 | assert "w" in new_g.edata |
| 2442 | |
| 2443 | # Case3: add self-loops for a homogeneous graph (the example in doc) |
| 2444 | transform = dgl.AddSelfLoop(fill_data="sum") |
| 2445 | g = dgl.graph(([0, 0, 2], [2, 1, 0]), idtype=idtype, device=F.ctx()) |
| 2446 | new_g = transform(g) |
| 2447 | assert new_g.device == g.device |
| 2448 | assert new_g.idtype == g.idtype |
| 2449 | assert new_g.num_nodes() == g.num_nodes() |
| 2450 | src, dst = new_g.edges() |
| 2451 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2452 | assert eset == {(0, 2), (0, 1), (2, 0), (0, 0), (1, 1), (2, 2)} |
| 2453 | |
| 2454 | # Create a heterogeneous graph |
| 2455 | g = dgl.heterograph( |
| 2456 | { |
| 2457 | ("user", "plays", "game"): ([0], [1]), |
| 2458 | ("user", "follows", "user"): ([1], [3]), |
| 2459 | }, |
| 2460 | idtype=idtype, |
| 2461 | device=F.ctx(), |
| 2462 | ) |
| 2463 | g.nodes["user"].data["h1"] = F.randn((4, 2)) |
| 2464 | g.edges["plays"].data["w1"] = F.randn((1, 3)) |
| 2465 | g.nodes["game"].data["h2"] = F.randn((2, 4)) |
| 2466 | g.edges["follows"].data["w2"] = F.randn((1, 5)) |
| 2467 | |
| 2468 | # Case4: add self-loops for a heterogeneous graph |
| 2469 | new_g = transform(g) |