MCPcopy
hub / github.com/yerfor/GeneFacePlusPlus / render

Method render

modules/radnerfs/radnerf_torso.py:86–199  ·  view source on GitHub ↗
(self, rays_o, rays_d, cond, bg_coords, poses, index=0, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs)

Source from the content-addressed store, hash-verified

84 return alpha, color, dx
85
86 def render(self, rays_o, rays_d, cond, bg_coords, poses, index=0, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
87 # rays_o, rays_d: [B, N, 3], assumes B == 1
88 # cond: [B, 29, 16]
89 # bg_coords: [1, N, 2]
90 # return: pred_rgb: [B, N, 3]
91
92 ### run head nerf with no_grad to get the renderred head
93 with torch.no_grad():
94 prefix = rays_o.shape[:-1]
95 rays_o = rays_o.contiguous().view(-1, 3)
96 rays_d = rays_d.contiguous().view(-1, 3)
97 bg_coords = bg_coords.contiguous().view(-1, 2)
98 N = rays_o.shape[0] # N = B * N, in fact
99 device = rays_o.device
100 results = {}
101 # pre-calculate near far
102 nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near)
103 nears = nears.detach()
104 fars = fars.detach()
105 # encode audio
106 cond_feat = self.cal_cond_feat(cond) # [1, 64]
107 if self.individual_embedding_dim > 0:
108 if self.training:
109 ind_code = self.individual_embeddings[index]
110 # use a fixed ind code for the unknown test data.
111 else:
112 ind_code = self.individual_embeddings[0]
113 else:
114 ind_code = None
115 if self.training:
116 # setup counter
117 counter = self.step_counter[self.local_step % 16]
118 counter.zero_() # set to 0
119 self.local_step += 1
120 xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
121 # xyzs, dirs, deltas, rays, points2rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
122 sigmas, rgbs, ambient = self(xyzs, dirs, cond_feat, ind_code)
123 sigmas = self.density_scale * sigmas
124 #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})')
125 weights_sum, ambient_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, ambient.abs().sum(-1), deltas, rays)
126 # for training only
127 results['weights_sum'] = weights_sum
128 results['ambient'] = ambient_sum
129 else:
130 dtype = torch.float32
131 weights_sum = torch.zeros(N, dtype=dtype, device=device)
132 depth = torch.zeros(N, dtype=dtype, device=device)
133 image = torch.zeros(N, 3, dtype=dtype, device=device)
134 n_alive = N
135 rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
136 rays_t = nears.clone() # [N]
137 step = 0
138 while step < max_steps:
139 # count alive rays
140 n_alive = rays_alive.shape[0]
141 # exit loop
142 if n_alive <= 0:
143 break

Callers

nothing calls this directly

Calls 3

forward_torsoMethod · 0.95
viewMethod · 0.80
cal_cond_featMethod · 0.45

Tested by

no test coverage detected