MCPcopy
hub / github.com/SizheAn/PanoHead / project

Function project

projector_withseg.py:47–183  ·  view source on GitHub ↗
(
    G,
    target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
    c: torch.Tensor,
    *,
    num_steps                  = 1000,
    w_avg_samples              = 10000,
    initial_learning_rate      = 0.1,
    initial_noise_factor       = 0.05,
    lr_rampdown_length         = 0.25,
    lr_rampup_length           = 0.05,
    noise_ramp_length          = 0.75,
    regularize_noise_weight    = 1e5,
    optimize_noise             = False,
    verbose                    = False,
    device: torch.device
)

Source from the content-addressed store, hash-verified

45
46
47def project(
48 G,
49 target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
50 c: torch.Tensor,
51 *,
52 num_steps = 1000,
53 w_avg_samples = 10000,
54 initial_learning_rate = 0.1,
55 initial_noise_factor = 0.05,
56 lr_rampdown_length = 0.25,
57 lr_rampup_length = 0.05,
58 noise_ramp_length = 0.75,
59 regularize_noise_weight = 1e5,
60 optimize_noise = False,
61 verbose = False,
62 device: torch.device
63):
64 assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
65
66 def logprint(*args):
67 if verbose:
68 print(*args)
69
70 G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore
71
72 # Compute w stats.
73 logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
74 z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
75 camera_lookat_point = torch.tensor([0, 0, 0.0], device=device)
76 cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, camera_lookat_point, radius=2.7, device=device)
77 intrinsics = torch.tensor([[4.2647, 0, 0.5], [0, 4.2647, 0.5], [0, 0, 1]], device=device)
78 c_samples = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
79 w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples.repeat(w_avg_samples,1)) # [N, L, C]
80 w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
81 w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
82 w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
83
84 # fix delta_c
85
86 delta_c = G.t_mapping(torch.from_numpy(np.mean(z_samples, axis=0, keepdims=True)).to(device), c[:1], truncation_psi=1.0, truncation_cutoff=None, update_emas=False)
87 delta_c = torch.squeeze(delta_c, 1)
88 c[:,3] += delta_c[:,0]
89 c[:,7] += delta_c[:,1]
90 c[:,11] += delta_c[:,2]
91
92 # Setup noise inputs.
93 noise_bufs = { name: buf for (name, buf) in G.backbone.synthesis.named_buffers() if 'noise_const' in name }
94
95 # Load VGG16 feature detector.
96 url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
97 with dnnlib.util.open_url(url) as f:
98 vgg16 = torch.jit.load(f).eval().to(device)
99
100 # Features for target image.
101 target_images = target.unsqueeze(0).to(device).to(torch.float32) / 255.0 * 2 - 1
102 target_images_perc = (target_images + 1) * (255/2)
103 if target_images_perc.shape[2] > 256:
104 target_images_perc = F.interpolate(target_images_perc, size=(256, 256), mode='area')

Callers 1

run_projectionFunction · 0.70

Calls 7

mappingMethod · 0.80
meanMethod · 0.80
loadMethod · 0.80
synthesisMethod · 0.80
logprintFunction · 0.70
sampleMethod · 0.45
backwardMethod · 0.45

Tested by

no test coverage detected