MCPcopy
hub / github.com/dmlc/dgl / test_uniform_random_walk

Function test_uniform_random_walk

tests/python/common/sampling/test_sampling.py:150–266  ·  view source on GitHub ↗
(use_uva)

Source from the content-addressed store, hash-verified

148
149@pytest.mark.parametrize("use_uva", [True, False])
150def test_uniform_random_walk(use_uva):
151 if use_uva and F.ctx() == F.cpu():
152 pytest.skip("UVA random walk requires a GPU.")
153 g1 = dgl.heterograph({("user", "follow", "user"): ([0, 1, 2], [1, 2, 0])})
154 g2 = dgl.heterograph(
155 {("user", "follow", "user"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0])}
156 )
157 g3 = dgl.heterograph(
158 {
159 ("user", "follow", "user"): ([0, 1, 2], [1, 2, 0]),
160 ("user", "view", "item"): ([0, 1, 2], [0, 1, 2]),
161 ("item", "viewed-by", "user"): ([0, 1, 2], [0, 1, 2]),
162 }
163 )
164 g4 = dgl.heterograph(
165 {
166 ("user", "follow", "user"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]),
167 ("user", "view", "item"): ([0, 0, 1, 2, 3, 3], [0, 1, 1, 2, 2, 1]),
168 ("item", "viewed-by", "user"): (
169 [0, 1, 1, 2, 2, 1],
170 [0, 0, 1, 2, 3, 3],
171 ),
172 }
173 )
174
175 if use_uva:
176 for g in (g1, g2, g3, g4):
177 g.create_formats_()
178 g.pin_memory_()
179 elif F._default_context_str == "gpu":
180 g1 = g1.to(F.ctx())
181 g2 = g2.to(F.ctx())
182 g3 = g3.to(F.ctx())
183 g4 = g4.to(F.ctx())
184
185 try:
186 traces, eids, ntypes = dgl.sampling.random_walk(
187 g1,
188 F.tensor([0, 1, 2, 0, 1, 2], dtype=g1.idtype),
189 length=4,
190 return_eids=True,
191 )
192 check_random_walk(g1, ["follow"] * 4, traces, ntypes, trace_eids=eids)
193 if F._default_context_str == "cpu":
194 with pytest.raises(dgl.DGLError):
195 dgl.sampling.random_walk(
196 g1,
197 F.tensor([0, 1, 2, 10], dtype=g1.idtype),
198 length=4,
199 return_eids=True,
200 )
201 traces, eids, ntypes = dgl.sampling.random_walk(
202 g1,
203 F.tensor([0, 1, 2, 0, 1, 2], dtype=g1.idtype),
204 length=4,
205 restart_prob=0.0,
206 return_eids=True,
207 )

Callers 1

test_sampling.pyFile · 0.85

Calls 9

check_random_walkFunction · 0.85
asnumpyMethod · 0.80
ctxMethod · 0.45
cpuMethod · 0.45
create_formats_Method · 0.45
pin_memory_Method · 0.45
toMethod · 0.45
is_pinnedMethod · 0.45
unpin_memory_Method · 0.45

Tested by

no test coverage detected