Create a 3-level image pyramid.
(
self, x: torch.Tensor
)
| 149 | return self.patch_encoder.patch_embed.img_size[0] * 4 |
| 150 | |
| 151 | def _create_pyramid( |
| 152 | self, x: torch.Tensor |
| 153 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 154 | """Create a 3-level image pyramid.""" |
| 155 | # Original resolution: 1536 by default. |
| 156 | x0 = x |
| 157 | |
| 158 | # Middle resolution: 768 by default. |
| 159 | x1 = F.interpolate( |
| 160 | x, size=None, scale_factor=0.5, mode="bilinear", align_corners=False |
| 161 | ) |
| 162 | |
| 163 | # Low resolution: 384 by default, corresponding to the backbone resolution. |
| 164 | x2 = F.interpolate( |
| 165 | x, size=None, scale_factor=0.25, mode="bilinear", align_corners=False |
| 166 | ) |
| 167 | |
| 168 | return x0, x1, x2 |
| 169 | |
| 170 | def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor: |
| 171 | """Split the input into small patches with sliding window.""" |