(self, img_q, cond_or_img_s, seg_s=None, return_features=False)
| 423 | return super().visual_forward(img_s, mask=('all', 'cls_token', seg_s)) |
| 424 | |
| 425 | def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False): |
| 426 | |
| 427 | if seg_s is None: |
| 428 | cond = cond_or_img_s |
| 429 | else: |
| 430 | img_s = cond_or_img_s |
| 431 | |
| 432 | with torch.no_grad(): |
| 433 | cond, _, _ = self.visual_forward_masked(img_s, seg_s) |
| 434 | |
| 435 | return super().forward(img_q, cond, return_features=return_features) |
| 436 | |
| 437 | |
| 438 |
nothing calls this directly
no test coverage detected