(self, text)
| 341 | return self.visual(image.type(self.dtype)) |
| 342 | |
| 343 | def encode_text(self, text): |
| 344 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] |
| 345 | |
| 346 | x = x + self.positional_embedding.type(self.dtype) |
| 347 | x = x.permute(1, 0, 2) # NLD -> LND |
| 348 | x = self.transformer(x) |
| 349 | x = x.permute(1, 0, 2) # LND -> NLD |
| 350 | x = self.ln_final(x).type(self.dtype) |
| 351 | |
| 352 | # x.shape = [batch_size, n_ctx, transformer.width] |
| 353 | # take features from the eot embedding (eot_token is the highest number in each sequence) |
| 354 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection |
| 355 | |
| 356 | return x |
| 357 | |
| 358 | def forward(self, image, text): |
| 359 | image_features = self.encode_image(image) |
no outgoing calls
no test coverage detected