MCPcopy
hub / github.com/SizheAn/PanoHead / project

Function project

projector.py:20–148  ·  view source on GitHub ↗
(
    G,
    target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
    c: torch.Tensor,
    *,
    num_steps                  = 1000,
    w_avg_samples              = 10000,
    initial_learning_rate      = 0.1,
    initial_noise_factor       = 0.05,
    lr_rampdown_length         = 0.25,
    lr_rampup_length           = 0.05,
    noise_ramp_length          = 0.75,
    regularize_noise_weight    = 1e5,
    optimize_noise             = False,
    verbose                    = False,
    device: torch.device
)

Source from the content-addressed store, hash-verified

18from camera_utils import LookAtPoseSampler
19
20def project(
21 G,
22 target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
23 c: torch.Tensor,
24 *,
25 num_steps = 1000,
26 w_avg_samples = 10000,
27 initial_learning_rate = 0.1,
28 initial_noise_factor = 0.05,
29 lr_rampdown_length = 0.25,
30 lr_rampup_length = 0.05,
31 noise_ramp_length = 0.75,
32 regularize_noise_weight = 1e5,
33 optimize_noise = False,
34 verbose = False,
35 device: torch.device
36):
37 assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
38
39 def logprint(*args):
40 if verbose:
41 print(*args)
42
43 G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore
44
45 # Compute w stats.
46 logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
47 z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
48 camera_lookat_point = torch.tensor([0, 0, 0.0], device=device)
49 cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, camera_lookat_point, radius=2.7, device=device)
50 intrinsics = torch.tensor([[4.2647, 0, 0.5], [0, 4.2647, 0.5], [0, 0, 1]], device=device)
51 c_samples = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
52 w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples.repeat(w_avg_samples,1)) # [N, L, C]
53 w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
54 w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
55 w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
56
57 # Setup noise inputs.
58 noise_bufs = { name: buf for (name, buf) in G.backbone.synthesis.named_buffers() if 'noise_const' in name }
59
60 # Load VGG16 feature detector.
61 url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
62 with dnnlib.util.open_url(url) as f:
63 vgg16 = torch.jit.load(f).eval().to(device)
64
65 # Features for target image.
66 target_images = target.unsqueeze(0).to(device).to(torch.float32) / 255.0 * 2 - 1
67 target_images_perc = (target_images + 1) * (255/2)
68 if target_images_perc.shape[2] > 256:
69 target_images_perc = F.interpolate(target_images_perc, size=(256, 256), mode='area')
70 target_features = vgg16(target_images_perc, resize_images=False, return_lpips=True)
71
72 w_avg = torch.tensor(w_avg, dtype=torch.float32, device=device).repeat(1, G.backbone.mapping.num_ws, 1)
73 w_opt = w_avg.detach().clone()
74 w_opt.requires_grad = True
75 w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device="cpu")
76 if optimize_noise:
77 optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)

Callers 1

run_projectionFunction · 0.70

Calls 7

mappingMethod · 0.80
meanMethod · 0.80
loadMethod · 0.80
synthesisMethod · 0.80
logprintFunction · 0.70
sampleMethod · 0.45
backwardMethod · 0.45

Tested by

no test coverage detected