(dtype, idtype, pinned)
| 256 | @pytest.mark.parametrize("idtype", [torch.int32, torch.int64]) |
| 257 | @pytest.mark.parametrize("pinned", [False, True]) |
| 258 | def test_index_select(dtype, idtype, pinned): |
| 259 | if F._default_context_str != "gpu" and pinned: |
| 260 | pytest.skip("Pinned tests are available only on GPU.") |
| 261 | tensor = torch.tensor([[2, 3], [5, 5], [20, 13]], dtype=dtype) |
| 262 | tensor = tensor.pin_memory() if pinned else tensor.to(F.ctx()) |
| 263 | index = torch.tensor([0, 2], dtype=idtype, device=F.ctx()) |
| 264 | gb_result = gb.index_select(tensor, index) |
| 265 | torch_result = tensor.to(F.ctx())[index.long()] |
| 266 | assert torch.equal(torch_result, gb_result) |
| 267 | if pinned: |
| 268 | gb_result = gb.index_select(tensor.cpu(), index.cpu().pin_memory()) |
| 269 | assert torch.equal(torch_result.cpu(), gb_result) |
| 270 | assert gb_result.is_pinned() |
| 271 | |
| 272 | # Test the internal async API |
| 273 | future = torch.ops.graphbolt.index_select_async(tensor.cpu(), index.cpu()) |
| 274 | assert torch.equal(torch_result.cpu(), future.wait()) |
| 275 | |
| 276 | |
| 277 | @pytest.mark.parametrize( |
nothing calls this directly
no test coverage detected