(self, rays_o, rays_d, cond, bg_coords, poses, index=0, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, cond_mask=None,eye_area_percent=None,**kwargs)
| 284 | self.local_step = 0 |
| 285 | |
| 286 | def render(self, rays_o, rays_d, cond, bg_coords, poses, index=0, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, cond_mask=None,eye_area_percent=None,**kwargs): |
| 287 | # rays_o, rays_d: [B, N, 3], assumes B == 1 |
| 288 | # cond: [B, 29, 16] |
| 289 | # bg_coords: [1, N, 2] |
| 290 | # return: pred_rgb: [B, N, 3] |
| 291 | |
| 292 | prefix = rays_o.shape[:-1] |
| 293 | rays_o = rays_o.contiguous().view(-1, 3) |
| 294 | rays_d = rays_d.contiguous().view(-1, 3) |
| 295 | bg_coords = bg_coords.contiguous().view(-1, 2) |
| 296 | |
| 297 | N = rays_o.shape[0] # N = B * N, in fact |
| 298 | device = rays_o.device |
| 299 | |
| 300 | results = {} |
| 301 | |
| 302 | # pre-calculate near far |
| 303 | nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near) |
| 304 | nears = nears.detach() |
| 305 | fars = fars.detach() |
| 306 | |
| 307 | # encode audio |
| 308 | cond_feat = self.cal_cond_feat(cond, eye_area_percent=eye_area_percent) # [1, 64] |
| 309 | |
| 310 | if self.individual_embedding_dim > 0: |
| 311 | if self.training: |
| 312 | ind_code = self.individual_embeddings[index] |
| 313 | # use a fixed ind code for the unknown test data. |
| 314 | else: |
| 315 | ind_code = self.individual_embeddings[0] |
| 316 | else: |
| 317 | ind_code = None |
| 318 | |
| 319 | if self.training: |
| 320 | # setup counter |
| 321 | counter = self.step_counter[self.local_step % 16] |
| 322 | counter.zero_() # set to 0 |
| 323 | self.local_step += 1 |
| 324 | |
| 325 | align = 128 |
| 326 | xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, align, force_all_rays, dt_gamma, max_steps) |
| 327 | # xyzs, dirs, deltas, rays, points2rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, align, force_all_rays, dt_gamma, max_steps) |
| 328 | |
| 329 | sigmas, rgbs, ambient = self.forward(xyzs, dirs, cond_feat, ind_code, cond_mask=cond_mask) |
| 330 | sigmas = self.density_scale * sigmas |
| 331 | |
| 332 | #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})') |
| 333 | |
| 334 | weights_sum, ambient_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, ambient.abs().sum(-1), deltas, rays) |
| 335 | |
| 336 | # for training only |
| 337 | results['weights_sum'] = weights_sum |
| 338 | results['ambient'] = ambient_sum |
| 339 | results['position'] = xyzs |
| 340 | else: |
| 341 | |
| 342 | dtype = torch.float32 |
| 343 |
no test coverage detected