| 135 | |
| 136 | |
| 137 | class AttnBlock(nn.Module): |
| 138 | def __init__(self, in_channels): |
| 139 | super().__init__() |
| 140 | self.in_channels = in_channels |
| 141 | |
| 142 | self.norm = Normalize(in_channels) |
| 143 | self.q = torch.nn.Conv2d(in_channels, |
| 144 | in_channels, |
| 145 | kernel_size=1, |
| 146 | stride=1, |
| 147 | padding=0) |
| 148 | self.k = torch.nn.Conv2d(in_channels, |
| 149 | in_channels, |
| 150 | kernel_size=1, |
| 151 | stride=1, |
| 152 | padding=0) |
| 153 | self.v = torch.nn.Conv2d(in_channels, |
| 154 | in_channels, |
| 155 | kernel_size=1, |
| 156 | stride=1, |
| 157 | padding=0) |
| 158 | self.proj_out = torch.nn.Conv2d(in_channels, |
| 159 | in_channels, |
| 160 | kernel_size=1, |
| 161 | stride=1, |
| 162 | padding=0) |
| 163 | |
| 164 | def forward(self, x): |
| 165 | h_ = x |
| 166 | h_ = self.norm(h_) |
| 167 | q = self.q(h_) |
| 168 | k = self.k(h_) |
| 169 | v = self.v(h_) |
| 170 | |
| 171 | # compute attention |
| 172 | b, c, h, w = q.shape |
| 173 | q = q.reshape(b, c, h*w) |
| 174 | q = q.permute(0, 2, 1) # b,hw,c |
| 175 | k = k.reshape(b, c, h*w) # b,c,hw |
| 176 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] |
| 177 | w_ = w_ * (int(c)**(-0.5)) |
| 178 | w_ = torch.nn.functional.softmax(w_, dim=2) |
| 179 | |
| 180 | # attend to values |
| 181 | v = v.reshape(b, c, h*w) |
| 182 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) |
| 183 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] |
| 184 | h_ = torch.bmm(v, w_) |
| 185 | h_ = h_.reshape(b, c, h, w) |
| 186 | |
| 187 | h_ = self.proj_out(h_) |
| 188 | |
| 189 | return x+h_ |
| 190 | |
| 191 | |
| 192 | class Model(nn.Module): |