(idtype)
| 177 | ) |
| 178 | @parametrize_idtype |
| 179 | def test_subgraph_mask(idtype): |
| 180 | g = create_test_heterograph(idtype) |
| 181 | g_graph = g["follows"] |
| 182 | g_bipartite = g["plays"] |
| 183 | |
| 184 | x = F.randn((3, 5)) |
| 185 | y = F.randn((2, 4)) |
| 186 | g.nodes["user"].data["h"] = x |
| 187 | g.edges["follows"].data["h"] = y |
| 188 | |
| 189 | def _check_subgraph(g, sg): |
| 190 | assert sg.idtype == g.idtype |
| 191 | assert sg.device == g.device |
| 192 | assert sg.ntypes == g.ntypes |
| 193 | assert sg.etypes == g.etypes |
| 194 | assert sg.canonical_etypes == g.canonical_etypes |
| 195 | assert F.array_equal( |
| 196 | F.tensor(sg.nodes["user"].data[dgl.NID]), F.tensor([1, 2], idtype) |
| 197 | ) |
| 198 | assert F.array_equal( |
| 199 | F.tensor(sg.nodes["game"].data[dgl.NID]), F.tensor([0], idtype) |
| 200 | ) |
| 201 | assert F.array_equal( |
| 202 | F.tensor(sg.edges["follows"].data[dgl.EID]), F.tensor([1], idtype) |
| 203 | ) |
| 204 | assert F.array_equal( |
| 205 | F.tensor(sg.edges["plays"].data[dgl.EID]), F.tensor([1], idtype) |
| 206 | ) |
| 207 | assert F.array_equal( |
| 208 | F.tensor(sg.edges["wishes"].data[dgl.EID]), F.tensor([1], idtype) |
| 209 | ) |
| 210 | assert sg.num_nodes("developer") == 0 |
| 211 | assert sg.num_edges("develops") == 0 |
| 212 | assert F.array_equal( |
| 213 | sg.nodes["user"].data["h"], g.nodes["user"].data["h"][1:3] |
| 214 | ) |
| 215 | assert F.array_equal( |
| 216 | sg.edges["follows"].data["h"], g.edges["follows"].data["h"][1:2] |
| 217 | ) |
| 218 | |
| 219 | sg1 = g.subgraph( |
| 220 | { |
| 221 | "user": F.tensor([False, True, True], dtype=F.bool), |
| 222 | "game": F.tensor([True, False, False, False], dtype=F.bool), |
| 223 | } |
| 224 | ) |
| 225 | _check_subgraph(g, sg1) |
| 226 | sg2 = g.edge_subgraph( |
| 227 | { |
| 228 | "follows": F.tensor([False, True], dtype=F.bool), |
| 229 | "plays": F.tensor([False, True, False, False], dtype=F.bool), |
| 230 | "wishes": F.tensor([False, True], dtype=F.bool), |
| 231 | } |
| 232 | ) |
| 233 | _check_subgraph(g, sg2) |
| 234 | |
| 235 | |
| 236 | @parametrize_idtype |
nothing calls this directly
no test coverage detected