MCPcopy Index your code
hub / github.com/modelscope/FunASR / __init__

Method __init__

funasr/models/transducer/rnnt_decoder.py:30–78  ·  view source on GitHub ↗

Construct a RNNDecoder object.

(
        self,
        vocab_size: int,
        embed_size: int = 256,
        hidden_size: int = 256,
        rnn_type: str = "lstm",
        num_layers: int = 1,
        dropout_rate: float = 0.0,
        embed_dropout_rate: float = 0.0,
        embed_pad: int = 0,
        use_embed_mask: bool = False,
    )

Source from the content-addressed store, hash-verified

28 """
29
30 def __init__(
31 self,
32 vocab_size: int,
33 embed_size: int = 256,
34 hidden_size: int = 256,
35 rnn_type: str = "lstm",
36 num_layers: int = 1,
37 dropout_rate: float = 0.0,
38 embed_dropout_rate: float = 0.0,
39 embed_pad: int = 0,
40 use_embed_mask: bool = False,
41 ) -> None:
42 """Construct a RNNDecoder object."""
43 super().__init__()
44
45 if rnn_type not in ("lstm", "gru"):
46 raise ValueError(f"Not supported: rnn_type={rnn_type}")
47
48 self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
49 self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
50
51 rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
52
53 self.rnn = torch.nn.ModuleList([rnn_class(embed_size, hidden_size, 1, batch_first=True)])
54
55 for _ in range(1, num_layers):
56 self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
57
58 self.dropout_rnn = torch.nn.ModuleList(
59 [torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
60 )
61
62 self.dlayers = num_layers
63 self.dtype = rnn_type
64
65 self.output_size = hidden_size
66 self.vocab_size = vocab_size
67
68 self.device = next(self.parameters()).device
69 self.score_cache = {}
70
71 self.use_embed_mask = use_embed_mask
72 if self.use_embed_mask:
73 self._embed_mask = SpecAug(
74 time_mask_width_range=3,
75 num_time_mask=4,
76 apply_freq_mask=False,
77 apply_time_warp=False,
78 )
79
80 def forward(
81 self,

Callers

nothing calls this directly

Calls 2

SpecAugClass · 0.90
parametersMethod · 0.80

Tested by

no test coverage detected