Construct a RNNDecoder object.
(
self,
vocab_size: int,
embed_size: int = 256,
hidden_size: int = 256,
rnn_type: str = "lstm",
num_layers: int = 1,
dropout_rate: float = 0.0,
embed_dropout_rate: float = 0.0,
embed_pad: int = 0,
use_embed_mask: bool = False,
)
| 28 | """ |
| 29 | |
| 30 | def __init__( |
| 31 | self, |
| 32 | vocab_size: int, |
| 33 | embed_size: int = 256, |
| 34 | hidden_size: int = 256, |
| 35 | rnn_type: str = "lstm", |
| 36 | num_layers: int = 1, |
| 37 | dropout_rate: float = 0.0, |
| 38 | embed_dropout_rate: float = 0.0, |
| 39 | embed_pad: int = 0, |
| 40 | use_embed_mask: bool = False, |
| 41 | ) -> None: |
| 42 | """Construct a RNNDecoder object.""" |
| 43 | super().__init__() |
| 44 | |
| 45 | if rnn_type not in ("lstm", "gru"): |
| 46 | raise ValueError(f"Not supported: rnn_type={rnn_type}") |
| 47 | |
| 48 | self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad) |
| 49 | self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate) |
| 50 | |
| 51 | rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU |
| 52 | |
| 53 | self.rnn = torch.nn.ModuleList([rnn_class(embed_size, hidden_size, 1, batch_first=True)]) |
| 54 | |
| 55 | for _ in range(1, num_layers): |
| 56 | self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)] |
| 57 | |
| 58 | self.dropout_rnn = torch.nn.ModuleList( |
| 59 | [torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)] |
| 60 | ) |
| 61 | |
| 62 | self.dlayers = num_layers |
| 63 | self.dtype = rnn_type |
| 64 | |
| 65 | self.output_size = hidden_size |
| 66 | self.vocab_size = vocab_size |
| 67 | |
| 68 | self.device = next(self.parameters()).device |
| 69 | self.score_cache = {} |
| 70 | |
| 71 | self.use_embed_mask = use_embed_mask |
| 72 | if self.use_embed_mask: |
| 73 | self._embed_mask = SpecAug( |
| 74 | time_mask_width_range=3, |
| 75 | num_time_mask=4, |
| 76 | apply_freq_mask=False, |
| 77 | apply_time_warp=False, |
| 78 | ) |
| 79 | |
| 80 | def forward( |
| 81 | self, |
nothing calls this directly
no test coverage detected