MCPcopy
hub / github.com/SesameAILabs/csm / setup_caches

Method setup_caches

models.py:120–130  ·  view source on GitHub ↗

Setup KV caches and return a causal mask.

(self, max_batch_size: int)

Source from the content-addressed store, hash-verified

118 self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
119
120 def setup_caches(self, max_batch_size: int) -> torch.Tensor:
121 """Setup KV caches and return a causal mask."""
122 dtype = next(self.parameters()).dtype
123 device = next(self.parameters()).device
124
125 with device:
126 self.backbone.setup_caches(max_batch_size, dtype)
127 self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
128
129 self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
130 self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
131
132 def generate_frame(
133 self,

Callers 1

__init__Method · 0.80

Calls 1

_create_causal_maskFunction · 0.85

Tested by

no test coverage detected