| 433 | |
| 434 | @parametrize_idtype |
| 435 | def test_in_subgraph(idtype): |
| 436 | hg = dgl.heterograph( |
| 437 | { |
| 438 | ("user", "follow", "user"): ( |
| 439 | [1, 2, 3, 0, 2, 3, 0], |
| 440 | [0, 0, 0, 1, 1, 1, 2], |
| 441 | ), |
| 442 | ("user", "play", "game"): ([0, 0, 1, 3], [0, 1, 2, 2]), |
| 443 | ("game", "liked-by", "user"): ( |
| 444 | [2, 2, 2, 1, 1, 0], |
| 445 | [0, 1, 2, 0, 3, 0], |
| 446 | ), |
| 447 | ("user", "flips", "coin"): ([0, 1, 2, 3], [0, 0, 0, 0]), |
| 448 | }, |
| 449 | idtype=idtype, |
| 450 | num_nodes_dict={"user": 5, "game": 10, "coin": 8}, |
| 451 | ).to(F.ctx()) |
| 452 | subg = dgl.in_subgraph(hg, {"user": [0, 1], "game": 0}) |
| 453 | assert subg.idtype == idtype |
| 454 | assert len(subg.ntypes) == 3 |
| 455 | assert len(subg.etypes) == 4 |
| 456 | u, v = subg["follow"].edges() |
| 457 | edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) |
| 458 | assert F.array_equal( |
| 459 | hg["follow"].edge_ids(u, v), subg["follow"].edata[dgl.EID] |
| 460 | ) |
| 461 | assert edge_set == {(1, 0), (2, 0), (3, 0), (0, 1), (2, 1), (3, 1)} |
| 462 | u, v = subg["play"].edges() |
| 463 | edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) |
| 464 | assert F.array_equal(hg["play"].edge_ids(u, v), subg["play"].edata[dgl.EID]) |
| 465 | assert edge_set == {(0, 0)} |
| 466 | u, v = subg["liked-by"].edges() |
| 467 | edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) |
| 468 | assert F.array_equal( |
| 469 | hg["liked-by"].edge_ids(u, v), subg["liked-by"].edata[dgl.EID] |
| 470 | ) |
| 471 | assert edge_set == {(2, 0), (2, 1), (1, 0), (0, 0)} |
| 472 | assert subg["flips"].num_edges() == 0 |
| 473 | for ntype in subg.ntypes: |
| 474 | assert dgl.NID not in subg.nodes[ntype].data |
| 475 | |
| 476 | # Test store_ids |
| 477 | subg = dgl.in_subgraph(hg, {"user": [0, 1], "game": 0}, store_ids=False) |
| 478 | for etype in ["follow", "play", "liked-by"]: |
| 479 | assert dgl.EID not in subg.edges[etype].data |
| 480 | for ntype in subg.ntypes: |
| 481 | assert dgl.NID not in subg.nodes[ntype].data |
| 482 | |
| 483 | # Test relabel nodes |
| 484 | subg = dgl.in_subgraph(hg, {"user": [0, 1], "game": 0}, relabel_nodes=True) |
| 485 | assert subg.idtype == idtype |
| 486 | assert len(subg.ntypes) == 3 |
| 487 | assert len(subg.etypes) == 4 |
| 488 | |
| 489 | u, v = subg["follow"].edges() |
| 490 | old_u = F.gather_row(subg.nodes["user"].data[dgl.NID], u) |
| 491 | old_v = F.gather_row(subg.nodes["user"].data[dgl.NID], v) |
| 492 | assert F.array_equal( |