MCPcopy Index your code
hub / github.com/Standard-Intelligence/hertz-dev / __init__

Method __init__

transformer.py:337–366  ·  view source on GitHub ↗
(self, c: Config)

Source from the content-addressed store, hash-verified

335 from_pretrained: Optional[Tuple[str, int]] = None
336
337 def __init__(self, c: Config):
338 super().__init__()
339
340 from_pretrained = c.from_pretrained
341 if exists(from_pretrained):
342 checkpoint = load_ckpt(c.from_pretrained)
343
344 self.shape_rotator = ShapeRotator(c.dim//c.n_head, c.seq_len, theta=c.theta)
345
346 self.layers = nn.ModuleList([
347 Block(
348 dim=c.dim,
349 layer_id=l,
350 n_head=c.n_head,
351 kv_heads=c.kv_heads,
352 ff_dim=c.ff_dim,
353 eps=c.eps,
354 causal=c.causal,
355 shape_rotator=self.shape_rotator,
356 ) for l in range(c.layers)
357 ])
358
359 kv_heads = c.kv_heads or c.n_head
360 head_dim = c.dim // c.n_head
361 cache_shape = [c.layers, c.seq_len, 2, kv_heads, head_dim]
362 self.cache_shape = cache_shape
363 self.cache = [None] * c.layers
364
365 if exists(from_pretrained):
366 self.load_state_dict(checkpoint)
367
368 def init_cache(self, bsize, device, dtype, length:int=None):
369 if self.cache_shape is None:

Callers

nothing calls this directly

Calls 5

existsFunction · 0.90
load_ckptFunction · 0.90
ShapeRotatorClass · 0.85
BlockClass · 0.85
__init__Method · 0.45

Tested by

no test coverage detected