Create a scoped cuda stream, and synchronize it when the context is destroyed
()
| 30 | |
| 31 | @contextlib.contextmanager |
| 32 | def _scoped_stream(): |
| 33 | '''Create a scoped cuda stream, and synchronize it when the context is destroyed |
| 34 | ''' |
| 35 | #TODO: delete torch, use cuda native python bindings |
| 36 | import torch |
| 37 | stream = torch.cuda.current_stream() |
| 38 | try: |
| 39 | # return a handle, trt and other lib does not recognize torch.cuda.Stream |
| 40 | yield stream.cuda_stream |
| 41 | finally: |
| 42 | stream.synchronize() |
| 43 | |
| 44 | |
| 45 | @dataclass |
no test coverage detected