MCPcopy
hub / github.com/dmlc/dgl / test_non_uniform_random_walk

Function test_non_uniform_random_walk

tests/python/common/sampling/test_sampling.py:42–146  ·  view source on GitHub ↗
(use_uva)

Source from the content-addressed store, hash-verified

40
41@pytest.mark.parametrize("use_uva", [True, False])
42def test_non_uniform_random_walk(use_uva):
43 if use_uva:
44 if F.ctx() == F.cpu():
45 pytest.skip("UVA biased random walk requires a GPU.")
46 if dgl.backend.backend_name != "pytorch":
47 pytest.skip(
48 "UVA biased random walk is only supported with PyTorch."
49 )
50 g2 = dgl.heterograph(
51 {("user", "follow", "user"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0])}
52 )
53 g4 = dgl.heterograph(
54 {
55 ("user", "follow", "user"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]),
56 ("user", "view", "item"): ([0, 0, 1, 2, 3, 3], [0, 1, 1, 2, 2, 1]),
57 ("item", "viewed-by", "user"): (
58 [0, 1, 1, 2, 2, 1],
59 [0, 0, 1, 2, 3, 3],
60 ),
61 }
62 )
63
64 g2.edata["p"] = F.copy_to(
65 F.tensor([3, 0, 3, 3, 3], dtype=F.float32), F.cpu()
66 )
67 g2.edata["p2"] = F.copy_to(
68 F.tensor([[3], [0], [3], [3], [3]], dtype=F.float32), F.cpu()
69 )
70 g4.edges["follow"].data["p"] = F.copy_to(
71 F.tensor([3, 0, 3, 3, 3], dtype=F.float32), F.cpu()
72 )
73 g4.edges["viewed-by"].data["p"] = F.copy_to(
74 F.tensor([1, 1, 1, 1, 1, 1], dtype=F.float32), F.cpu()
75 )
76
77 if use_uva:
78 for g in (g2, g4):
79 g.create_formats_()
80 g.pin_memory_()
81 elif F._default_context_str == "gpu":
82 g2 = g2.to(F.ctx())
83 g4 = g4.to(F.ctx())
84
85 try:
86 traces, eids, ntypes = dgl.sampling.random_walk(
87 g2,
88 F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g2.idtype),
89 length=4,
90 prob="p",
91 return_eids=True,
92 )
93 check_random_walk(
94 g2, ["follow"] * 4, traces, ntypes, "p", trace_eids=eids
95 )
96
97 with pytest.raises(dgl.DGLError):
98 traces, ntypes = dgl.sampling.random_walk(
99 g2,

Callers 1

test_sampling.pyFile · 0.85

Calls 9

check_random_walkFunction · 0.85
asnumpyMethod · 0.80
ctxMethod · 0.45
cpuMethod · 0.45
copy_toMethod · 0.45
create_formats_Method · 0.45
pin_memory_Method · 0.45
toMethod · 0.45
unpin_memory_Method · 0.45

Tested by

no test coverage detected