(
g_ema,
latent: torch.Tensor,
noise,
F,
handle_points,
target_points,
mask,
max_iters=1000,
r1=3,
r2=12,
lam=20,
d=2,
lr=2e-3,
)
| 101 | |
| 102 | |
| 103 | def drag_gan( |
| 104 | g_ema, |
| 105 | latent: torch.Tensor, |
| 106 | noise, |
| 107 | F, |
| 108 | handle_points, |
| 109 | target_points, |
| 110 | mask, |
| 111 | max_iters=1000, |
| 112 | r1=3, |
| 113 | r2=12, |
| 114 | lam=20, |
| 115 | d=2, |
| 116 | lr=2e-3, |
| 117 | ): |
| 118 | handle_points0 = copy.deepcopy(handle_points) |
| 119 | handle_points = torch.stack(handle_points) |
| 120 | handle_points0 = torch.stack(handle_points0) |
| 121 | target_points = torch.stack(target_points) |
| 122 | |
| 123 | F0 = F.detach().clone() |
| 124 | device = latent.device |
| 125 | |
| 126 | latent_trainable = latent[:, :6, :].detach().clone().requires_grad_(True) |
| 127 | latent_untrainable = latent[:, 6:, :].detach().clone().requires_grad_(False) |
| 128 | optimizer = torch.optim.Adam([latent_trainable], lr=lr) |
| 129 | for _ in range(max_iters): |
| 130 | if torch.allclose(handle_points, target_points, atol=d): |
| 131 | break |
| 132 | |
| 133 | optimizer.zero_grad() |
| 134 | latent = torch.cat([latent_trainable, latent_untrainable], dim=1) |
| 135 | sample2, F2 = g_ema.generate(latent, noise) |
| 136 | |
| 137 | # motion supervision |
| 138 | loss = motion_supervison(handle_points, target_points, F2, r1, device) |
| 139 | |
| 140 | if mask is not None: |
| 141 | loss += ((F2 - F0) * (1 - mask)).abs().mean() * lam |
| 142 | |
| 143 | loss.backward() |
| 144 | optimizer.step() |
| 145 | |
| 146 | with torch.no_grad(): |
| 147 | latent = torch.cat([latent_trainable, latent_untrainable], dim=1) |
| 148 | sample2, F2 = g_ema.generate(latent, noise) |
| 149 | handle_points = point_tracking(F2, F0, handle_points, handle_points0, r2, device) |
| 150 | |
| 151 | F = F2.detach().clone() |
| 152 | # if iter % 1 == 0: |
| 153 | # print(iter, loss.item(), handle_points, target_points) |
| 154 | |
| 155 | yield sample2, latent, F2, handle_points |
| 156 | |
| 157 | |
| 158 | def motion_supervison(handle_points, target_points, F2, r1, device): |
no test coverage detected