MCPcopy
hub / github.com/dmlc/dgl / test_pinsage_sampling

Function test_pinsage_sampling

tests/python/common/sampling/test_sampling.py:316–398  ·  view source on GitHub ↗
(use_uva)

Source from the content-addressed store, hash-verified

314
315@pytest.mark.parametrize("use_uva", [True, False])
316def test_pinsage_sampling(use_uva):
317 if use_uva and F.ctx() == F.cpu():
318 pytest.skip("UVA sampling requires a GPU.")
319
320 def _test_sampler(g, sampler, ntype):
321 seeds = F.copy_to(F.tensor([0, 2], dtype=g.idtype), F.ctx())
322 neighbor_g = sampler(seeds)
323 assert neighbor_g.ntypes == [ntype]
324 u, v = neighbor_g.all_edges(form="uv", order="eid")
325 uv = list(zip(F.asnumpy(u).tolist(), F.asnumpy(v).tolist()))
326 assert (1, 0) in uv or (0, 0) in uv
327 assert (2, 2) in uv or (3, 2) in uv
328
329 g = dgl.heterograph(
330 {
331 ("item", "bought-by", "user"): (
332 [0, 0, 1, 1, 2, 2, 3, 3],
333 [0, 1, 0, 1, 2, 3, 2, 3],
334 ),
335 ("user", "bought", "item"): (
336 [0, 1, 0, 1, 2, 3, 2, 3],
337 [0, 0, 1, 1, 2, 2, 3, 3],
338 ),
339 }
340 )
341 if use_uva:
342 g.create_formats_()
343 g.pin_memory_()
344 elif F._default_context_str == "gpu":
345 g = g.to(F.ctx())
346 try:
347 sampler = dgl.sampling.PinSAGESampler(g, "item", "user", 4, 0.5, 3, 2)
348 _test_sampler(g, sampler, "item")
349 sampler = dgl.sampling.RandomWalkNeighborSampler(
350 g, 4, 0.5, 3, 2, ["bought-by", "bought"]
351 )
352 _test_sampler(g, sampler, "item")
353 sampler = dgl.sampling.RandomWalkNeighborSampler(
354 g,
355 4,
356 0.5,
357 3,
358 2,
359 [("item", "bought-by", "user"), ("user", "bought", "item")],
360 )
361 _test_sampler(g, sampler, "item")
362 finally:
363 if g.is_pinned():
364 g.unpin_memory_()
365
366 g = dgl.graph(([0, 0, 1, 1, 2, 2, 3, 3], [0, 1, 0, 1, 2, 3, 2, 3]))
367 if use_uva:
368 g.create_formats_()
369 g.pin_memory_()
370 elif F._default_context_str == "gpu":
371 g = g.to(F.ctx())
372 try:
373 sampler = dgl.sampling.RandomWalkNeighborSampler(g, 4, 0.5, 3, 2)

Callers 1

test_sampling.pyFile · 0.85

Calls 9

_test_samplerFunction · 0.85
ctxMethod · 0.45
cpuMethod · 0.45
create_formats_Method · 0.45
pin_memory_Method · 0.45
toMethod · 0.45
is_pinnedMethod · 0.45
unpin_memory_Method · 0.45
graphMethod · 0.45

Tested by

no test coverage detected