(self, c: Config)
| 151 | use_rms: bool = False |
| 152 | |
| 153 | def __init__(self, c: Config): |
| 154 | super().__init__() |
| 155 | _levels = T.tensor(c.levels, dtype=int32) |
| 156 | self.register_buffer("_levels", _levels, persistent = False) |
| 157 | |
| 158 | _basis = T.cumprod(T.tensor([1] + c.levels[:-1]), dim=0, dtype=int32) |
| 159 | self.register_buffer("_basis", _basis, persistent = False) |
| 160 | |
| 161 | self.scale = c.scale |
| 162 | |
| 163 | codebook_dim = len(c.levels) |
| 164 | self.codebook_dim = codebook_dim |
| 165 | |
| 166 | effective_codebook_dim = codebook_dim * c.num_codebooks |
| 167 | self.num_codebooks = c.num_codebooks |
| 168 | |
| 169 | self.allowed_dtypes = [] |
| 170 | for dtype_str in c.allowed_dtypes: |
| 171 | if hasattr(T, dtype_str): |
| 172 | self.allowed_dtypes.append(getattr(T, dtype_str)) |
| 173 | else: |
| 174 | raise ValueError(f"Invalid dtype string: {dtype_str}") |
| 175 | |
| 176 | self.effective_codebook_dim = effective_codebook_dim |
| 177 | |
| 178 | keep_num_codebooks_dim = default(c.keep_num_codebooks_dim, c.num_codebooks > 1) |
| 179 | assert not (c.num_codebooks > 1 and not keep_num_codebooks_dim) |
| 180 | self.keep_num_codebooks_dim = keep_num_codebooks_dim |
| 181 | |
| 182 | self.dim = default(c.dim, len(_levels) * c.num_codebooks) |
| 183 | |
| 184 | self.channel_first = c.channel_first |
| 185 | |
| 186 | has_projections = self.dim != effective_codebook_dim |
| 187 | self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = c.projection_has_bias) if has_projections else nn.Identity() |
| 188 | self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = c.projection_has_bias) if has_projections else nn.Identity() |
| 189 | |
| 190 | self.has_projections = has_projections |
| 191 | |
| 192 | self.return_indices = c.return_indices |
| 193 | if c.return_indices: |
| 194 | self.codebook_size = self._levels.prod().item() |
| 195 | implicit_codebook = self._indices_to_codes(T.arange(self.codebook_size)) |
| 196 | self.register_buffer("implicit_codebook", implicit_codebook, persistent = False) |
| 197 | |
| 198 | self.allowed_dtypes = c.allowed_dtypes |
| 199 | self.force_quantization_f32 = c.force_quantization_f32 |
| 200 | |
| 201 | self.latent_loss = None |
| 202 | |
| 203 | def latent_metric(self, codes, get_global=False): |
| 204 | return {'code_util_estimate': get_code_utilization(codes, self.codebook_size, get_global)} |
no test coverage detected