Image to Patch Embedding.
| 63 | |
| 64 | |
| 65 | class PatchEmbed(nn.Module): |
| 66 | """ |
| 67 | Image to Patch Embedding. |
| 68 | """ |
| 69 | |
| 70 | def __init__( |
| 71 | self, |
| 72 | kernel_size: Tuple[int, ...] = (7, 7), |
| 73 | stride: Tuple[int, ...] = (4, 4), |
| 74 | padding: Tuple[int, ...] = (3, 3), |
| 75 | in_chans: int = 3, |
| 76 | embed_dim: int = 768, |
| 77 | ): |
| 78 | """ |
| 79 | Args: |
| 80 | kernel_size (Tuple): kernel size of the projection layer. |
| 81 | stride (Tuple): stride of the projection layer. |
| 82 | padding (Tuple): padding size of the projection layer. |
| 83 | in_chans (int): Number of input image channels. |
| 84 | embed_dim (int): embed_dim (int): Patch embedding dimension. |
| 85 | """ |
| 86 | super().__init__() |
| 87 | self.proj = nn.Conv2d( |
| 88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding |
| 89 | ) |
| 90 | |
| 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 92 | x = self.proj(x) |
| 93 | # B C H W -> B H W C |
| 94 | x = x.permute(0, 2, 3, 1) |
| 95 | return x |