| 156 | |
| 157 | |
| 158 | def motion_supervison(handle_points, target_points, F2, r1, device): |
| 159 | loss = 0 |
| 160 | n = len(handle_points) |
| 161 | for i in range(n): |
| 162 | target2handle = target_points[i] - handle_points[i] |
| 163 | d_i = target2handle / (torch.norm(target2handle) + 1e-7) |
| 164 | if torch.norm(d_i) > torch.norm(target2handle): |
| 165 | d_i = target2handle |
| 166 | |
| 167 | mask = utils.create_circular_mask( |
| 168 | F2.shape[2], F2.shape[3], center=handle_points[i].tolist(), radius=r1 |
| 169 | ).to(device) |
| 170 | |
| 171 | coordinates = torch.nonzero(mask).float() # shape [num_points, 2] |
| 172 | |
| 173 | # Shift the coordinates in the direction d_i |
| 174 | shifted_coordinates = coordinates + d_i[None] |
| 175 | |
| 176 | h, w = F2.shape[2], F2.shape[3] |
| 177 | |
| 178 | # Extract features in the mask region and compute the loss |
| 179 | F_qi = F2[:, :, mask] # shape: [C, H*W] |
| 180 | |
| 181 | # Sample shifted patch from F |
| 182 | normalized_shifted_coordinates = shifted_coordinates.clone() |
| 183 | normalized_shifted_coordinates[:, 0] = ( |
| 184 | 2.0 * shifted_coordinates[:, 0] / (h - 1) |
| 185 | ) - 1 # for height |
| 186 | normalized_shifted_coordinates[:, 1] = ( |
| 187 | 2.0 * shifted_coordinates[:, 1] / (w - 1) |
| 188 | ) - 1 # for width |
| 189 | # Add extra dimensions for batch and channels (required by grid_sample) |
| 190 | normalized_shifted_coordinates = normalized_shifted_coordinates.unsqueeze( |
| 191 | 0 |
| 192 | ).unsqueeze( |
| 193 | 0 |
| 194 | ) # shape [1, 1, num_points, 2] |
| 195 | normalized_shifted_coordinates = normalized_shifted_coordinates.flip( |
| 196 | -1 |
| 197 | ) # grid_sample expects [x, y] instead of [y, x] |
| 198 | normalized_shifted_coordinates = normalized_shifted_coordinates.clamp(-1, 1) |
| 199 | |
| 200 | # Use grid_sample to interpolate the feature map F at the shifted patch coordinates |
| 201 | F_qi_plus_di = torch.nn.functional.grid_sample( |
| 202 | F2, normalized_shifted_coordinates, mode="bilinear", align_corners=True |
| 203 | ) |
| 204 | # Output has shape [1, C, 1, num_points] so squeeze it |
| 205 | F_qi_plus_di = F_qi_plus_di.squeeze(2) # shape [1, C, num_points] |
| 206 | |
| 207 | loss += torch.nn.functional.l1_loss(F_qi.detach(), F_qi_plus_di) |
| 208 | return loss |
| 209 | |
| 210 | |
| 211 | def point_tracking( |