| 521 | |
| 522 | @parametrize_idtype |
| 523 | def test_out_subgraph(idtype): |
| 524 | hg = dgl.heterograph( |
| 525 | { |
| 526 | ("user", "follow", "user"): ( |
| 527 | [1, 2, 3, 0, 2, 3, 0], |
| 528 | [0, 0, 0, 1, 1, 1, 2], |
| 529 | ), |
| 530 | ("user", "play", "game"): ([0, 0, 1, 3], [0, 1, 2, 2]), |
| 531 | ("game", "liked-by", "user"): ( |
| 532 | [2, 2, 2, 1, 1, 0], |
| 533 | [0, 1, 2, 0, 3, 0], |
| 534 | ), |
| 535 | ("user", "flips", "coin"): ([0, 1, 2, 3], [0, 0, 0, 0]), |
| 536 | }, |
| 537 | idtype=idtype, |
| 538 | ).to(F.ctx()) |
| 539 | subg = dgl.out_subgraph(hg, {"user": [0, 1], "game": 0}) |
| 540 | assert subg.idtype == idtype |
| 541 | assert len(subg.ntypes) == 3 |
| 542 | assert len(subg.etypes) == 4 |
| 543 | u, v = subg["follow"].edges() |
| 544 | edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) |
| 545 | assert edge_set == {(1, 0), (0, 1), (0, 2)} |
| 546 | assert F.array_equal( |
| 547 | hg["follow"].edge_ids(u, v), subg["follow"].edata[dgl.EID] |
| 548 | ) |
| 549 | u, v = subg["play"].edges() |
| 550 | edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) |
| 551 | assert edge_set == {(0, 0), (0, 1), (1, 2)} |
| 552 | assert F.array_equal(hg["play"].edge_ids(u, v), subg["play"].edata[dgl.EID]) |
| 553 | u, v = subg["liked-by"].edges() |
| 554 | edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) |
| 555 | assert edge_set == {(0, 0)} |
| 556 | assert F.array_equal( |
| 557 | hg["liked-by"].edge_ids(u, v), subg["liked-by"].edata[dgl.EID] |
| 558 | ) |
| 559 | u, v = subg["flips"].edges() |
| 560 | edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) |
| 561 | assert edge_set == {(0, 0), (1, 0)} |
| 562 | assert F.array_equal( |
| 563 | hg["flips"].edge_ids(u, v), subg["flips"].edata[dgl.EID] |
| 564 | ) |
| 565 | for ntype in subg.ntypes: |
| 566 | assert dgl.NID not in subg.nodes[ntype].data |
| 567 | |
| 568 | # Test store_ids |
| 569 | subg = dgl.out_subgraph(hg, {"user": [0, 1], "game": 0}, store_ids=False) |
| 570 | for etype in subg.canonical_etypes: |
| 571 | assert dgl.EID not in subg.edges[etype].data |
| 572 | for ntype in subg.ntypes: |
| 573 | assert dgl.NID not in subg.nodes[ntype].data |
| 574 | |
| 575 | # Test relabel nodes |
| 576 | subg = dgl.out_subgraph(hg, {"user": [1], "game": 0}, relabel_nodes=True) |
| 577 | assert subg.idtype == idtype |
| 578 | assert len(subg.ntypes) == 3 |
| 579 | assert len(subg.etypes) == 4 |
| 580 | |