Initialize OpenAIWhisperDecoderWarp. Args: dropout_rate: TODO. whisper_model: Whisper Model instance. download_dir: TODO. use_padmask: TODO.
(
self,
dropout_rate: float = 0.0,
whisper_model: str = "small",
download_dir: str = None,
use_padmask: bool = False,
)
| 20 | """ |
| 21 | |
| 22 | def __init__( |
| 23 | self, |
| 24 | dropout_rate: float = 0.0, |
| 25 | whisper_model: str = "small", |
| 26 | download_dir: str = None, |
| 27 | use_padmask: bool = False, |
| 28 | ): |
| 29 | """Initialize OpenAIWhisperDecoderWarp. |
| 30 | |
| 31 | Args: |
| 32 | dropout_rate: TODO. |
| 33 | whisper_model: Whisper Model instance. |
| 34 | download_dir: TODO. |
| 35 | use_padmask: TODO. |
| 36 | """ |
| 37 | super().__init__() |
| 38 | |
| 39 | assert whisper_model in whisper.available_models() |
| 40 | _model = whisper.load_model(whisper_model, download_root=download_dir, device="cpu") |
| 41 | self.decoders = copy.deepcopy(_model.decoder) |
| 42 | attention_dim = self.decoders.token_embedding.embedding_dim |
| 43 | |
| 44 | # note that originally Whisper doesn't use dropouts |
| 45 | self.dropout = torch.nn.Dropout(dropout_rate) |
| 46 | |
| 47 | self.decoders.train() |
| 48 | del _model |
| 49 | self.use_padmask = use_padmask |
| 50 | |
| 51 | def forward( |
| 52 | self, |