Initialize weights of a neural network module. Parameters are initialized using the given method or distribution. Custom initialization routines can be implemented into submodules as function `espnet_initialization_fn` within the custom module. Args: model: Target.
(model: torch.nn.Module, init: str)
| 7 | |
| 8 | |
| 9 | def initialize(model: torch.nn.Module, init: str): |
| 10 | """Initialize weights of a neural network module. |
| 11 | |
| 12 | Parameters are initialized using the given method or distribution. |
| 13 | |
| 14 | Custom initialization routines can be implemented into submodules |
| 15 | as function `espnet_initialization_fn` within the custom module. |
| 16 | |
| 17 | Args: |
| 18 | model: Target. |
| 19 | init: Method of initialization. |
| 20 | """ |
| 21 | |
| 22 | # weight init |
| 23 | for p in model.parameters(): |
| 24 | if p.dim() > 1: |
| 25 | if init == "xavier_uniform": |
| 26 | torch.nn.init.xavier_uniform_(p.data) |
| 27 | elif init == "xavier_normal": |
| 28 | torch.nn.init.xavier_normal_(p.data) |
| 29 | elif init == "kaiming_uniform": |
| 30 | torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") |
| 31 | elif init == "kaiming_normal": |
| 32 | torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") |
| 33 | else: |
| 34 | raise ValueError("Unknown initialization: " + init) |
| 35 | # bias init |
| 36 | for p in model.parameters(): |
| 37 | if p.dim() == 1: |
| 38 | p.data.zero_() |
| 39 | |
| 40 | # reset some modules with default init |
| 41 | for m in model.modules(): |
| 42 | if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm, torch.nn.GroupNorm)): |
| 43 | m.reset_parameters() |
| 44 | if hasattr(m, "espnet_initialization_fn"): |
| 45 | m.espnet_initialization_fn() |
| 46 | |
| 47 | # TODO(xkc): Hacking s3prl_frontend and wav2vec2encoder initialization |
| 48 | if getattr(model, "encoder", None) and getattr( |
| 49 | model.encoder, "reload_pretrained_parameters", None |
| 50 | ): |
| 51 | model.encoder.reload_pretrained_parameters() |
| 52 | if getattr(model, "frontend", None) and getattr( |
| 53 | model.frontend, "reload_pretrained_parameters", None |
| 54 | ): |
| 55 | model.frontend.reload_pretrained_parameters() |
nothing calls this directly
no test coverage detected
searching dependent graphs…