The DoRA plugin applies column-wise scaling to the output of a LoRA layer. Parameters: input : Tensor (On GPU) The input tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding out_hidden_sizes : list[int] The o
(activations: Tensor,
out_hidden_sizes: list[int],
lora_weights_pointers: list[Tensor],
host_request_types: Tensor,
host_context_lengths: Tensor | None = None)
| 6768 | |
| 6769 | |
| 6770 | def dora_plugin(activations: Tensor, |
| 6771 | out_hidden_sizes: list[int], |
| 6772 | lora_weights_pointers: list[Tensor], |
| 6773 | host_request_types: Tensor, |
| 6774 | host_context_lengths: Tensor | None = None) -> Tensor: |
| 6775 | ''' |
| 6776 | The DoRA plugin applies column-wise scaling to the output of a LoRA layer. |
| 6777 | |
| 6778 | Parameters: |
| 6779 | input : Tensor (On GPU) |
| 6780 | The input tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding |
| 6781 | |
| 6782 | out_hidden_sizes : list[int] |
| 6783 | The output hidden size of each adapter in the related LoRA module. |
| 6784 | For example, for a qkv projection out_hidden_sizes should be [q_dim, k_dim, v_dim]. |
| 6785 | |
| 6786 | host_request_types : Tensor = None |
| 6787 | The tensor on the host that indicates if a request is in context or |
| 6788 | generation phase. Its shape is [batch_size]. See Inflight Batching |
| 6789 | in docs/source/advanced/gpt-attention.md, |
| 6790 | |
| 6791 | host_context_lengths: cpu Tensor = None |
| 6792 | A host tensor that contains the lengths of the different inputs, |
| 6793 | |
| 6794 | Return: |
| 6795 | The tensor produced by that layer. |
| 6796 | |
| 6797 | ''' |
| 6798 | assert host_context_lengths is not None or not default_net( |
| 6799 | ).plugin_config.remove_input_padding |
| 6800 | |
| 6801 | dora_plg_creator = trt.get_plugin_registry().get_creator( |
| 6802 | 'Dora', '1', TRT_LLM_PLUGIN_NAMESPACE) |
| 6803 | assert dora_plg_creator is not None |
| 6804 | |
| 6805 | out_hidden_sizes = trt.PluginField( |
| 6806 | f"out_hidden_sizes", np.array(out_hidden_sizes, dtype=np.int32), |
| 6807 | trt.PluginFieldType.INT32) |
| 6808 | |
| 6809 | remove_input_padding = trt.PluginField( |
| 6810 | "remove_input_padding", |
| 6811 | np.array(np.int8(default_net().plugin_config.remove_input_padding), |
| 6812 | dtype=np.int8), trt.PluginFieldType.INT8) |
| 6813 | |
| 6814 | lora_dtype = default_net().plugin_config.lora_plugin |
| 6815 | type_id = trt.PluginField( |
| 6816 | "type", np.array(int(str_dtype_to_trt(lora_dtype)), np.int32), |
| 6817 | trt.PluginFieldType.INT32) |
| 6818 | |
| 6819 | pfc = trt.PluginFieldCollection( |
| 6820 | [type_id, remove_input_padding, out_hidden_sizes]) |
| 6821 | |
| 6822 | dora_plug = dora_plg_creator.create_plugin("dora", pfc, |
| 6823 | trt.TensorRTPhase.BUILD) |
| 6824 | |
| 6825 | plug_inputs = [activations.cast(lora_dtype), host_request_types |
| 6826 | ] + lora_weights_pointers |
| 6827 |