(self,
hidden_states: Tensor,
attention_mask=None,
input_lengths=None,
max_input_length=None,
lora_layer_params=None)
| 1718 | dtype=dtype) |
| 1719 | |
| 1720 | def forward(self, |
| 1721 | hidden_states: Tensor, |
| 1722 | attention_mask=None, |
| 1723 | input_lengths=None, |
| 1724 | max_input_length=None, |
| 1725 | lora_layer_params=None): |
| 1726 | assert isinstance(hidden_states, Tensor) |
| 1727 | |
| 1728 | qkv_lora_params = None |
| 1729 | if lora_layer_params is not None: |
| 1730 | qkv_lora_params = lora_layer_params.get_runtime_params( |
| 1731 | 0, "attn_qkv") |
| 1732 | |
| 1733 | qkv = self.qkv(hidden_states, qkv_lora_params) |
| 1734 | |
| 1735 | if default_net().plugin_config.remove_input_padding: |
| 1736 | assert qkv.ndim() == 2 |
| 1737 | |
| 1738 | if default_net( |
| 1739 | ).plugin_config.lora_plugin and qkv_lora_params is None and lora_layer_params is not None: |
| 1740 | q_lora_params = lora_layer_params.get_runtime_params(0, "attn_q") |
| 1741 | k_lora_params = lora_layer_params.get_runtime_params(0, "attn_k") |
| 1742 | v_lora_params = lora_layer_params.get_runtime_params(0, "attn_v") |
| 1743 | |
| 1744 | assert (q_lora_params is not None and k_lora_params is not None and v_lora_params is not None) or \ |
| 1745 | (q_lora_params is None and k_lora_params is None and v_lora_params is None), "q_lora_params, k_lora_params and v_lora_params should be all enabled or all disabled at the same time." |
| 1746 | |
| 1747 | if q_lora_params is not None and k_lora_params is not None and v_lora_params is not None: |
| 1748 | qkv_lora_params = LoraRuntimeParams( |
| 1749 | lora_ranks=[ |
| 1750 | q_lora_params.lora_ranks[0], |
| 1751 | k_lora_params.lora_ranks[0], |
| 1752 | v_lora_params.lora_ranks[0], |
| 1753 | ], |
| 1754 | lora_weights_pointers=[ |
| 1755 | q_lora_params.lora_weights_pointers[0], |
| 1756 | k_lora_params.lora_weights_pointers[0], |
| 1757 | v_lora_params.lora_weights_pointers[0], |
| 1758 | ], |
| 1759 | host_request_types=q_lora_params.host_request_types, |
| 1760 | host_context_lengths=q_lora_params.host_context_lengths) |
| 1761 | |
| 1762 | q_lora, k_lora, v_lora = self.qkv_lora(hidden_states, |
| 1763 | qkv_lora_params) |
| 1764 | qkv_lora = concat([q_lora, k_lora, v_lora], |
| 1765 | dim=q_lora.rank() - 1) |
| 1766 | qkv = qkv + qkv_lora |
| 1767 | |
| 1768 | if default_net().plugin_config.bert_attention_plugin: |
| 1769 | # TRT plugin mode |
| 1770 | assert input_lengths is not None |
| 1771 | context = bert_attention( |
| 1772 | qkv, |
| 1773 | input_lengths, |
| 1774 | self.num_attention_heads, |
| 1775 | self.attention_head_size, |
| 1776 | q_scaling=self.q_scaling, |
| 1777 | relative_attention=self.relative_attention, |
nothing calls this directly
no test coverage detected