(self, rays_o, rays_d, cond, bg_coords, poses, index=0, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs)
| 84 | return alpha, color, dx |
| 85 | |
| 86 | def render(self, rays_o, rays_d, cond, bg_coords, poses, index=0, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs): |
| 87 | # rays_o, rays_d: [B, N, 3], assumes B == 1 |
| 88 | # cond: [B, 29, 16] |
| 89 | # bg_coords: [1, N, 2] |
| 90 | # return: pred_rgb: [B, N, 3] |
| 91 | |
| 92 | ### run head nerf with no_grad to get the renderred head |
| 93 | with torch.no_grad(): |
| 94 | prefix = rays_o.shape[:-1] |
| 95 | rays_o = rays_o.contiguous().view(-1, 3) |
| 96 | rays_d = rays_d.contiguous().view(-1, 3) |
| 97 | bg_coords = bg_coords.contiguous().view(-1, 2) |
| 98 | N = rays_o.shape[0] # N = B * N, in fact |
| 99 | device = rays_o.device |
| 100 | results = {} |
| 101 | # pre-calculate near far |
| 102 | nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near) |
| 103 | nears = nears.detach() |
| 104 | fars = fars.detach() |
| 105 | # encode audio |
| 106 | cond_feat = self.cal_cond_feat(cond) # [1, 64] |
| 107 | if self.individual_embedding_dim > 0: |
| 108 | if self.training: |
| 109 | ind_code = self.individual_embeddings[index] |
| 110 | # use a fixed ind code for the unknown test data. |
| 111 | else: |
| 112 | ind_code = self.individual_embeddings[0] |
| 113 | else: |
| 114 | ind_code = None |
| 115 | if self.training: |
| 116 | # setup counter |
| 117 | counter = self.step_counter[self.local_step % 16] |
| 118 | counter.zero_() # set to 0 |
| 119 | self.local_step += 1 |
| 120 | xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps) |
| 121 | # xyzs, dirs, deltas, rays, points2rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps) |
| 122 | sigmas, rgbs, ambient = self(xyzs, dirs, cond_feat, ind_code) |
| 123 | sigmas = self.density_scale * sigmas |
| 124 | #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})') |
| 125 | weights_sum, ambient_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, ambient.abs().sum(-1), deltas, rays) |
| 126 | # for training only |
| 127 | results['weights_sum'] = weights_sum |
| 128 | results['ambient'] = ambient_sum |
| 129 | else: |
| 130 | dtype = torch.float32 |
| 131 | weights_sum = torch.zeros(N, dtype=dtype, device=device) |
| 132 | depth = torch.zeros(N, dtype=dtype, device=device) |
| 133 | image = torch.zeros(N, 3, dtype=dtype, device=device) |
| 134 | n_alive = N |
| 135 | rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] |
| 136 | rays_t = nears.clone() # [N] |
| 137 | step = 0 |
| 138 | while step < max_steps: |
| 139 | # count alive rays |
| 140 | n_alive = rays_alive.shape[0] |
| 141 | # exit loop |
| 142 | if n_alive <= 0: |
| 143 | break |
nothing calls this directly
no test coverage detected