MCPcopy Index your code
hub / github.com/Standard-Intelligence/hertz-dev / __init__

Method __init__

model.py:73–112  ·  view source on GitHub ↗
(self, c: Config)

Source from the content-addressed store, hash-verified

71 from_pretrained: Optional[Tuple[str, str]] = None
72
73 def __init__(self, c: Config):
74 super().__init__()
75
76 if exists(c.from_pretrained):
77 checkpoint = load_ckpt(*c.from_pretrained)
78 else:
79 assert (exists(c.io_config) and exists(c.stack_config) and exists(c.quantizer_config)), f'hmm {c}'
80
81 self.io = c.io_config()
82 self.stack = c.stack_config()
83
84 self.plex_layer = c.stack_config.layers//2
85 self.plex_roll = c.plex_roll
86 self.plex_dim = c.quantizer_config.dim
87
88 assert self.plex_dim is not None and c.stack_config.dim is not None, f'One of the following are None: self.plex_dim: {self.plex_dim}, c.stack_config.dim: {c.stack_config.dim}'
89 self.plex_projection = nn.Linear(self.plex_dim, c.stack_config.dim)
90 self.out_norm = Norm(c.stack_config.dim)
91
92 if c.split:
93 self.io2 = c.io_config()
94 self.plex_projection2 = nn.Linear(self.plex_dim, c.stack_config.dim)
95
96 self.io2.fc_loc = None
97 self.io2.fc_scale = None
98 self.io2.fc_weight = None
99
100 kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
101 head_dim = c.stack_config.dim // c.stack_config.n_head
102 self.cache_num_layers = c.stack_config.layers + ((c.stack_config.layers - self.plex_layer) if c.split else 0)
103 cache_shape = [self.cache_num_layers, c.stack_config.seq_len, 2, kv_heads, head_dim]
104 self.cache_shape = cache_shape
105 self.cache = [None] * self.cache_num_layers
106
107 if exists(c.from_pretrained):
108 result = self.load_state_dict(checkpoint, strict=False)
109 print0_colored(result, 'yellow')
110
111 self.quantizer = c.quantizer_config().eval()
112 self.quantizer.requires_grad = False
113
114 @T.no_grad()
115 def quantize(self, x):

Callers

nothing calls this directly

Calls 5

existsFunction · 0.90
load_ckptFunction · 0.90
NormClass · 0.90
print0_coloredFunction · 0.90
__init__Method · 0.45

Tested by

no test coverage detected