MCPcopy
hub / github.com/openai/guided-diffusion / QKVAttentionLegacy

Class QKVAttentionLegacy

guided_diffusion/unet.py:328–358  ·  view source on GitHub ↗

A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping

Source from the content-addressed store, hash-verified

326
327
328class QKVAttentionLegacy(nn.Module):
329 """
330 A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
331 """
332
333 def __init__(self, n_heads):
334 super().__init__()
335 self.n_heads = n_heads
336
337 def forward(self, qkv):
338 """
339 Apply QKV attention.
340
341 :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
342 :return: an [N x (H * C) x T] tensor after attention.
343 """
344 bs, width, length = qkv.shape
345 assert width % (3 * self.n_heads) == 0
346 ch = width // (3 * self.n_heads)
347 q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
348 scale = 1 / math.sqrt(math.sqrt(ch))
349 weight = th.einsum(
350 "bct,bcs->bts", q * scale, k * scale
351 ) # More stable with f16 than dividing afterwards
352 weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
353 a = th.einsum("bts,bcs->bct", weight, v)
354 return a.reshape(bs, -1, length)
355
356 @staticmethod
357 def count_flops(model, _x, y):
358 return count_flops_attn(model, _x, y)
359
360
361class QKVAttention(nn.Module):

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected