MCPcopy
hub / github.com/modelscope/FunASR / forward

Method forward

funasr/frontends/s3prl.py:137–161  ·  view source on GitHub ↗

Forward pass for training. Args: input: Input audio/text data. input_lengths: Lengths of input.

(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    )

Source from the content-addressed store, hash-verified

135 return self.output_dim
136
137 def forward(
138 self, input: torch.Tensor, input_lengths: torch.Tensor
139 ) -> Tuple[torch.Tensor, torch.Tensor]:
140 """Forward pass for training.
141
142 Args:
143 input: Input audio/text data.
144 input_lengths: Lengths of input.
145 """
146 wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
147 self.upstream.eval()
148 with torch.no_grad():
149 feats = self.upstream(wavs)
150 feats = self.featurizer(wavs, feats)
151
152 if self.args.tile_factor != 1:
153 feats = self._tile_representations(feats)
154
155 input_feats = pad_list(feats, 0.0)
156 feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long)
157
158 # Saving CUDA Memory
159 del feats
160
161 return input_feats, feats_lens
162
163 def reload_pretrained_parameters(self):
164 """Reload pretrained parameters."""

Callers

nothing calls this directly

Calls 3

_tile_representationsMethod · 0.95
pad_listFunction · 0.90
evalMethod · 0.45

Tested by

no test coverage detected