Convert a single TVM Tensor to PyTorch tensor using faster DLPack converter.
(self, tvm_tensor: Any)
| 426 | return self._convert_single_tvm_to_pytorch(tvm_tensors) |
| 427 | |
| 428 | def _convert_single_tvm_to_pytorch(self, tvm_tensor: Any) -> "torch.Tensor": |
| 429 | """Convert a single TVM Tensor to PyTorch tensor using faster DLPack converter.""" |
| 430 | # pylint: disable=import-outside-toplevel |
| 431 | import torch |
| 432 | |
| 433 | if isinstance(tvm_tensor, torch.Tensor): |
| 434 | return tvm_tensor |
| 435 | if not isinstance(tvm_tensor, Tensor): |
| 436 | return torch.tensor(tvm_tensor) |
| 437 | |
| 438 | # 1. Try faster C++ DLPack converter |
| 439 | if _FASTER_DLPACK_EXTENSION is not None: |
| 440 | try: |
| 441 | return torch.from_dlpack(tvm_tensor) |
| 442 | except (AttributeError, ValueError): |
| 443 | pass # Fall through to the next method |
| 444 | |
| 445 | # 2. Try standard DLPack conversion |
| 446 | try: |
| 447 | return torch.from_dlpack(tvm_tensor) |
| 448 | # pylint: disable=broad-exception-caught |
| 449 | except Exception as error: |
| 450 | print(f"Warning: DLPack conversion from TVM failed ({error}), using numpy fallback") |
| 451 | numpy_array = tvm_tensor.numpy() |
| 452 | return torch.from_numpy(numpy_array) |
| 453 | |
| 454 | def get_function(self, name: str) -> Function | None: |
| 455 | """Get a compiled function by name.""" |
no test coverage detected