@brief: Set input shapes to given context, and infer the output shapes from the given input shapes. This function should be called every time when the input shapes are changed before calling run(). Or call the context.set_input_shape on all dynamic shaped input
(
self,
inputs: List[TensorInfo],
context: Optional[trt.IExecutionContext] = None
)
| 202 | ) |
| 203 | |
| 204 | def infer_shapes( |
| 205 | self, |
| 206 | inputs: List[TensorInfo], |
| 207 | context: Optional[trt.IExecutionContext] = None |
| 208 | ) -> List[TensorInfo]: |
| 209 | ''' |
| 210 | @brief: Set input shapes to given context, and infer the output shapes from the given input shapes. |
| 211 | This function should be called every time when the input shapes are changed before calling run(). |
| 212 | Or call the context.set_input_shape on all dynamic shaped input tensors manually. |
| 213 | @param inputs: list of TensorInfo object, each item represents an input tensor |
| 214 | @param context: TensorRT execution context, if None, use the default context |
| 215 | @return: list of TensorInfo object, each item represents an output tensor, returns None if failed |
| 216 | ''' |
| 217 | # set shape to the default context if context is not specified |
| 218 | if context is None: |
| 219 | context = self.context |
| 220 | for i in inputs: |
| 221 | if self.engine.get_tensor_mode(i.name) != trt.TensorIOMode.INPUT: |
| 222 | raise ValueError(f"Tensor:{i.name} is not an input tensor") |
| 223 | if self.engine.get_tensor_dtype(i.name) != i.dtype: |
| 224 | raise ValueError(f"Tensor:{i.name} has wrong dtype") |
| 225 | if not context.set_input_shape(i.name, i.shape): |
| 226 | raise RuntimeError( |
| 227 | f"Could not set shape {i.shape} for tensor {i.name}. Please check the profile range for which your model was build." |
| 228 | ) |
| 229 | |
| 230 | outputs = [] |
| 231 | for i in range(self.engine.num_io_tensors): |
| 232 | name = self.engine.get_tensor_name(i) |
| 233 | if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT: |
| 234 | shape = context.get_tensor_shape(name) |
| 235 | dtype = self.engine.get_tensor_dtype(name) |
| 236 | outputs.append(TensorInfo(name, dtype, shape)) |
| 237 | return outputs |
| 238 | |
| 239 | def _set_weight_streaming(self, gpu_weights_percent): |
| 240 | if not self.engine.streamable_weights_size: |