(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None)
| 137 | |
| 138 | @T.no_grad() |
| 139 | def forward(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None): |
| 140 | if self.c.split: |
| 141 | x1, x2 = data.chunk(2, dim=-1) |
| 142 | x = self.io.input(x1) + self.io2.input(x2) |
| 143 | else: |
| 144 | x = self.io.input(data) |
| 145 | |
| 146 | cache_idx = 0 |
| 147 | for l, layer in enumerate(self.stack.layers): |
| 148 | if l == self.plex_layer: |
| 149 | if self.c.split: |
| 150 | plex1, plex2 = self.quantize(data) |
| 151 | plex1 = T.roll(plex1, -self.c.plex_roll, dims=1) |
| 152 | plex2 = T.roll(plex2, -self.c.plex_roll, dims=1) |
| 153 | if exists(next_tokens): |
| 154 | plex1[:, -1:] = self.untokenize(next_tokens[0]) |
| 155 | plex2[:, -1:] = self.untokenize(next_tokens[1]) |
| 156 | x1 = x + self.plex_projection(plex1) |
| 157 | x2 = x + self.plex_projection2(plex2) |
| 158 | else: |
| 159 | plex = self.quantize(data) |
| 160 | plex = T.roll(plex, -self.c.plex_roll, dims=1) |
| 161 | if exists(next_tokens): |
| 162 | plex[:, -1:] = self.untokenize(next_tokens) |
| 163 | x = x + self.plex_projection(plex) |
| 164 | |
| 165 | if l < self.plex_layer: |
| 166 | x = layer(x, kv=self.cache[l]) |
| 167 | else: |
| 168 | if self.c.split: |
| 169 | x1 = layer(x1, kv=self.cache[self.plex_layer + cache_idx]) |
| 170 | cache_idx += 1 |
| 171 | x2 = layer(x2, kv=self.cache[self.plex_layer + cache_idx]) |
| 172 | cache_idx += 1 |
| 173 | else: |
| 174 | x = layer(x, kv=self.cache[l]) |
| 175 | |
| 176 | with T.autocast(device_type='cuda', dtype=T.bfloat16): |
| 177 | if self.c.split: |
| 178 | x1, x2 = self.out_norm(x1), self.out_norm(x2) |
| 179 | out1, out2 = self.io.output(x1), self.io.output(x2) |
| 180 | else: |
| 181 | x = self.out_norm(x) |
| 182 | out = self.io.output(x) |
| 183 | |
| 184 | if isnt(temps): |
| 185 | if self.c.split: |
| 186 | return out1, out2 |
| 187 | else: |
| 188 | return out |
| 189 | else: |
| 190 | if self.c.split: |
| 191 | next_data1 = self.io.temp_sample(out1, temps)[:, -1:, :] |
| 192 | next_data2 = self.io2.temp_sample(out2, temps)[:, -1:, :] |
| 193 | next_data = T.cat([next_data1, next_data2], dim=-1) |
| 194 | return next_data |
| 195 | else: |
| 196 | next_data = self.io.temp_sample(out, temps)[:, -1:, :] |
nothing calls this directly
no test coverage detected