Forward pass for training. Args: input: Input audio/text data. input_lengths: Lengths of input.
(
self, input: torch.Tensor, input_lengths: torch.Tensor
)
| 135 | return self.output_dim |
| 136 | |
| 137 | def forward( |
| 138 | self, input: torch.Tensor, input_lengths: torch.Tensor |
| 139 | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 140 | """Forward pass for training. |
| 141 | |
| 142 | Args: |
| 143 | input: Input audio/text data. |
| 144 | input_lengths: Lengths of input. |
| 145 | """ |
| 146 | wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)] |
| 147 | self.upstream.eval() |
| 148 | with torch.no_grad(): |
| 149 | feats = self.upstream(wavs) |
| 150 | feats = self.featurizer(wavs, feats) |
| 151 | |
| 152 | if self.args.tile_factor != 1: |
| 153 | feats = self._tile_representations(feats) |
| 154 | |
| 155 | input_feats = pad_list(feats, 0.0) |
| 156 | feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long) |
| 157 | |
| 158 | # Saving CUDA Memory |
| 159 | del feats |
| 160 | |
| 161 | return input_feats, feats_lens |
| 162 | |
| 163 | def reload_pretrained_parameters(self): |
| 164 | """Reload pretrained parameters.""" |
nothing calls this directly
no test coverage detected