(use_uva)
| 148 | |
| 149 | @pytest.mark.parametrize("use_uva", [True, False]) |
| 150 | def test_uniform_random_walk(use_uva): |
| 151 | if use_uva and F.ctx() == F.cpu(): |
| 152 | pytest.skip("UVA random walk requires a GPU.") |
| 153 | g1 = dgl.heterograph({("user", "follow", "user"): ([0, 1, 2], [1, 2, 0])}) |
| 154 | g2 = dgl.heterograph( |
| 155 | {("user", "follow", "user"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0])} |
| 156 | ) |
| 157 | g3 = dgl.heterograph( |
| 158 | { |
| 159 | ("user", "follow", "user"): ([0, 1, 2], [1, 2, 0]), |
| 160 | ("user", "view", "item"): ([0, 1, 2], [0, 1, 2]), |
| 161 | ("item", "viewed-by", "user"): ([0, 1, 2], [0, 1, 2]), |
| 162 | } |
| 163 | ) |
| 164 | g4 = dgl.heterograph( |
| 165 | { |
| 166 | ("user", "follow", "user"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]), |
| 167 | ("user", "view", "item"): ([0, 0, 1, 2, 3, 3], [0, 1, 1, 2, 2, 1]), |
| 168 | ("item", "viewed-by", "user"): ( |
| 169 | [0, 1, 1, 2, 2, 1], |
| 170 | [0, 0, 1, 2, 3, 3], |
| 171 | ), |
| 172 | } |
| 173 | ) |
| 174 | |
| 175 | if use_uva: |
| 176 | for g in (g1, g2, g3, g4): |
| 177 | g.create_formats_() |
| 178 | g.pin_memory_() |
| 179 | elif F._default_context_str == "gpu": |
| 180 | g1 = g1.to(F.ctx()) |
| 181 | g2 = g2.to(F.ctx()) |
| 182 | g3 = g3.to(F.ctx()) |
| 183 | g4 = g4.to(F.ctx()) |
| 184 | |
| 185 | try: |
| 186 | traces, eids, ntypes = dgl.sampling.random_walk( |
| 187 | g1, |
| 188 | F.tensor([0, 1, 2, 0, 1, 2], dtype=g1.idtype), |
| 189 | length=4, |
| 190 | return_eids=True, |
| 191 | ) |
| 192 | check_random_walk(g1, ["follow"] * 4, traces, ntypes, trace_eids=eids) |
| 193 | if F._default_context_str == "cpu": |
| 194 | with pytest.raises(dgl.DGLError): |
| 195 | dgl.sampling.random_walk( |
| 196 | g1, |
| 197 | F.tensor([0, 1, 2, 10], dtype=g1.idtype), |
| 198 | length=4, |
| 199 | return_eids=True, |
| 200 | ) |
| 201 | traces, eids, ntypes = dgl.sampling.random_walk( |
| 202 | g1, |
| 203 | F.tensor([0, 1, 2, 0, 1, 2], dtype=g1.idtype), |
| 204 | length=4, |
| 205 | restart_prob=0.0, |
| 206 | return_eids=True, |
| 207 | ) |
no test coverage detected