MCPcopy
hub / github.com/SizheAn/PanoHead / project_pti

Function project_pti

projector_withseg.py:186–259  ·  view source on GitHub ↗
(
    G,
    target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
    w_pivot: torch.Tensor,
    c: torch.Tensor,
    *,
    num_steps                  = 1000,
    initial_learning_rate      = 3e-4,
    lr_rampdown_length         = 0.25,
    lr_rampup_length           = 0.05,
    verbose                    = False,
    device: torch.device
)

Source from the content-addressed store, hash-verified

184
185
186def project_pti(
187 G,
188 target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
189 w_pivot: torch.Tensor,
190 c: torch.Tensor,
191 *,
192 num_steps = 1000,
193 initial_learning_rate = 3e-4,
194 lr_rampdown_length = 0.25,
195 lr_rampup_length = 0.05,
196 verbose = False,
197 device: torch.device
198):
199 assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
200
201 def logprint(*args):
202 if verbose:
203 print(*args)
204
205 G = copy.deepcopy(G).train().requires_grad_(True).to(device) # type: ignore
206
207 # Load VGG16 feature detector.
208 url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
209 with dnnlib.util.open_url(url) as f:
210 vgg16 = torch.jit.load(f).eval().to(device)
211
212 # Features for target image.
213 target_images = target.unsqueeze(0).to(device).to(torch.float32) / 255.0 * 2 - 1
214 target_images_perc = (target_images + 1) * (255/2)
215 if target_images_perc.shape[2] > 256:
216 target_images_perc = F.interpolate(target_images_perc, size=(256, 256), mode='area')
217 target_features = vgg16(target_images_perc, resize_images=False, return_lpips=True)
218
219 w_pivot = w_pivot.to(device).detach()
220 optimizer = torch.optim.Adam(G.parameters(), betas=(0.9, 0.999), lr=initial_learning_rate)
221
222 out_params = []
223
224 for step in range(num_steps):
225 # Learning rate schedule.
226 # t = step / num_steps
227 # lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
228 # lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
229 # lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
230 # lr = initial_learning_rate * lr_ramp
231 # for param_group in optimizer.param_groups:
232 # param_group['lr'] = lr
233
234 # Synth images from opt_w.
235 synth_images = G.synthesis(w_pivot, c=c, noise_mode='const')['image']
236
237 # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
238 synth_images_perc = (synth_images + 1) * (255/2)
239 if synth_images_perc.shape[2] > 256:
240 synth_images_perc = F.interpolate(synth_images_perc, size=(256, 256), mode='area')
241
242 # Features for synth images.
243 synth_features = vgg16(synth_images_perc, resize_images=False, return_lpips=True)

Callers 1

run_projectionFunction · 0.70

Calls 6

loadMethod · 0.80
synthesisMethod · 0.80
meanMethod · 0.80
appendMethod · 0.80
logprintFunction · 0.70
backwardMethod · 0.45

Tested by

no test coverage detected