Convert PyTorch tensors to TVM Tensors using DLPack.
(
self, tensors: Any | list[Any] | tuple[Any, ...]
)
| 360 | return dtype_mapping.get(str(tvm_dtype), torch.float32) |
| 361 | |
| 362 | def _convert_pytorch_to_tvm( |
| 363 | self, tensors: Any | list[Any] | tuple[Any, ...] |
| 364 | ) -> Tensor | list[Tensor]: |
| 365 | """Convert PyTorch tensors to TVM Tensors using DLPack.""" |
| 366 | # pylint: disable=import-outside-toplevel |
| 367 | import torch |
| 368 | |
| 369 | if isinstance(tensors, list | tuple): |
| 370 | return [self._convert_single_pytorch_to_tvm(t) for t in tensors] |
| 371 | return self._convert_single_pytorch_to_tvm(tensors) |
| 372 | |
| 373 | def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor: |
| 374 | """Convert a single PyTorch tensor to TVM Tensor with faster DLPack converter.""" |