(dtype, idtype)
| 291 | ) |
| 292 | @pytest.mark.parametrize("idtype", [torch.int32, torch.int64]) |
| 293 | def test_scatter_async(dtype, idtype): |
| 294 | input = torch.tensor([[2, 3], [5, 5], [20, 13]], dtype=dtype) |
| 295 | index = torch.ones([1], dtype=idtype) |
| 296 | res = torch.ops.graphbolt.scatter_async(input, index, input[2:3]) |
| 297 | assert torch.equal( |
| 298 | torch.tensor([[2, 3], [20, 13], [20, 13]], dtype=dtype), res.wait() |
| 299 | ) |
| 300 | |
| 301 | |
| 302 | def torch_expand_indptr(indptr, dtype, nodes=None): |