MCPcopy Index your code
hub / github.com/modelscope/FunASR / initialize

Function initialize

funasr/train_utils/initialize.py:9–55  ·  view source on GitHub ↗

Initialize weights of a neural network module. Parameters are initialized using the given method or distribution. Custom initialization routines can be implemented into submodules as function `espnet_initialization_fn` within the custom module. Args: model: Target.

(model: torch.nn.Module, init: str)

Source from the content-addressed store, hash-verified

7
8
9def initialize(model: torch.nn.Module, init: str):
10 """Initialize weights of a neural network module.
11
12 Parameters are initialized using the given method or distribution.
13
14 Custom initialization routines can be implemented into submodules
15 as function `espnet_initialization_fn` within the custom module.
16
17 Args:
18 model: Target.
19 init: Method of initialization.
20 """
21
22 # weight init
23 for p in model.parameters():
24 if p.dim() > 1:
25 if init == "xavier_uniform":
26 torch.nn.init.xavier_uniform_(p.data)
27 elif init == "xavier_normal":
28 torch.nn.init.xavier_normal_(p.data)
29 elif init == "kaiming_uniform":
30 torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
31 elif init == "kaiming_normal":
32 torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
33 else:
34 raise ValueError("Unknown initialization: " + init)
35 # bias init
36 for p in model.parameters():
37 if p.dim() == 1:
38 p.data.zero_()
39
40 # reset some modules with default init
41 for m in model.modules():
42 if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm, torch.nn.GroupNorm)):
43 m.reset_parameters()
44 if hasattr(m, "espnet_initialization_fn"):
45 m.espnet_initialization_fn()
46
47 # TODO(xkc): Hacking s3prl_frontend and wav2vec2encoder initialization
48 if getattr(model, "encoder", None) and getattr(
49 model.encoder, "reload_pretrained_parameters", None
50 ):
51 model.encoder.reload_pretrained_parameters()
52 if getattr(model, "frontend", None) and getattr(
53 model.frontend, "reload_pretrained_parameters", None
54 ):
55 model.frontend.reload_pretrained_parameters()

Callers

nothing calls this directly

Calls 4

parametersMethod · 0.80
reset_parametersMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…