(module: nn.Module, norm: str = 'none')
| 23 | |
| 24 | |
| 25 | def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: |
| 26 | assert norm in CONV_NORMALIZATIONS |
| 27 | if norm == 'weight_norm': |
| 28 | return weight_norm(module) |
| 29 | elif norm == 'spectral_norm': |
| 30 | return spectral_norm(module) |
| 31 | else: |
| 32 | # We already check was in CONV_NORMALIZATION, so any other choice |
| 33 | # doesn't need reparametrization. |
| 34 | return module |
| 35 | |
| 36 | |
| 37 | def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: |