| 134 | stride=self.patch_size) |
| 135 | |
| 136 | def forward_vace(self, x, vace_context, seq_len, kwargs): |
| 137 | # embeddings |
| 138 | c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] |
| 139 | c = [u.flatten(2).transpose(1, 2) for u in c] |
| 140 | c = torch.cat([ |
| 141 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], |
| 142 | dim=1) for u in c |
| 143 | ]) |
| 144 | |
| 145 | # arguments |
| 146 | new_kwargs = dict(x=x) |
| 147 | new_kwargs.update(kwargs) |
| 148 | |
| 149 | hints = [] |
| 150 | for block in self.vace_blocks: |
| 151 | c, c_skip = block(c, **new_kwargs) |
| 152 | hints.append(c_skip) |
| 153 | return hints |
| 154 | |
| 155 | def forward( |
| 156 | self, |