()
| 67 | F._default_context_str == "cpu", reason="stream only runs on GPU." |
| 68 | ) |
| 69 | def test_set_get_stream(): |
| 70 | current_stream = torch.cuda.current_stream() |
| 71 | # test setting another stream |
| 72 | s = torch.cuda.Stream(device=F.ctx()) |
| 73 | torch.cuda.set_stream(s) |
| 74 | assert ( |
| 75 | to_dgl_stream_handle(s).value |
| 76 | == _dgl_get_stream(to_dgl_context(F.ctx())).value |
| 77 | ) |
| 78 | # revert to default stream |
| 79 | torch.cuda.set_stream(current_stream) |
| 80 | |
| 81 | |
| 82 | @unittest.skipIf( |
no test coverage detected