| 26 | from_pretrained: Optional[Tuple[str, str]] = None |
| 27 | |
| 28 | def __init__(self, c: Config): |
| 29 | super().__init__() |
| 30 | |
| 31 | if exists(c.from_pretrained): |
| 32 | checkpoint = load_ckpt(*c.from_pretrained) |
| 33 | else: |
| 34 | assert exists(c.compressor_config), f'hmm {c}' |
| 35 | |
| 36 | self.compressor = c.compressor_config() |
| 37 | self.ffnn = FFNN(c.dim, c.ff_dim) |
| 38 | self.input = nn.Linear(c.input_dim, c.dim) if exists(c.input_dim) else nn.Identity() |
| 39 | |
| 40 | if exists(c.from_pretrained): |
| 41 | self.load_state_dict(checkpoint) |
| 42 | |
| 43 | @T.no_grad() |
| 44 | def forward(self, x, return_latent=False, known_latent=None): |