(
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
)
| 34 | """ |
| 35 | |
| 36 | def __init__( |
| 37 | self, |
| 38 | spacial_dim: int, |
| 39 | embed_dim: int, |
| 40 | num_heads_channels: int, |
| 41 | output_dim: int = None, |
| 42 | ): |
| 43 | super().__init__() |
| 44 | self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) |
| 45 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) |
| 46 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) |
| 47 | self.num_heads = embed_dim // num_heads_channels |
| 48 | self.attention = QKVAttention(self.num_heads) |
| 49 | |
| 50 | def forward(self, x): |
| 51 | b, c, *_spatial = x.shape |
nothing calls this directly
no test coverage detected