Encode input at multiple resolutions. Args: ---- x (torch.Tensor): Input image. Returns: ------- Multi resolution encoded features.
(self, x: torch.Tensor)
| 231 | return embeddings |
| 232 | |
| 233 | def forward(self, x: torch.Tensor) -> list[torch.Tensor]: |
| 234 | """Encode input at multiple resolutions. |
| 235 | |
| 236 | Args: |
| 237 | ---- |
| 238 | x (torch.Tensor): Input image. |
| 239 | |
| 240 | Returns: |
| 241 | ------- |
| 242 | Multi resolution encoded features. |
| 243 | |
| 244 | """ |
| 245 | batch_size = x.shape[0] |
| 246 | |
| 247 | # Step 0: create a 3-level image pyramid. |
| 248 | x0, x1, x2 = self._create_pyramid(x) |
| 249 | |
| 250 | # Step 1: split to create batched overlapped mini-images at the backbone (BeiT/ViT/Dino) |
| 251 | # resolution. |
| 252 | # 5x5 @ 384x384 at the highest resolution (1536x1536). |
| 253 | x0_patches = self.split(x0, overlap_ratio=0.25) |
| 254 | # 3x3 @ 384x384 at the middle resolution (768x768). |
| 255 | x1_patches = self.split(x1, overlap_ratio=0.5) |
| 256 | # 1x1 # 384x384 at the lowest resolution (384x384). |
| 257 | x2_patches = x2 |
| 258 | |
| 259 | # Concatenate all the sliding window patches and form a batch of size (35=5x5+3x3+1x1). |
| 260 | x_pyramid_patches = torch.cat( |
| 261 | (x0_patches, x1_patches, x2_patches), |
| 262 | dim=0, |
| 263 | ) |
| 264 | |
| 265 | # Step 2: Run the backbone (BeiT) model and get the result of large batch size. |
| 266 | x_pyramid_encodings = self.patch_encoder(x_pyramid_patches) |
| 267 | x_pyramid_encodings = self.reshape_feature( |
| 268 | x_pyramid_encodings, self.out_size, self.out_size |
| 269 | ) |
| 270 | |
| 271 | # Step 3: merging. |
| 272 | # Merge highres latent encoding. |
| 273 | x_latent0_encodings = self.reshape_feature( |
| 274 | self.backbone_highres_hook0, |
| 275 | self.out_size, |
| 276 | self.out_size, |
| 277 | ) |
| 278 | x_latent0_features = self.merge( |
| 279 | x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3 |
| 280 | ) |
| 281 | |
| 282 | x_latent1_encodings = self.reshape_feature( |
| 283 | self.backbone_highres_hook1, |
| 284 | self.out_size, |
| 285 | self.out_size, |
| 286 | ) |
| 287 | x_latent1_features = self.merge( |
| 288 | x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3 |
| 289 | ) |
| 290 |
nothing calls this directly
no test coverage detected