(use_uva)
| 314 | |
| 315 | @pytest.mark.parametrize("use_uva", [True, False]) |
| 316 | def test_pinsage_sampling(use_uva): |
| 317 | if use_uva and F.ctx() == F.cpu(): |
| 318 | pytest.skip("UVA sampling requires a GPU.") |
| 319 | |
| 320 | def _test_sampler(g, sampler, ntype): |
| 321 | seeds = F.copy_to(F.tensor([0, 2], dtype=g.idtype), F.ctx()) |
| 322 | neighbor_g = sampler(seeds) |
| 323 | assert neighbor_g.ntypes == [ntype] |
| 324 | u, v = neighbor_g.all_edges(form="uv", order="eid") |
| 325 | uv = list(zip(F.asnumpy(u).tolist(), F.asnumpy(v).tolist())) |
| 326 | assert (1, 0) in uv or (0, 0) in uv |
| 327 | assert (2, 2) in uv or (3, 2) in uv |
| 328 | |
| 329 | g = dgl.heterograph( |
| 330 | { |
| 331 | ("item", "bought-by", "user"): ( |
| 332 | [0, 0, 1, 1, 2, 2, 3, 3], |
| 333 | [0, 1, 0, 1, 2, 3, 2, 3], |
| 334 | ), |
| 335 | ("user", "bought", "item"): ( |
| 336 | [0, 1, 0, 1, 2, 3, 2, 3], |
| 337 | [0, 0, 1, 1, 2, 2, 3, 3], |
| 338 | ), |
| 339 | } |
| 340 | ) |
| 341 | if use_uva: |
| 342 | g.create_formats_() |
| 343 | g.pin_memory_() |
| 344 | elif F._default_context_str == "gpu": |
| 345 | g = g.to(F.ctx()) |
| 346 | try: |
| 347 | sampler = dgl.sampling.PinSAGESampler(g, "item", "user", 4, 0.5, 3, 2) |
| 348 | _test_sampler(g, sampler, "item") |
| 349 | sampler = dgl.sampling.RandomWalkNeighborSampler( |
| 350 | g, 4, 0.5, 3, 2, ["bought-by", "bought"] |
| 351 | ) |
| 352 | _test_sampler(g, sampler, "item") |
| 353 | sampler = dgl.sampling.RandomWalkNeighborSampler( |
| 354 | g, |
| 355 | 4, |
| 356 | 0.5, |
| 357 | 3, |
| 358 | 2, |
| 359 | [("item", "bought-by", "user"), ("user", "bought", "item")], |
| 360 | ) |
| 361 | _test_sampler(g, sampler, "item") |
| 362 | finally: |
| 363 | if g.is_pinned(): |
| 364 | g.unpin_memory_() |
| 365 | |
| 366 | g = dgl.graph(([0, 0, 1, 1, 2, 2, 3, 3], [0, 1, 0, 1, 2, 3, 2, 3])) |
| 367 | if use_uva: |
| 368 | g.create_formats_() |
| 369 | g.pin_memory_() |
| 370 | elif F._default_context_str == "gpu": |
| 371 | g = g.to(F.ctx()) |
| 372 | try: |
| 373 | sampler = dgl.sampling.RandomWalkNeighborSampler(g, 4, 0.5, 3, 2) |
no test coverage detected