(self, c: Config)
| 71 | from_pretrained: Optional[Tuple[str, str]] = None |
| 72 | |
| 73 | def __init__(self, c: Config): |
| 74 | super().__init__() |
| 75 | |
| 76 | if exists(c.from_pretrained): |
| 77 | checkpoint = load_ckpt(*c.from_pretrained) |
| 78 | else: |
| 79 | assert (exists(c.io_config) and exists(c.stack_config) and exists(c.quantizer_config)), f'hmm {c}' |
| 80 | |
| 81 | self.io = c.io_config() |
| 82 | self.stack = c.stack_config() |
| 83 | |
| 84 | self.plex_layer = c.stack_config.layers//2 |
| 85 | self.plex_roll = c.plex_roll |
| 86 | self.plex_dim = c.quantizer_config.dim |
| 87 | |
| 88 | assert self.plex_dim is not None and c.stack_config.dim is not None, f'One of the following are None: self.plex_dim: {self.plex_dim}, c.stack_config.dim: {c.stack_config.dim}' |
| 89 | self.plex_projection = nn.Linear(self.plex_dim, c.stack_config.dim) |
| 90 | self.out_norm = Norm(c.stack_config.dim) |
| 91 | |
| 92 | if c.split: |
| 93 | self.io2 = c.io_config() |
| 94 | self.plex_projection2 = nn.Linear(self.plex_dim, c.stack_config.dim) |
| 95 | |
| 96 | self.io2.fc_loc = None |
| 97 | self.io2.fc_scale = None |
| 98 | self.io2.fc_weight = None |
| 99 | |
| 100 | kv_heads = c.stack_config.kv_heads or c.stack_config.n_head |
| 101 | head_dim = c.stack_config.dim // c.stack_config.n_head |
| 102 | self.cache_num_layers = c.stack_config.layers + ((c.stack_config.layers - self.plex_layer) if c.split else 0) |
| 103 | cache_shape = [self.cache_num_layers, c.stack_config.seq_len, 2, kv_heads, head_dim] |
| 104 | self.cache_shape = cache_shape |
| 105 | self.cache = [None] * self.cache_num_layers |
| 106 | |
| 107 | if exists(c.from_pretrained): |
| 108 | result = self.load_state_dict(checkpoint, strict=False) |
| 109 | print0_colored(result, 'yellow') |
| 110 | |
| 111 | self.quantizer = c.quantizer_config().eval() |
| 112 | self.quantizer.requires_grad = False |
| 113 | |
| 114 | @T.no_grad() |
| 115 | def quantize(self, x): |
nothing calls this directly
no test coverage detected