MCPcopy
hub / github.com/XingangPan/DragGAN / _render_drag_impl

Method _render_drag_impl

viz/renderer.py:272–386  ·  view source on GitHub ↗
(self, res,
        points          = [],
        targets         = [],
        mask            = None,
        lambda_mask     = 10,
        reg             = 0,
        feature_idx     = 5,
        r1              = 3,
        r2              = 12,
        random_seed     = 0,
        noise_mode      = 'const',
        trunc_psi       = 0.7,
        force_fp32      = False,
        layer_name      = None,
        sel_channels    = 3,
        base_channel    = 0,
        img_scale_db    = 0,
        img_normalize   = False,
        untransform     = False,
        is_drag         = False,
        reset           = False,
        to_pil          = False,
        **kwargs
    )

Source from the content-addressed store, hash-verified

270 print(' Remain feat_refs and points0_pt')
271
272 def _render_drag_impl(self, res,
273 points = [],
274 targets = [],
275 mask = None,
276 lambda_mask = 10,
277 reg = 0,
278 feature_idx = 5,
279 r1 = 3,
280 r2 = 12,
281 random_seed = 0,
282 noise_mode = 'const',
283 trunc_psi = 0.7,
284 force_fp32 = False,
285 layer_name = None,
286 sel_channels = 3,
287 base_channel = 0,
288 img_scale_db = 0,
289 img_normalize = False,
290 untransform = False,
291 is_drag = False,
292 reset = False,
293 to_pil = False,
294 **kwargs
295 ):
296 G = self.G
297 ws = self.w
298 if ws.dim() == 2:
299 ws = ws.unsqueeze(1).repeat(1,6,1)
300 ws = torch.cat([ws[:,:6,:], self.w0[:,6:,:]], dim=1)
301 if hasattr(self, 'points'):
302 if len(points) != len(self.points):
303 reset = True
304 if reset:
305 self.feat_refs = None
306 self.points0_pt = None
307 self.points = points
308
309 # Run synthesis network.
310 label = torch.zeros([1, G.c_dim], device=self._device)
311 img, feat = G(ws, label, truncation_psi=trunc_psi, noise_mode=noise_mode, input_is_w=True, return_feature=True)
312
313 h, w = G.img_resolution, G.img_resolution
314
315 if is_drag:
316 X = torch.linspace(0, h, h)
317 Y = torch.linspace(0, w, w)
318 xx, yy = torch.meshgrid(X, Y)
319 feat_resize = F.interpolate(feat[feature_idx], [h, w], mode='bilinear')
320 if self.feat_refs is None:
321 self.feat0_resize = F.interpolate(feat[feature_idx].detach(), [h, w], mode='bilinear')
322 self.feat_refs = []
323 for point in points:
324 py, px = round(point[0]), round(point[1])
325 self.feat_refs.append(self.feat0_resize[:,:,py,px])
326 self.points0_pt = torch.Tensor(points).unsqueeze(0).to(self._device) # 1, N, 2
327
328 # Point tracking with feature matching
329 with torch.no_grad():

Callers 3

renderMethod · 0.95
init_imagesFunction · 0.80
on_click_startFunction · 0.80

Calls 1

backwardMethod · 0.45

Tested by

no test coverage detected