MCPcopy
hub / github.com/ermongroup/ddim / AttnBlock

Class AttnBlock

models/diffusion.py:137–189  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

135
136
137class AttnBlock(nn.Module):
138 def __init__(self, in_channels):
139 super().__init__()
140 self.in_channels = in_channels
141
142 self.norm = Normalize(in_channels)
143 self.q = torch.nn.Conv2d(in_channels,
144 in_channels,
145 kernel_size=1,
146 stride=1,
147 padding=0)
148 self.k = torch.nn.Conv2d(in_channels,
149 in_channels,
150 kernel_size=1,
151 stride=1,
152 padding=0)
153 self.v = torch.nn.Conv2d(in_channels,
154 in_channels,
155 kernel_size=1,
156 stride=1,
157 padding=0)
158 self.proj_out = torch.nn.Conv2d(in_channels,
159 in_channels,
160 kernel_size=1,
161 stride=1,
162 padding=0)
163
164 def forward(self, x):
165 h_ = x
166 h_ = self.norm(h_)
167 q = self.q(h_)
168 k = self.k(h_)
169 v = self.v(h_)
170
171 # compute attention
172 b, c, h, w = q.shape
173 q = q.reshape(b, c, h*w)
174 q = q.permute(0, 2, 1) # b,hw,c
175 k = k.reshape(b, c, h*w) # b,c,hw
176 w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
177 w_ = w_ * (int(c)**(-0.5))
178 w_ = torch.nn.functional.softmax(w_, dim=2)
179
180 # attend to values
181 v = v.reshape(b, c, h*w)
182 w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
183 # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
184 h_ = torch.bmm(v, w_)
185 h_ = h_.reshape(b, c, h, w)
186
187 h_ = self.proj_out(h_)
188
189 return x+h_
190
191
192class Model(nn.Module):

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected