Discard class token and reshape 1D feature map to a 2D grid.
(
self, embeddings: torch.Tensor, width, height, cls_token_offset=1
)
| 217 | return output |
| 218 | |
| 219 | def reshape_feature( |
| 220 | self, embeddings: torch.Tensor, width, height, cls_token_offset=1 |
| 221 | ): |
| 222 | """Discard class token and reshape 1D feature map to a 2D grid.""" |
| 223 | b, hw, c = embeddings.shape |
| 224 | |
| 225 | # Remove class token. |
| 226 | if cls_token_offset > 0: |
| 227 | embeddings = embeddings[:, cls_token_offset:, :] |
| 228 | |
| 229 | # Shape: (batch, height, width, dim) -> (batch, dim, height, width) |
| 230 | embeddings = embeddings.reshape(b, height, width, c).permute(0, 3, 1, 2) |
| 231 | return embeddings |
| 232 | |
| 233 | def forward(self, x: torch.Tensor) -> list[torch.Tensor]: |
| 234 | """Encode input at multiple resolutions. |