| 141 | return self._wrap_output(token, features) |
| 142 | |
| 143 | class SigLIP2Wrapper(CLIPWrapper): |
| 144 | def __init__(self, clip_model, tokenizer, proc, adaptor_name, clip_mode = False, patch_size: int = 16, is_dynamic: bool = True): |
| 145 | super().__init__(clip_model, tokenizer, adaptor_name, clip_mode) |
| 146 | self._patch_size = patch_size |
| 147 | self._proc = proc |
| 148 | self._is_dynamic = is_dynamic |
| 149 | |
| 150 | self.register_buffer('mask', torch.ones(1, 1, dtype=torch.int32)) |
| 151 | |
| 152 | @property |
| 153 | def patch_size(self): |
| 154 | return self._patch_size |
| 155 | |
| 156 | def forward(self, x: torch.Tensor, *args, **kwargs): |
| 157 | out_h = x.shape[-2] // self._patch_size |
| 158 | out_w = x.shape[-1] // self._patch_size |
| 159 | |
| 160 | extra = dict() |
| 161 | |
| 162 | if self._is_dynamic: |
| 163 | pixel_values = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', |
| 164 | p1=self._patch_size, p2=self._patch_size, |
| 165 | h=out_h, w=out_w) |
| 166 | mask = self.mask.expand(*pixel_values.shape[:2]) |
| 167 | shapes = torch.tensor([(out_h, out_w)] * pixel_values.shape[0], dtype=torch.int64, device=x.device) |
| 168 | |
| 169 | extra = dict(attention_mask=mask, spatial_shapes=shapes) |
| 170 | else: |
| 171 | pixel_values = x |
| 172 | |
| 173 | output = self.inner.vision_model(pixel_values=pixel_values, return_dict=True, **extra) |
| 174 | |
| 175 | summary = output.pooler_output |
| 176 | features = output.last_hidden_state |
| 177 | |
| 178 | if kwargs.get('feature_fmt', None) == 'NCHW': |
| 179 | features = rearrange(features, 'b (h w) c -> b c h w', h=out_h, w=out_w) |
| 180 | |
| 181 | return self._wrap_output(summary, features) |
| 182 | |
| 183 | def encode_text(self, text, normalize: bool = False): |
| 184 | output = self.inner.text_model(**text, return_dict=True) |
| 185 | token = output.pooler_output |
| 186 | |
| 187 | if normalize: |
| 188 | token = F.normalize(token, dim=-1) |
| 189 | |
| 190 | return token |
| 191 | |
| 192 | def zero_shot_postproc(self, logits: torch.Tensor): |
| 193 | logit_scale, logit_bias = self.inner.logit_scale.to(logits.device), self.inner.logit_bias.to(logits.device) |
| 194 | logits = logits * logit_scale.exp() + logit_bias |
| 195 | return logits |
| 196 | |
| 197 | |
| 198 | class SAMWrapper(nn.Module): |
no outgoing calls
no test coverage detected
searching dependent graphs…