imgs: [B, 3, H, W] of torch.float32. - mean: [0.48145466, 0.4578275, 0.40821073] - std: [0.26862954, 0.26130258, 0.27577711] txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer.
(self, imgs, txt_ids)
| 404 | self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) |
| 405 | |
| 406 | def forward(self, imgs, txt_ids): |
| 407 | """ |
| 408 | imgs: [B, 3, H, W] of torch.float32. |
| 409 | - mean: [0.48145466, 0.4578275, 0.40821073] |
| 410 | - std: [0.26862954, 0.26130258, 0.27577711] |
| 411 | txt_ids: [B, L] of torch.long. |
| 412 | Encoded by data.CLIPTokenizer. |
| 413 | """ |
| 414 | xi = self.visual(imgs) |
| 415 | xt = self.textual(txt_ids) |
| 416 | return xi, xt |
| 417 | |
| 418 | def param_groups(self): |
| 419 | groups = [{ |