MCPcopy
hub / github.com/openai/improved-diffusion / QKVAttention

Class QKVAttention

improved_diffusion/unet.py:233–275  ·  view source on GitHub ↗

A module which performs QKV attention.

Source from the content-addressed store, hash-verified

231
232
233class QKVAttention(nn.Module):
234 """
235 A module which performs QKV attention.
236 """
237
238 def forward(self, qkv):
239 """
240 Apply QKV attention.
241
242 :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
243 :return: an [N x C x T] tensor after attention.
244 """
245 ch = qkv.shape[1] // 3
246 q, k, v = th.split(qkv, ch, dim=1)
247 scale = 1 / math.sqrt(math.sqrt(ch))
248 weight = th.einsum(
249 "bct,bcs->bts", q * scale, k * scale
250 ) # More stable with f16 than dividing afterwards
251 weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
252 return th.einsum("bts,bcs->bct", weight, v)
253
254 @staticmethod
255 def count_flops(model, _x, y):
256 """
257 A counter for the `thop` package to count the operations in an
258 attention operation.
259
260 Meant to be used like:
261
262 macs, params = thop.profile(
263 model,
264 inputs=(inputs, timestamps),
265 custom_ops={QKVAttention: QKVAttention.count_flops},
266 )
267
268 """
269 b, c, *spatial = y[0].shape
270 num_spatial = int(np.prod(spatial))
271 # We perform two matmuls with the same number of ops.
272 # The first computes the weight matrix, the second computes
273 # the combination of the value vectors.
274 matmul_ops = 2 * b * (num_spatial ** 2) * c
275 model.total_ops += th.DoubleTensor([matmul_ops])
276
277
278class UNetModel(nn.Module):

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected