(self, prompt_img)
| 85 | return text_embeddings |
| 86 | |
| 87 | def get_img_embeds(self, prompt_img): |
| 88 | # Tokenize text and get embeddings |
| 89 | prompt_img = prompt_img.squeeze(0) |
| 90 | img_input = self.processor(images=prompt_img.detach().cpu().numpy(), return_tensors='pt') |
| 91 | |
| 92 | with torch.no_grad(): |
| 93 | img_embeddings = self.image_encoder(img_input.pixel_values.to(self.device))[0] |
| 94 | |
| 95 | return img_embeddings |
| 96 | |
| 97 | def img_clip_loss(self, clip_model, rgb1, rgb2): |
| 98 | image_z_1 = clip_model.encode_image(self.aug(rgb1)) |