Expect users to rewrite this function to define the compute flow. There are a few special attributes for users to get access to some resources. `self.workspace`: The workspace address of TRT managed workspace. `self.current_stream`: The CUDA stream this plugin is expected t
(self, inputs: Sequence[TensorWrapper], outputs: Sequence[TensorWrapper])
| 462 | ) |
| 463 | |
| 464 | def forward(self, inputs: Sequence[TensorWrapper], outputs: Sequence[TensorWrapper]): |
| 465 | """Expect users to rewrite this function to define the compute flow. |
| 466 | |
| 467 | There are a few special attributes for users to get access to some resources. |
| 468 | |
| 469 | `self.workspace`: The workspace address of TRT managed workspace. |
| 470 | `self.current_stream`: The CUDA stream this plugin is expected to execute on. By default |
| 471 | `PluginBase` set the torch.cuda.current_stream() to this stream. This attribute is for the |
| 472 | toolkit that doesn't work with torch's stream. |
| 473 | """ |
| 474 | raise NotImplementedError |
| 475 | |
| 476 | def shape_dtype_inference(self, inputs: Sequence[SymTensor]): |
| 477 | """Expect users to rewrite this function to define the shape dtype inference for output tensors.""" |