MCPcopy Index your code
hub / github.com/OpenGVLab/DragGAN / drag_gan

Function drag_gan

draggan/deprecated/api.py:103–155  ·  view source on GitHub ↗
(
    g_ema,
    latent: torch.Tensor,
    noise,
    F,
    handle_points,
    target_points,
    mask,
    max_iters=1000,
    r1=3,
    r2=12,
    lam=20,
    d=2,
    lr=2e-3,
)

Source from the content-addressed store, hash-verified

101
102
103def drag_gan(
104 g_ema,
105 latent: torch.Tensor,
106 noise,
107 F,
108 handle_points,
109 target_points,
110 mask,
111 max_iters=1000,
112 r1=3,
113 r2=12,
114 lam=20,
115 d=2,
116 lr=2e-3,
117):
118 handle_points0 = copy.deepcopy(handle_points)
119 handle_points = torch.stack(handle_points)
120 handle_points0 = torch.stack(handle_points0)
121 target_points = torch.stack(target_points)
122
123 F0 = F.detach().clone()
124 device = latent.device
125
126 latent_trainable = latent[:, :6, :].detach().clone().requires_grad_(True)
127 latent_untrainable = latent[:, 6:, :].detach().clone().requires_grad_(False)
128 optimizer = torch.optim.Adam([latent_trainable], lr=lr)
129 for _ in range(max_iters):
130 if torch.allclose(handle_points, target_points, atol=d):
131 break
132
133 optimizer.zero_grad()
134 latent = torch.cat([latent_trainable, latent_untrainable], dim=1)
135 sample2, F2 = g_ema.generate(latent, noise)
136
137 # motion supervision
138 loss = motion_supervison(handle_points, target_points, F2, r1, device)
139
140 if mask is not None:
141 loss += ((F2 - F0) * (1 - mask)).abs().mean() * lam
142
143 loss.backward()
144 optimizer.step()
145
146 with torch.no_grad():
147 latent = torch.cat([latent_trainable, latent_untrainable], dim=1)
148 sample2, F2 = g_ema.generate(latent, noise)
149 handle_points = point_tracking(F2, F0, handle_points, handle_points0, r2, device)
150
151 F = F2.detach().clone()
152 # if iter % 1 == 0:
153 # print(iter, loss.item(), handle_points, target_points)
154
155 yield sample2, latent, F2, handle_points
156
157
158def motion_supervison(handle_points, target_points, F2, r1, device):

Callers 1

on_dragFunction · 0.70

Calls 5

generateMethod · 0.80
meanMethod · 0.80
motion_supervisonFunction · 0.70
point_trackingFunction · 0.70
backwardMethod · 0.45

Tested by

no test coverage detected