| 198 | |
| 199 | |
| 200 | def get_patches_batch(self, x, p): |
| 201 | _size_h, _size_w = p.shape[2:] |
| 202 | patches_batch = [] |
| 203 | for idx in range(x.shape[0]): |
| 204 | columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1) |
| 205 | patches_x = [] |
| 206 | for column_x in columns_x: |
| 207 | patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)] |
| 208 | patch_sample = torch.cat(patches_x, dim=1) |
| 209 | patches_batch.append(patch_sample) |
| 210 | return torch.cat(patches_batch, dim=0) |
| 211 | |
| 212 | def forward(self, features): |
| 213 | if self.training and self.config.out_ref: |