Add an identity operation. Parameters: input : Tensor The input tensor. Returns: The tensor produced by this identity operation.
(input: Tensor)
| 3276 | |
| 3277 | |
| 3278 | def identity(input: Tensor) -> Tensor: |
| 3279 | ''' |
| 3280 | Add an identity operation. |
| 3281 | |
| 3282 | Parameters: |
| 3283 | input : Tensor |
| 3284 | The input tensor. |
| 3285 | |
| 3286 | Returns: |
| 3287 | The tensor produced by this identity operation. |
| 3288 | ''' |
| 3289 | if not default_net().plugin_config.identity_plugin: |
| 3290 | layer = default_trtnet().add_identity(input.trt_tensor) |
| 3291 | else: |
| 3292 | plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 3293 | 'Identity', '1', TRT_LLM_PLUGIN_NAMESPACE) |
| 3294 | assert plg_creator is not None |
| 3295 | pfc = trt.PluginFieldCollection() |
| 3296 | id_plug = plg_creator.create_plugin("identity", pfc) |
| 3297 | plug_inputs = [input.trt_tensor] |
| 3298 | layer = default_trtnet().add_plugin_v2(plug_inputs, id_plug) |
| 3299 | _add_plugin_info(layer, plg_creator, "identity", pfc) |
| 3300 | return _create_tensor(layer.get_output(0), layer) |
| 3301 | |
| 3302 | |
| 3303 | def argmax(input: Tensor, dim: int, keepdim: bool = False) -> Tensor: |
no test coverage detected