MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / Parameter

Class Parameter

tensorrt_llm/parameter.py:34–274  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

32
33
34class Parameter:
35 _DEFAULT_DTYPE = trt.DataType.FLOAT
36
37 def __init__(self,
38 value: Optional[Union[np.ndarray, torch.Tensor]] = None,
39 shape: Sequence[int] = None,
40 dtype: Union[str, trt.DataType] = None,
41 is_buffer: bool = False,
42 prefer_managed=False):
43 if dtype is None:
44 logger.warning(
45 f'Parameter dtype is None, using default dtype: {self._DEFAULT_DTYPE}, it is recommended to always specify dtype explicitly'
46 )
47 dtype = self._DEFAULT_DTYPE if dtype is None else dtype
48 if isinstance(dtype, str):
49 dtype = str_dtype_to_trt(dtype)
50 self._dtype: trt.DataType = dtype
51 if value is None:
52 assert isinstance(shape, (
53 list,
54 tuple)), f"shape must be list or tuple, receive {(type(shape))}"
55 self._shape = tuple(shape)
56 self._value = None
57 else:
58 self._shape = value.shape
59 self._value = self._regularize_value(value)
60 self.is_buffer = is_buffer
61 self._prefer_managed = prefer_managed
62 self._tensor: Tensor = None
63 self._network: weakref.ref = None
64 self._name = None
65 self.need_transpose = False
66
67 @property
68 def shape(self):
69 return self._shape
70
71 @property
72 def dtype(self):
73 return self._dtype
74
75 @property
76 def name(self):
77 return self._name
78
79 def _create_managed_tensor(self, network) -> Tensor:
80 num = len(network._inputs)
81 self._name = f"managed_constant_{num}"
82
83 if self._value is None or (isinstance(self._value, np.ndarray)
84 and not self._value.flags['C_CONTIGUOUS']):
85 value_old = self._value
86 self._value = np.empty(self._shape, trt_dtype_to_np(self._dtype))
87 network._register_unfilled_weights(
88 # use updated self._shape here
89 self._name,
90 self._value,
91 value_old)

Callers 15

__init__Method · 0.90
_run_matmul_pluginMethod · 0.90
test_w4a8_linearFunction · 0.90
create_weightsMethod · 0.90
create_weightsMethod · 0.90
create_weightsMethod · 0.90
create_weightsMethod · 0.90
create_weightsMethod · 0.90
load_weights_vanillaMethod · 0.90

Calls

no outgoing calls

Tested by 4

_run_matmul_pluginMethod · 0.72
test_w4a8_linearFunction · 0.72