MCPcopy Index your code
hub / github.com/Standard-Intelligence/hertz-dev / __init__

Method __init__

ioblocks.py:153–201  ·  view source on GitHub ↗
(self, c: Config)

Source from the content-addressed store, hash-verified

151 use_rms: bool = False
152
153 def __init__(self, c: Config):
154 super().__init__()
155 _levels = T.tensor(c.levels, dtype=int32)
156 self.register_buffer("_levels", _levels, persistent = False)
157
158 _basis = T.cumprod(T.tensor([1] + c.levels[:-1]), dim=0, dtype=int32)
159 self.register_buffer("_basis", _basis, persistent = False)
160
161 self.scale = c.scale
162
163 codebook_dim = len(c.levels)
164 self.codebook_dim = codebook_dim
165
166 effective_codebook_dim = codebook_dim * c.num_codebooks
167 self.num_codebooks = c.num_codebooks
168
169 self.allowed_dtypes = []
170 for dtype_str in c.allowed_dtypes:
171 if hasattr(T, dtype_str):
172 self.allowed_dtypes.append(getattr(T, dtype_str))
173 else:
174 raise ValueError(f"Invalid dtype string: {dtype_str}")
175
176 self.effective_codebook_dim = effective_codebook_dim
177
178 keep_num_codebooks_dim = default(c.keep_num_codebooks_dim, c.num_codebooks > 1)
179 assert not (c.num_codebooks > 1 and not keep_num_codebooks_dim)
180 self.keep_num_codebooks_dim = keep_num_codebooks_dim
181
182 self.dim = default(c.dim, len(_levels) * c.num_codebooks)
183
184 self.channel_first = c.channel_first
185
186 has_projections = self.dim != effective_codebook_dim
187 self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = c.projection_has_bias) if has_projections else nn.Identity()
188 self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = c.projection_has_bias) if has_projections else nn.Identity()
189
190 self.has_projections = has_projections
191
192 self.return_indices = c.return_indices
193 if c.return_indices:
194 self.codebook_size = self._levels.prod().item()
195 implicit_codebook = self._indices_to_codes(T.arange(self.codebook_size))
196 self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
197
198 self.allowed_dtypes = c.allowed_dtypes
199 self.force_quantization_f32 = c.force_quantization_f32
200
201 self.latent_loss = None
202
203 def latent_metric(self, codes, get_global=False):
204 return {'code_util_estimate': get_code_utilization(codes, self.codebook_size, get_global)}

Callers 2

__init__Method · 0.45
__init__Method · 0.45

Calls 2

_indices_to_codesMethod · 0.95
defaultFunction · 0.90

Tested by

no test coverage detected