(self, channels, num_heads=1, use_checkpoint=False)
| 206 | """ |
| 207 | |
| 208 | def __init__(self, channels, num_heads=1, use_checkpoint=False): |
| 209 | super().__init__() |
| 210 | self.channels = channels |
| 211 | self.num_heads = num_heads |
| 212 | self.use_checkpoint = use_checkpoint |
| 213 | |
| 214 | self.norm = normalization(channels) |
| 215 | self.qkv = conv_nd(1, channels, channels * 3, 1) |
| 216 | self.attention = QKVAttention() |
| 217 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) |
| 218 | |
| 219 | def forward(self, x): |
| 220 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) |
nothing calls this directly
no test coverage detected