| 10 | |
| 11 | |
| 12 | class SpeakerEncoder(nn.Module): |
| 13 | def __init__(self, device, loss_device): |
| 14 | super().__init__() |
| 15 | self.loss_device = loss_device |
| 16 | |
| 17 | # Network defition |
| 18 | self.lstm = nn.LSTM(input_size=mel_n_channels, |
| 19 | hidden_size=model_hidden_size, |
| 20 | num_layers=model_num_layers, |
| 21 | batch_first=True).to(device) |
| 22 | self.linear = nn.Linear(in_features=model_hidden_size, |
| 23 | out_features=model_embedding_size).to(device) |
| 24 | self.relu = torch.nn.ReLU().to(device) |
| 25 | |
| 26 | # Cosine similarity scaling (with fixed initial parameter values) |
| 27 | self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device) |
| 28 | self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device) |
| 29 | |
| 30 | # Loss |
| 31 | self.loss_fn = nn.CrossEntropyLoss().to(loss_device) |
| 32 | |
| 33 | def do_gradient_ops(self): |
| 34 | # Gradient scale |
| 35 | self.similarity_weight.grad *= 0.01 |
| 36 | self.similarity_bias.grad *= 0.01 |
| 37 | |
| 38 | # Gradient clipping |
| 39 | clip_grad_norm_(self.parameters(), 3, norm_type=2) |
| 40 | |
| 41 | def forward(self, utterances, hidden_init=None): |
| 42 | """ |
| 43 | Computes the embeddings of a batch of utterance spectrograms. |
| 44 | |
| 45 | :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape |
| 46 | (batch_size, n_frames, n_channels) |
| 47 | :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, |
| 48 | batch_size, hidden_size). Will default to a tensor of zeros if None. |
| 49 | :return: the embeddings as a tensor of shape (batch_size, embedding_size) |
| 50 | """ |
| 51 | # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state |
| 52 | # and the final cell state. |
| 53 | out, (hidden, cell) = self.lstm(utterances, hidden_init) |
| 54 | |
| 55 | # We take only the hidden state of the last layer |
| 56 | embeds_raw = self.relu(self.linear(hidden[-1])) |
| 57 | |
| 58 | # L2-normalize it |
| 59 | embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5) |
| 60 | |
| 61 | return embeds |
| 62 | |
| 63 | def similarity_matrix(self, embeds): |
| 64 | """ |
| 65 | Computes the similarity matrix according the section 2.1 of GE2E. |
| 66 | |
| 67 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch, |
| 68 | utterances_per_speaker, embedding_size) |
| 69 | :return: the similarity matrix as a tensor of shape (speakers_per_batch, |
no outgoing calls
no test coverage detected