(self, c: Config)
| 335 | from_pretrained: Optional[Tuple[str, int]] = None |
| 336 | |
| 337 | def __init__(self, c: Config): |
| 338 | super().__init__() |
| 339 | |
| 340 | from_pretrained = c.from_pretrained |
| 341 | if exists(from_pretrained): |
| 342 | checkpoint = load_ckpt(c.from_pretrained) |
| 343 | |
| 344 | self.shape_rotator = ShapeRotator(c.dim//c.n_head, c.seq_len, theta=c.theta) |
| 345 | |
| 346 | self.layers = nn.ModuleList([ |
| 347 | Block( |
| 348 | dim=c.dim, |
| 349 | layer_id=l, |
| 350 | n_head=c.n_head, |
| 351 | kv_heads=c.kv_heads, |
| 352 | ff_dim=c.ff_dim, |
| 353 | eps=c.eps, |
| 354 | causal=c.causal, |
| 355 | shape_rotator=self.shape_rotator, |
| 356 | ) for l in range(c.layers) |
| 357 | ]) |
| 358 | |
| 359 | kv_heads = c.kv_heads or c.n_head |
| 360 | head_dim = c.dim // c.n_head |
| 361 | cache_shape = [c.layers, c.seq_len, 2, kv_heads, head_dim] |
| 362 | self.cache_shape = cache_shape |
| 363 | self.cache = [None] * c.layers |
| 364 | |
| 365 | if exists(from_pretrained): |
| 366 | self.load_state_dict(checkpoint) |
| 367 | |
| 368 | def init_cache(self, bsize, device, dtype, length:int=None): |
| 369 | if self.cache_shape is None: |
nothing calls this directly
no test coverage detected