(self, codes: torch.Tensor)
| 95 | return spk_embedding.unsqueeze(0).bfloat16() |
| 96 | |
| 97 | def embed_codes(self, codes: torch.Tensor) -> torch.Tensor: |
| 98 | return sum(emb(codes[:, i]) for i, emb in enumerate(self.embeddings)) |
| 99 | |
| 100 | def apply_heads(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 101 | return torch.stack([head(hidden_states) for head in self.heads], dim=1) |
no outgoing calls
no test coverage detected