A module which performs QKV attention.
| 231 | |
| 232 | |
| 233 | class QKVAttention(nn.Module): |
| 234 | """ |
| 235 | A module which performs QKV attention. |
| 236 | """ |
| 237 | |
| 238 | def forward(self, qkv): |
| 239 | """ |
| 240 | Apply QKV attention. |
| 241 | |
| 242 | :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs. |
| 243 | :return: an [N x C x T] tensor after attention. |
| 244 | """ |
| 245 | ch = qkv.shape[1] // 3 |
| 246 | q, k, v = th.split(qkv, ch, dim=1) |
| 247 | scale = 1 / math.sqrt(math.sqrt(ch)) |
| 248 | weight = th.einsum( |
| 249 | "bct,bcs->bts", q * scale, k * scale |
| 250 | ) # More stable with f16 than dividing afterwards |
| 251 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) |
| 252 | return th.einsum("bts,bcs->bct", weight, v) |
| 253 | |
| 254 | @staticmethod |
| 255 | def count_flops(model, _x, y): |
| 256 | """ |
| 257 | A counter for the `thop` package to count the operations in an |
| 258 | attention operation. |
| 259 | |
| 260 | Meant to be used like: |
| 261 | |
| 262 | macs, params = thop.profile( |
| 263 | model, |
| 264 | inputs=(inputs, timestamps), |
| 265 | custom_ops={QKVAttention: QKVAttention.count_flops}, |
| 266 | ) |
| 267 | |
| 268 | """ |
| 269 | b, c, *spatial = y[0].shape |
| 270 | num_spatial = int(np.prod(spatial)) |
| 271 | # We perform two matmuls with the same number of ops. |
| 272 | # The first computes the weight matrix, the second computes |
| 273 | # the combination of the value vectors. |
| 274 | matmul_ops = 2 * b * (num_spatial ** 2) * c |
| 275 | model.total_ops += th.DoubleTensor([matmul_ops]) |
| 276 | |
| 277 | |
| 278 | class UNetModel(nn.Module): |