(self, res,
points = [],
targets = [],
mask = None,
lambda_mask = 10,
reg = 0,
feature_idx = 5,
r1 = 3,
r2 = 12,
random_seed = 0,
noise_mode = 'const',
trunc_psi = 0.7,
force_fp32 = False,
layer_name = None,
sel_channels = 3,
base_channel = 0,
img_scale_db = 0,
img_normalize = False,
untransform = False,
is_drag = False,
reset = False,
to_pil = False,
**kwargs
)
| 270 | print(' Remain feat_refs and points0_pt') |
| 271 | |
| 272 | def _render_drag_impl(self, res, |
| 273 | points = [], |
| 274 | targets = [], |
| 275 | mask = None, |
| 276 | lambda_mask = 10, |
| 277 | reg = 0, |
| 278 | feature_idx = 5, |
| 279 | r1 = 3, |
| 280 | r2 = 12, |
| 281 | random_seed = 0, |
| 282 | noise_mode = 'const', |
| 283 | trunc_psi = 0.7, |
| 284 | force_fp32 = False, |
| 285 | layer_name = None, |
| 286 | sel_channels = 3, |
| 287 | base_channel = 0, |
| 288 | img_scale_db = 0, |
| 289 | img_normalize = False, |
| 290 | untransform = False, |
| 291 | is_drag = False, |
| 292 | reset = False, |
| 293 | to_pil = False, |
| 294 | **kwargs |
| 295 | ): |
| 296 | G = self.G |
| 297 | ws = self.w |
| 298 | if ws.dim() == 2: |
| 299 | ws = ws.unsqueeze(1).repeat(1,6,1) |
| 300 | ws = torch.cat([ws[:,:6,:], self.w0[:,6:,:]], dim=1) |
| 301 | if hasattr(self, 'points'): |
| 302 | if len(points) != len(self.points): |
| 303 | reset = True |
| 304 | if reset: |
| 305 | self.feat_refs = None |
| 306 | self.points0_pt = None |
| 307 | self.points = points |
| 308 | |
| 309 | # Run synthesis network. |
| 310 | label = torch.zeros([1, G.c_dim], device=self._device) |
| 311 | img, feat = G(ws, label, truncation_psi=trunc_psi, noise_mode=noise_mode, input_is_w=True, return_feature=True) |
| 312 | |
| 313 | h, w = G.img_resolution, G.img_resolution |
| 314 | |
| 315 | if is_drag: |
| 316 | X = torch.linspace(0, h, h) |
| 317 | Y = torch.linspace(0, w, w) |
| 318 | xx, yy = torch.meshgrid(X, Y) |
| 319 | feat_resize = F.interpolate(feat[feature_idx], [h, w], mode='bilinear') |
| 320 | if self.feat_refs is None: |
| 321 | self.feat0_resize = F.interpolate(feat[feature_idx].detach(), [h, w], mode='bilinear') |
| 322 | self.feat_refs = [] |
| 323 | for point in points: |
| 324 | py, px = round(point[0]), round(point[1]) |
| 325 | self.feat_refs.append(self.feat0_resize[:,:,py,px]) |
| 326 | self.points0_pt = torch.Tensor(points).unsqueeze(0).to(self._device) # 1, N, 2 |
| 327 | |
| 328 | # Point tracking with feature matching |
| 329 | with torch.no_grad(): |
no test coverage detected