(self, c: Config)
| 212 | from_pretrained: Optional[Tuple[str, str]] = None |
| 213 | |
| 214 | def __init__(self, c: Config): |
| 215 | super().__init__() |
| 216 | |
| 217 | if exists(c.from_pretrained): |
| 218 | checkpoint = load_ckpt(*c.from_pretrained) |
| 219 | else: |
| 220 | assert (exists(c.stack_config)), f'hmm {c}' |
| 221 | |
| 222 | self.input = nn.Linear(c.latent_size, c.dim) |
| 223 | if self.c.split: |
| 224 | self.input2 = nn.Linear(c.latent_size, c.dim) |
| 225 | |
| 226 | self.shape_rotator = ShapeRotator(c.stack_config.dim//c.stack_config.n_head, c.stack_config.seq_len, theta=c.stack_config.theta) |
| 227 | |
| 228 | self.layers = nn.ModuleList([ |
| 229 | PerfBlock( |
| 230 | dim=c.stack_config.dim, |
| 231 | layer_id=l, |
| 232 | n_head=c.stack_config.n_head, |
| 233 | kv_heads=c.stack_config.kv_heads, |
| 234 | ff_dim=c.stack_config.ff_dim, |
| 235 | eps=c.stack_config.eps, |
| 236 | shape_rotator=self.shape_rotator, |
| 237 | ) for l in range(c.stack_config.layers) |
| 238 | ]) |
| 239 | |
| 240 | self.output = GPTOutput(c.dim, c.vocab_size) |
| 241 | if self.c.split: |
| 242 | self.output2 = GPTOutput(c.dim, c.vocab_size) |
| 243 | |
| 244 | self.cache = [None] * c.stack_config.layers |
| 245 | self.kv_heads = c.stack_config.kv_heads or c.stack_config.n_head |
| 246 | self.head_dim = c.stack_config.dim // c.stack_config.n_head |
| 247 | |
| 248 | if exists(c.from_pretrained): |
| 249 | result = self.load_state_dict(checkpoint, strict=False) |
| 250 | print0_colored(result, 'yellow') |
| 251 | |
| 252 | self.resynthesizer = c.resynthesizer_config().eval() |
| 253 | self.resynthesizer.requires_grad = False |
| 254 | |
| 255 | self.audio_tokenizer = make_tokenizer(device='cpu') |
| 256 | self.audio_cache = None |
| 257 | self.audio_latent_cache = None |
| 258 | self.use_audio_cache = False |
| 259 | |
| 260 | @T.no_grad() |
| 261 | def tokenize(self, audio_data): |
no test coverage detected