| 283 | |
| 284 | |
| 285 | class OpenAI_CLIP_VisionAdapter(nn.Module): |
| 286 | def __init__(self, model): |
| 287 | super().__init__() |
| 288 | self.input_resolution = model.input_resolution |
| 289 | self.output_dim = model.output_dim |
| 290 | self.conv1 = model.conv1 |
| 291 | |
| 292 | self.class_embedding = model.class_embedding |
| 293 | self.positional_embedding = model.positional_embedding |
| 294 | self.ln_pre = model.ln_pre |
| 295 | |
| 296 | self.transformer = model.transformer |
| 297 | |
| 298 | self.ln_post = model.ln_post |
| 299 | self.proj = model.proj |
| 300 | |
| 301 | @property |
| 302 | def patch_size(self): |
| 303 | return self.conv1.kernel_size |
| 304 | |
| 305 | def forward(self, x: torch.Tensor): |
| 306 | x = self.conv1(x) # shape = [*, width, grid, grid] |
| 307 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] |
| 308 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] |
| 309 | x = torch.cat([ |
| 310 | self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), |
| 311 | x |
| 312 | ], dim=1) # shape = [*, grid ** 2 + 1, width] |
| 313 | x = x + self.positional_embedding.to(x.dtype) |
| 314 | x = self.ln_pre(x) |
| 315 | |
| 316 | x = x.permute(1, 0, 2) # NLD -> LND |
| 317 | x = self.transformer(x) |
| 318 | x = x.permute(1, 0, 2) # LND -> NLD |
| 319 | |
| 320 | feats = x[:, 1:] |
| 321 | |
| 322 | x = self.ln_post(x[:, 0, :]) |
| 323 | |
| 324 | if self.proj is not None: |
| 325 | x = x @ self.proj |
| 326 | |
| 327 | return x, feats |
| 328 | |
| 329 | @dataclass |
| 330 | class ModelInfo: |
no outgoing calls
no test coverage detected
searching dependent graphs…