(self, *args: Union[Sequence[TensorWrapper], Sequence[torch.Tensor]])
| 416 | self.current_stream = -1 |
| 417 | |
| 418 | def __call__(self, *args: Union[Sequence[TensorWrapper], Sequence[torch.Tensor]]): |
| 419 | is_trtllm = True |
| 420 | for i in args: |
| 421 | is_trtllm &= isinstance(i, Tensor) |
| 422 | |
| 423 | if not is_trtllm: |
| 424 | for i in args: |
| 425 | assert isinstance(i, torch.Tensor), ( |
| 426 | "Plugin inputs must be `tensorrt_llm.Tensor`s or `torch.Tensor`s" |
| 427 | ) |
| 428 | sym_tensors = self.shape_dtype_inference( |
| 429 | [SymTensor(i.dtype, [j for j in i.shape]) for i in args] |
| 430 | ) |
| 431 | sym_tensors = _convert_return_value_to_list(sym_tensors) |
| 432 | ret = [ |
| 433 | torch.empty(sym_tensor.shape, dtype=trt_dtype_to_torch(sym_tensor.dtype)) |
| 434 | for sym_tensor in sym_tensors |
| 435 | ] |
| 436 | self.current_stream = torch.cuda.current_stream().cuda_stream |
| 437 | self.workspace = torch.empty(self.workspace) |
| 438 | self.forward(args, ret) |
| 439 | else: |
| 440 | args = [i.trt_tensor for i in args] |
| 441 | layer_plugin = default_trtnet().add_plugin_v3(args, [], self) |
| 442 | ret = [ |
| 443 | _create_tensor(layer_plugin.get_output(i), layer_plugin) |
| 444 | for i in range(self.num_outputs) |
| 445 | ] |
| 446 | |
| 447 | if len(ret) == 1: |
| 448 | return ret[0] |
| 449 | |
| 450 | return ret |
| 451 | |
| 452 | def on_shape_change(self, input_desc, output_desc): |
| 453 | pass |
nothing calls this directly
no test coverage detected