MCPcopy
hub / github.com/meta-pytorch/opacus / fix

Function fix

opacus/validators/multihead_attention.py:38–51  ·  view source on GitHub ↗
(module: nn.MultiheadAttention)

Source from the content-addressed store, hash-verified

36
37@register_module_fixer(nn.MultiheadAttention)
38def fix(module: nn.MultiheadAttention) -> DPMultiheadAttention:
39 dp_attn = DPMultiheadAttention(
40 embed_dim=module.embed_dim,
41 num_heads=module.num_heads,
42 dropout=module.dropout,
43 bias=module.in_proj_bias is not None,
44 add_bias_kv=module.bias_k is not None,
45 add_zero_attn=module.add_zero_attn,
46 kdim=module.kdim,
47 vdim=module.vdim,
48 batch_first=module.batch_first,
49 )
50 dp_attn.load_state_dict(module.state_dict())
51 return dp_attn

Callers

nothing calls this directly

Calls 3

load_state_dictMethod · 0.95
state_dictMethod · 0.45

Tested by

no test coverage detected