Merge the patched input into a image with sliding window.
(self, x: torch.Tensor, batch_size: int, padding: int = 3)
| 188 | return torch.cat(x_patch_list, dim=0) |
| 189 | |
| 190 | def merge(self, x: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor: |
| 191 | """Merge the patched input into a image with sliding window.""" |
| 192 | steps = int(math.sqrt(x.shape[0] // batch_size)) |
| 193 | |
| 194 | idx = 0 |
| 195 | |
| 196 | output_list = [] |
| 197 | for j in range(steps): |
| 198 | output_row_list = [] |
| 199 | for i in range(steps): |
| 200 | output = x[batch_size * idx : batch_size * (idx + 1)] |
| 201 | |
| 202 | if j != 0: |
| 203 | output = output[..., padding:, :] |
| 204 | if i != 0: |
| 205 | output = output[..., :, padding:] |
| 206 | if j != steps - 1: |
| 207 | output = output[..., :-padding, :] |
| 208 | if i != steps - 1: |
| 209 | output = output[..., :, :-padding] |
| 210 | |
| 211 | output_row_list.append(output) |
| 212 | idx += 1 |
| 213 | |
| 214 | output_row = torch.cat(output_row_list, dim=-1) |
| 215 | output_list.append(output_row) |
| 216 | output = torch.cat(output_list, dim=-2) |
| 217 | return output |
| 218 | |
| 219 | def reshape_feature( |
| 220 | self, embeddings: torch.Tensor, width, height, cls_token_offset=1 |