MCPcopy Index your code
hub / github.com/OpenGVLab/DragGAN / motion_supervison

Function motion_supervison

draggan/deprecated/api.py:158–208  ·  view source on GitHub ↗
(handle_points, target_points, F2, r1, device)

Source from the content-addressed store, hash-verified

156
157
158def motion_supervison(handle_points, target_points, F2, r1, device):
159 loss = 0
160 n = len(handle_points)
161 for i in range(n):
162 target2handle = target_points[i] - handle_points[i]
163 d_i = target2handle / (torch.norm(target2handle) + 1e-7)
164 if torch.norm(d_i) > torch.norm(target2handle):
165 d_i = target2handle
166
167 mask = utils.create_circular_mask(
168 F2.shape[2], F2.shape[3], center=handle_points[i].tolist(), radius=r1
169 ).to(device)
170
171 coordinates = torch.nonzero(mask).float() # shape [num_points, 2]
172
173 # Shift the coordinates in the direction d_i
174 shifted_coordinates = coordinates + d_i[None]
175
176 h, w = F2.shape[2], F2.shape[3]
177
178 # Extract features in the mask region and compute the loss
179 F_qi = F2[:, :, mask] # shape: [C, H*W]
180
181 # Sample shifted patch from F
182 normalized_shifted_coordinates = shifted_coordinates.clone()
183 normalized_shifted_coordinates[:, 0] = (
184 2.0 * shifted_coordinates[:, 0] / (h - 1)
185 ) - 1 # for height
186 normalized_shifted_coordinates[:, 1] = (
187 2.0 * shifted_coordinates[:, 1] / (w - 1)
188 ) - 1 # for width
189 # Add extra dimensions for batch and channels (required by grid_sample)
190 normalized_shifted_coordinates = normalized_shifted_coordinates.unsqueeze(
191 0
192 ).unsqueeze(
193 0
194 ) # shape [1, 1, num_points, 2]
195 normalized_shifted_coordinates = normalized_shifted_coordinates.flip(
196 -1
197 ) # grid_sample expects [x, y] instead of [y, x]
198 normalized_shifted_coordinates = normalized_shifted_coordinates.clamp(-1, 1)
199
200 # Use grid_sample to interpolate the feature map F at the shifted patch coordinates
201 F_qi_plus_di = torch.nn.functional.grid_sample(
202 F2, normalized_shifted_coordinates, mode="bilinear", align_corners=True
203 )
204 # Output has shape [1, C, 1, num_points] so squeeze it
205 F_qi_plus_di = F_qi_plus_di.squeeze(2) # shape [1, C, num_points]
206
207 loss += torch.nn.functional.l1_loss(F_qi.detach(), F_qi_plus_di)
208 return loss
209
210
211def point_tracking(

Callers 1

drag_ganFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected