MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / torch_to_numpy

Function torch_to_numpy

tensorrt_llm/_utils.py:56–64  ·  view source on GitHub ↗
(x: torch.Tensor)

Source from the content-addressed store, hash-verified

54
55
56def torch_to_numpy(x: torch.Tensor):
57 assert isinstance(x, torch.Tensor), \
58 f'x must be a torch.Tensor object, but got {type(x)}.'
59 if x.dtype == torch.bfloat16:
60 return x.view(torch.int16).detach().cpu().numpy().view(np_bfloat16)
61 elif x.dtype == torch.float8_e4m3fn:
62 return x.view(torch.int8).detach().cpu().numpy().view(np_float8)
63 else:
64 return x.detach().cpu().numpy()
65
66
67def numpy_to_torch(x):

Callers 15

_construct_executionMethod · 0.90
_construct_executionMethod · 0.90
_run_matmulMethod · 0.90
create_trt_sessionMethod · 0.90
MLPMethod · 0.90
set_fp4_scalesMethod · 0.90
set_weight_layerMethod · 0.90
create_trt_sessionMethod · 0.90
_construct_executionFunction · 0.90

Calls 1

viewMethod · 0.45

Tested by 15

_construct_executionMethod · 0.72
_construct_executionMethod · 0.72
_run_matmulMethod · 0.72
create_trt_sessionMethod · 0.72
MLPMethod · 0.72
set_fp4_scalesMethod · 0.72
set_weight_layerMethod · 0.72
create_trt_sessionMethod · 0.72
_construct_executionFunction · 0.72