MCPcopy Index your code
hub / github.com/NVlabs/RADIO / SAMWrapper

Class SAMWrapper

examples/common/model_loader.py:198–227  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

196
197
198class SAMWrapper(nn.Module):
199 def __init__(self, sam_encoder: nn.Module):
200 super().__init__()
201 self.inner = sam_encoder
202
203 @property
204 def embed_dim(self):
205 return self.inner.patch_embed.proj.out_channels
206
207 @property
208 def patch_size(self):
209 return self.inner.patch_embed.proj.kernel_size[0]
210
211 @property
212 def vision_encoder(self):
213 return self.inner
214
215 def forward(self, x: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
216 x = self.inner.patch_embed(x)
217 if self.inner.pos_embed is not None:
218 x = x + self.inner.pos_embed
219
220 for blk in self.inner.blocks:
221 x = blk(x)
222
223 features = x.flatten(1, 2)
224
225 summary = features.mean(dim=1)
226
227 return summary, features
228
229
230class InternViTWrapper(nn.Module):

Callers 1

load_modelFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…