(
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
)
| 25 | """ |
| 26 | |
| 27 | def __init__( |
| 28 | self, |
| 29 | spacial_dim: int, |
| 30 | embed_dim: int, |
| 31 | num_heads_channels: int, |
| 32 | output_dim: int = None, |
| 33 | ): |
| 34 | super().__init__() |
| 35 | self.positional_embedding = nn.Parameter( |
| 36 | th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5 |
| 37 | ) |
| 38 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) |
| 39 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) |
| 40 | self.num_heads = embed_dim // num_heads_channels |
| 41 | self.attention = QKVAttention(self.num_heads) |
| 42 | |
| 43 | def forward(self, x): |
| 44 | b, c, *_spatial = x.shape |
nothing calls this directly
no test coverage detected