| 209 | |
| 210 | |
| 211 | def point_tracking( |
| 212 | F: torch.Tensor, |
| 213 | F0: torch.Tensor, |
| 214 | handle_points: torch.Tensor, |
| 215 | handle_points0: torch.Tensor, |
| 216 | r2: int = 3, |
| 217 | device: torch.device = torch.device("cuda"), |
| 218 | ) -> torch.Tensor: |
| 219 | |
| 220 | n = handle_points.shape[0] # Number of handle points |
| 221 | new_handle_points = torch.zeros_like(handle_points) |
| 222 | |
| 223 | for i in range(n): |
| 224 | # Compute the patch around the handle point |
| 225 | patch = utils.create_square_mask( |
| 226 | F.shape[2], F.shape[3], center=handle_points[i].tolist(), radius=r2 |
| 227 | ).to(device) |
| 228 | |
| 229 | # Find indices where the patch is True |
| 230 | patch_coordinates = torch.nonzero(patch) # shape [num_points, 2] |
| 231 | |
| 232 | # Extract features in the patch |
| 233 | F_qi = F[:, :, patch_coordinates[:, 0], patch_coordinates[:, 1]] |
| 234 | # Extract feature of the initial handle point |
| 235 | f_i = F0[:, :, handle_points0[i][0].long(), handle_points0[i][1].long()] |
| 236 | |
| 237 | # Compute the L1 distance between the patch features and the initial handle point feature |
| 238 | distances = torch.norm(F_qi - f_i[:, :, None], p=1, dim=1) |
| 239 | |
| 240 | # Find the new handle point as the one with minimum distance |
| 241 | min_index = torch.argmin(distances) |
| 242 | new_handle_points[i] = patch_coordinates[min_index] |
| 243 | |
| 244 | return new_handle_points |