(self,
x,
lora_runtime_params: LoraRuntimeParams = None,
is_cross_attention: bool = False)
| 103 | self.out_hidden_sizes = out_hidden_sizes |
| 104 | |
| 105 | def forward(self, |
| 106 | x, |
| 107 | lora_runtime_params: LoraRuntimeParams = None, |
| 108 | is_cross_attention: bool = False): |
| 109 | assert lora_runtime_params.weight_index == 0, "DoRA does not support weight_index != 0" |
| 110 | if default_net().plugin_config.lora_plugin and default_net( |
| 111 | ).plugin_config.dora_plugin: |
| 112 | result = dora_plugin( |
| 113 | x, |
| 114 | out_hidden_sizes=self.out_hidden_sizes, |
| 115 | host_request_types=lora_runtime_params.host_request_types, |
| 116 | host_context_lengths=lora_runtime_params.host_context_lengths |
| 117 | if not is_cross_attention else |
| 118 | lora_runtime_params.host_encoder_input_lengths, |
| 119 | lora_weights_pointers=lora_runtime_params.lora_weights_pointers, |
| 120 | ) |
| 121 | else: |
| 122 | assert False, "Not support dora without plugin" |
| 123 | |
| 124 | return result |
| 125 | |
| 126 | |
| 127 | class LoraParams(object): |
nothing calls this directly
no test coverage detected