(self, image, text)
| 356 | return x |
| 357 | |
| 358 | def forward(self, image, text): |
| 359 | image_features = self.encode_image(image) |
| 360 | text_features = self.encode_text(text) |
| 361 | |
| 362 | # normalized features |
| 363 | image_features = image_features / image_features.norm(dim=1, keepdim=True) |
| 364 | text_features = text_features / text_features.norm(dim=1, keepdim=True) |
| 365 | |
| 366 | # cosine similarity as logits |
| 367 | logit_scale = self.logit_scale.exp() |
| 368 | logits_per_image = logit_scale * image_features @ text_features.t() |
| 369 | logits_per_text = logits_per_image.t() |
| 370 | |
| 371 | # shape = [global_batch_size, global_batch_size] |
| 372 | return logits_per_image, logits_per_text |
| 373 | |
| 374 | |
| 375 | def convert_weights(model: nn.Module): |
no test coverage detected