MCPcopy
hub / github.com/meta-pytorch/opacus / fix

Function fix

opacus/validators/lstm.py:38–49  ·  view source on GitHub ↗
(module: nn.LSTM)

Source from the content-addressed store, hash-verified

36
37@register_module_fixer(nn.LSTM)
38def fix(module: nn.LSTM) -> DPLSTM:
39 dplstm = DPLSTM(
40 input_size=module.input_size,
41 hidden_size=module.hidden_size,
42 num_layers=module.num_layers,
43 bias=module.bias,
44 batch_first=module.batch_first,
45 dropout=module.dropout,
46 bidirectional=module.bidirectional,
47 )
48 dplstm.load_state_dict(module.state_dict())
49 return dplstm

Callers

nothing calls this directly

Calls 3

DPLSTMClass · 0.90
load_state_dictMethod · 0.45
state_dictMethod · 0.45

Tested by

no test coverage detected