Split the input into small patches with sliding window.
(self, x: torch.Tensor, overlap_ratio: float = 0.25)
| 168 | return x0, x1, x2 |
| 169 | |
| 170 | def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor: |
| 171 | """Split the input into small patches with sliding window.""" |
| 172 | patch_size = 384 |
| 173 | patch_stride = int(patch_size * (1 - overlap_ratio)) |
| 174 | |
| 175 | image_size = x.shape[-1] |
| 176 | steps = int(math.ceil((image_size - patch_size) / patch_stride)) + 1 |
| 177 | |
| 178 | x_patch_list = [] |
| 179 | for j in range(steps): |
| 180 | j0 = j * patch_stride |
| 181 | j1 = j0 + patch_size |
| 182 | |
| 183 | for i in range(steps): |
| 184 | i0 = i * patch_stride |
| 185 | i1 = i0 + patch_size |
| 186 | x_patch_list.append(x[..., j0:j1, i0:i1]) |
| 187 | |
| 188 | return torch.cat(x_patch_list, dim=0) |
| 189 | |
| 190 | def merge(self, x: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor: |
| 191 | """Merge the patched input into a image with sliding window.""" |