MCPcopy Index your code
hub / github.com/Standard-Intelligence/hertz-dev / forward

Method forward

model.py:139–197  ·  view source on GitHub ↗
(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None)

Source from the content-addressed store, hash-verified

137
138 @T.no_grad()
139 def forward(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None):
140 if self.c.split:
141 x1, x2 = data.chunk(2, dim=-1)
142 x = self.io.input(x1) + self.io2.input(x2)
143 else:
144 x = self.io.input(data)
145
146 cache_idx = 0
147 for l, layer in enumerate(self.stack.layers):
148 if l == self.plex_layer:
149 if self.c.split:
150 plex1, plex2 = self.quantize(data)
151 plex1 = T.roll(plex1, -self.c.plex_roll, dims=1)
152 plex2 = T.roll(plex2, -self.c.plex_roll, dims=1)
153 if exists(next_tokens):
154 plex1[:, -1:] = self.untokenize(next_tokens[0])
155 plex2[:, -1:] = self.untokenize(next_tokens[1])
156 x1 = x + self.plex_projection(plex1)
157 x2 = x + self.plex_projection2(plex2)
158 else:
159 plex = self.quantize(data)
160 plex = T.roll(plex, -self.c.plex_roll, dims=1)
161 if exists(next_tokens):
162 plex[:, -1:] = self.untokenize(next_tokens)
163 x = x + self.plex_projection(plex)
164
165 if l < self.plex_layer:
166 x = layer(x, kv=self.cache[l])
167 else:
168 if self.c.split:
169 x1 = layer(x1, kv=self.cache[self.plex_layer + cache_idx])
170 cache_idx += 1
171 x2 = layer(x2, kv=self.cache[self.plex_layer + cache_idx])
172 cache_idx += 1
173 else:
174 x = layer(x, kv=self.cache[l])
175
176 with T.autocast(device_type='cuda', dtype=T.bfloat16):
177 if self.c.split:
178 x1, x2 = self.out_norm(x1), self.out_norm(x2)
179 out1, out2 = self.io.output(x1), self.io.output(x2)
180 else:
181 x = self.out_norm(x)
182 out = self.io.output(x)
183
184 if isnt(temps):
185 if self.c.split:
186 return out1, out2
187 else:
188 return out
189 else:
190 if self.c.split:
191 next_data1 = self.io.temp_sample(out1, temps)[:, -1:, :]
192 next_data2 = self.io2.temp_sample(out2, temps)[:, -1:, :]
193 next_data = T.cat([next_data1, next_data2], dim=-1)
194 return next_data
195 else:
196 next_data = self.io.temp_sample(out, temps)[:, -1:, :]

Callers

nothing calls this directly

Calls 7

quantizeMethod · 0.95
untokenizeMethod · 0.95
existsFunction · 0.90
isntFunction · 0.90
inputMethod · 0.80
outputMethod · 0.80
temp_sampleMethod · 0.80

Tested by

no test coverage detected