(module: nn.MultiheadAttention)
| 36 | |
| 37 | @register_module_fixer(nn.MultiheadAttention) |
| 38 | def fix(module: nn.MultiheadAttention) -> DPMultiheadAttention: |
| 39 | dp_attn = DPMultiheadAttention( |
| 40 | embed_dim=module.embed_dim, |
| 41 | num_heads=module.num_heads, |
| 42 | dropout=module.dropout, |
| 43 | bias=module.in_proj_bias is not None, |
| 44 | add_bias_kv=module.bias_k is not None, |
| 45 | add_zero_attn=module.add_zero_attn, |
| 46 | kdim=module.kdim, |
| 47 | vdim=module.vdim, |
| 48 | batch_first=module.batch_first, |
| 49 | ) |
| 50 | dp_attn.load_state_dict(module.state_dict()) |
| 51 | return dp_attn |
nothing calls this directly
no test coverage detected