MCPcopy
hub / github.com/SesameAILabs/csm / generate_frame

Method generate_frame

models.py:132–184  ·  view source on GitHub ↗

Args: tokens: (batch_size, seq_len, audio_num_codebooks+1) tokens_mask: (batch_size, seq_len, audio_num_codebooks+1) input_pos: (batch_size, seq_len) positions for each token mask: (batch_size, seq_len, max_seq_len Returns:

(
        self,
        tokens: torch.Tensor,
        tokens_mask: torch.Tensor,
        input_pos: torch.Tensor,
        temperature: float,
        topk: int,
    )

Source from the content-addressed store, hash-verified

130 self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
131
132 def generate_frame(
133 self,
134 tokens: torch.Tensor,
135 tokens_mask: torch.Tensor,
136 input_pos: torch.Tensor,
137 temperature: float,
138 topk: int,
139 ) -> torch.Tensor:
140 """
141 Args:
142 tokens: (batch_size, seq_len, audio_num_codebooks+1)
143 tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
144 input_pos: (batch_size, seq_len) positions for each token
145 mask: (batch_size, seq_len, max_seq_len
146
147 Returns:
148 (batch_size, audio_num_codebooks) sampled tokens
149 """
150 dtype = next(self.parameters()).dtype
151 b, s, _ = tokens.size()
152
153 assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
154 curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
155 embeds = self._embed_tokens(tokens)
156 masked_embeds = embeds * tokens_mask.unsqueeze(-1)
157 h = masked_embeds.sum(dim=2)
158 h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
159
160 last_h = h[:, -1, :]
161 c0_logits = self.codebook0_head(last_h)
162 c0_sample = sample_topk(c0_logits, topk, temperature)
163 c0_embed = self._embed_audio(0, c0_sample)
164
165 curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
166 curr_sample = c0_sample.clone()
167 curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
168
169 # Decoder caches must be reset every frame.
170 self.decoder.reset_caches()
171 for i in range(1, self.config.audio_num_codebooks):
172 curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
173 decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
174 dtype=dtype
175 )
176 ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
177 ci_sample = sample_topk(ci_logits, topk, temperature)
178 ci_embed = self._embed_audio(i, ci_sample)
179
180 curr_h = ci_embed
181 curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
182 curr_pos = curr_pos[:, -1:] + 1
183
184 return curr_sample
185
186 def reset_caches(self):
187 self.backbone.reset_caches()

Callers 1

generateMethod · 0.80

Calls 5

_embed_tokensMethod · 0.95
_embed_audioMethod · 0.95
_index_causal_maskFunction · 0.85
sample_topkFunction · 0.85
reset_cachesMethod · 0.80

Tested by

no test coverage detected