| 32 | |
| 33 | |
| 34 | class Parameter: |
| 35 | _DEFAULT_DTYPE = trt.DataType.FLOAT |
| 36 | |
| 37 | def __init__(self, |
| 38 | value: Optional[Union[np.ndarray, torch.Tensor]] = None, |
| 39 | shape: Sequence[int] = None, |
| 40 | dtype: Union[str, trt.DataType] = None, |
| 41 | is_buffer: bool = False, |
| 42 | prefer_managed=False): |
| 43 | if dtype is None: |
| 44 | logger.warning( |
| 45 | f'Parameter dtype is None, using default dtype: {self._DEFAULT_DTYPE}, it is recommended to always specify dtype explicitly' |
| 46 | ) |
| 47 | dtype = self._DEFAULT_DTYPE if dtype is None else dtype |
| 48 | if isinstance(dtype, str): |
| 49 | dtype = str_dtype_to_trt(dtype) |
| 50 | self._dtype: trt.DataType = dtype |
| 51 | if value is None: |
| 52 | assert isinstance(shape, ( |
| 53 | list, |
| 54 | tuple)), f"shape must be list or tuple, receive {(type(shape))}" |
| 55 | self._shape = tuple(shape) |
| 56 | self._value = None |
| 57 | else: |
| 58 | self._shape = value.shape |
| 59 | self._value = self._regularize_value(value) |
| 60 | self.is_buffer = is_buffer |
| 61 | self._prefer_managed = prefer_managed |
| 62 | self._tensor: Tensor = None |
| 63 | self._network: weakref.ref = None |
| 64 | self._name = None |
| 65 | self.need_transpose = False |
| 66 | |
| 67 | @property |
| 68 | def shape(self): |
| 69 | return self._shape |
| 70 | |
| 71 | @property |
| 72 | def dtype(self): |
| 73 | return self._dtype |
| 74 | |
| 75 | @property |
| 76 | def name(self): |
| 77 | return self._name |
| 78 | |
| 79 | def _create_managed_tensor(self, network) -> Tensor: |
| 80 | num = len(network._inputs) |
| 81 | self._name = f"managed_constant_{num}" |
| 82 | |
| 83 | if self._value is None or (isinstance(self._value, np.ndarray) |
| 84 | and not self._value.flags['C_CONTIGUOUS']): |
| 85 | value_old = self._value |
| 86 | self._value = np.empty(self._shape, trt_dtype_to_np(self._dtype)) |
| 87 | network._register_unfilled_weights( |
| 88 | # use updated self._shape here |
| 89 | self._name, |
| 90 | self._value, |
| 91 | value_old) |
no outgoing calls