MCPcopy
hub / github.com/apache/tvm / _convert_single_pytorch_to_tvm

Method _convert_single_pytorch_to_tvm

python/tvm/relax/base_py_module.py:373–418  ·  view source on GitHub ↗

Convert a single PyTorch tensor to TVM Tensor with faster DLPack converter.

(self, tensor: Any)

Source from the content-addressed store, hash-verified

371 return self._convert_single_pytorch_to_tvm(tensors)
372
373 def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor:
374 """Convert a single PyTorch tensor to TVM Tensor with faster DLPack converter."""
375 # pylint: disable=import-outside-toplevel
376 import torch
377
378 if isinstance(tensor, Tensor):
379 return tensor
380 if isinstance(tensor, torch.Tensor):
381 # 1. Try faster C++ DLPack converter
382 if _FASTER_DLPACK_EXTENSION is not None:
383 try:
384 dlpack = torch.to_dlpack(tensor)
385 return tvm.runtime.from_dlpack(dlpack)
386 except (AttributeError, ValueError):
387 pass # Fall through to the next method
388
389 # 2. Try modern `torch.to_dlpack` (preferred for PyTorch >= 1.7)
390 try:
391 dlpack = torch.to_dlpack(tensor)
392 return tvm.runtime.from_dlpack(dlpack)
393 except (AttributeError, ValueError):
394 pass # Fall through to the next method
395
396 # 3. Try legacy `torch.utils.dlpack.to_dlpack`
397 if to_dlpack_legacy:
398 try:
399 dlpack = to_dlpack_legacy(tensor)
400 return tvm.runtime.from_dlpack(dlpack)
401 except (AttributeError, ValueError) as error_legacy:
402 print(
403 f"Warning: Legacy DLPack conversion failed ({error_legacy}), "
404 f"using numpy fallback."
405 )
406
407 # 4. If all DLPack methods fail, use numpy fallback
408 numpy_array = tensor.detach().cpu().numpy()
409 return tvm.runtime.tensor(numpy_array, device=self.device)
410
411 # For other types (like scalars, lists), convert to numpy first
412 try:
413 numpy_array = np.array(tensor, dtype=np.float32)
414 return tvm.runtime.tensor(numpy_array, device=self.device)
415 except (TypeError, ValueError) as error:
416 raise TypeError(
417 f"Unsupported type for conversion to TVM Tensor: {type(tensor)}"
418 ) from error
419
420 def _convert_tvm_to_pytorch(
421 self, tvm_tensors: Any | list[Any]

Callers 1

Calls 3

printFunction · 0.85
numpyMethod · 0.80
cpuMethod · 0.45

Tested by

no test coverage detected