| 196 | |
| 197 | |
| 198 | class SAMWrapper(nn.Module): |
| 199 | def __init__(self, sam_encoder: nn.Module): |
| 200 | super().__init__() |
| 201 | self.inner = sam_encoder |
| 202 | |
| 203 | @property |
| 204 | def embed_dim(self): |
| 205 | return self.inner.patch_embed.proj.out_channels |
| 206 | |
| 207 | @property |
| 208 | def patch_size(self): |
| 209 | return self.inner.patch_embed.proj.kernel_size[0] |
| 210 | |
| 211 | @property |
| 212 | def vision_encoder(self): |
| 213 | return self.inner |
| 214 | |
| 215 | def forward(self, x: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: |
| 216 | x = self.inner.patch_embed(x) |
| 217 | if self.inner.pos_embed is not None: |
| 218 | x = x + self.inner.pos_embed |
| 219 | |
| 220 | for blk in self.inner.blocks: |
| 221 | x = blk(x) |
| 222 | |
| 223 | features = x.flatten(1, 2) |
| 224 | |
| 225 | summary = features.mean(dim=1) |
| 226 | |
| 227 | return summary, features |
| 228 | |
| 229 | |
| 230 | class InternViTWrapper(nn.Module): |
no outgoing calls
no test coverage detected
searching dependent graphs…