MCPcopy Index your code
hub / github.com/NVlabs/RADIO / SigLIP2Wrapper

Class SigLIP2Wrapper

examples/common/model_loader.py:143–195  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

141 return self._wrap_output(token, features)
142
143class SigLIP2Wrapper(CLIPWrapper):
144 def __init__(self, clip_model, tokenizer, proc, adaptor_name, clip_mode = False, patch_size: int = 16, is_dynamic: bool = True):
145 super().__init__(clip_model, tokenizer, adaptor_name, clip_mode)
146 self._patch_size = patch_size
147 self._proc = proc
148 self._is_dynamic = is_dynamic
149
150 self.register_buffer('mask', torch.ones(1, 1, dtype=torch.int32))
151
152 @property
153 def patch_size(self):
154 return self._patch_size
155
156 def forward(self, x: torch.Tensor, *args, **kwargs):
157 out_h = x.shape[-2] // self._patch_size
158 out_w = x.shape[-1] // self._patch_size
159
160 extra = dict()
161
162 if self._is_dynamic:
163 pixel_values = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
164 p1=self._patch_size, p2=self._patch_size,
165 h=out_h, w=out_w)
166 mask = self.mask.expand(*pixel_values.shape[:2])
167 shapes = torch.tensor([(out_h, out_w)] * pixel_values.shape[0], dtype=torch.int64, device=x.device)
168
169 extra = dict(attention_mask=mask, spatial_shapes=shapes)
170 else:
171 pixel_values = x
172
173 output = self.inner.vision_model(pixel_values=pixel_values, return_dict=True, **extra)
174
175 summary = output.pooler_output
176 features = output.last_hidden_state
177
178 if kwargs.get('feature_fmt', None) == 'NCHW':
179 features = rearrange(features, 'b (h w) c -> b c h w', h=out_h, w=out_w)
180
181 return self._wrap_output(summary, features)
182
183 def encode_text(self, text, normalize: bool = False):
184 output = self.inner.text_model(**text, return_dict=True)
185 token = output.pooler_output
186
187 if normalize:
188 token = F.normalize(token, dim=-1)
189
190 return token
191
192 def zero_shot_postproc(self, logits: torch.Tensor):
193 logit_scale, logit_bias = self.inner.logit_scale.to(logits.device), self.inner.logit_bias.to(logits.device)
194 logits = logits * logit_scale.exp() + logit_bias
195 return logits
196
197
198class SAMWrapper(nn.Module):

Callers 1

load_modelFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…