| 193 | |
| 194 | class CrossAttention(nn.Module): |
| 195 | def __init__( |
| 196 | self, |
| 197 | query_dim, |
| 198 | context_dim=None, |
| 199 | heads=8, |
| 200 | dim_head=64, |
| 201 | dropout=0.0, |
| 202 | backend=None, |
| 203 | ): |
| 204 | super().__init__() |
| 205 | inner_dim = dim_head * heads |
| 206 | context_dim = default(context_dim, query_dim) |
| 207 | |
| 208 | self.scale = dim_head**-0.5 |
| 209 | self.heads = heads |
| 210 | |
| 211 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
| 212 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) |
| 213 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) |
| 214 | |
| 215 | self.to_out = nn.Sequential( |
| 216 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) |
| 217 | ) |
| 218 | self.backend = backend |
| 219 | |
| 220 | def forward( |
| 221 | self, |