MCPcopy Index your code
hub / github.com/OpenGVLab/DragGAN / point_tracking

Function point_tracking

draggan/deprecated/api.py:211–244  ·  view source on GitHub ↗
(
    F: torch.Tensor,
    F0: torch.Tensor,
    handle_points: torch.Tensor,
    handle_points0: torch.Tensor,
    r2: int = 3,
    device: torch.device = torch.device("cuda"),
)

Source from the content-addressed store, hash-verified

209
210
211def point_tracking(
212 F: torch.Tensor,
213 F0: torch.Tensor,
214 handle_points: torch.Tensor,
215 handle_points0: torch.Tensor,
216 r2: int = 3,
217 device: torch.device = torch.device("cuda"),
218) -> torch.Tensor:
219
220 n = handle_points.shape[0] # Number of handle points
221 new_handle_points = torch.zeros_like(handle_points)
222
223 for i in range(n):
224 # Compute the patch around the handle point
225 patch = utils.create_square_mask(
226 F.shape[2], F.shape[3], center=handle_points[i].tolist(), radius=r2
227 ).to(device)
228
229 # Find indices where the patch is True
230 patch_coordinates = torch.nonzero(patch) # shape [num_points, 2]
231
232 # Extract features in the patch
233 F_qi = F[:, :, patch_coordinates[:, 0], patch_coordinates[:, 1]]
234 # Extract feature of the initial handle point
235 f_i = F0[:, :, handle_points0[i][0].long(), handle_points0[i][1].long()]
236
237 # Compute the L1 distance between the patch features and the initial handle point feature
238 distances = torch.norm(F_qi - f_i[:, :, None], p=1, dim=1)
239
240 # Find the new handle point as the one with minimum distance
241 min_index = torch.argmin(distances)
242 new_handle_points[i] = patch_coordinates[min_index]
243
244 return new_handle_points

Callers 1

drag_ganFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected