(self, config: ModelArgs)
| 104 | license="apache-2.0", |
| 105 | ): |
| 106 | def __init__(self, config: ModelArgs): |
| 107 | super().__init__() |
| 108 | self.config = config |
| 109 | |
| 110 | self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]()) |
| 111 | self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]()) |
| 112 | |
| 113 | self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim) |
| 114 | self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim) |
| 115 | |
| 116 | self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False) |
| 117 | self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False) |
| 118 | self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size)) |
| 119 | |
| 120 | def setup_caches(self, max_batch_size: int) -> torch.Tensor: |
| 121 | """Setup KV caches and return a causal mask.""" |
nothing calls this directly
no test coverage detected