march rays to generate points (forward only) Args: rays_o/d: float, [N, 3] bound: float, scalar density_bitfield: uint8: [CHHH // 8] C: int H: int nears/fars: float, [N] step_counter: int32, (2), used to cou
(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024)
| 174 | @staticmethod |
| 175 | @custom_fwd(cast_inputs=torch.float32) |
| 176 | def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024): |
| 177 | ''' march rays to generate points (forward only) |
| 178 | Args: |
| 179 | rays_o/d: float, [N, 3] |
| 180 | bound: float, scalar |
| 181 | density_bitfield: uint8: [CHHH // 8] |
| 182 | C: int |
| 183 | H: int |
| 184 | nears/fars: float, [N] |
| 185 | step_counter: int32, (2), used to count the actual number of generated points. |
| 186 | mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.) |
| 187 | perturb: bool |
| 188 | align: int, pad output so its size is dividable by align, set to -1 to disable. |
| 189 | force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays. |
| 190 | dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) |
| 191 | max_steps: int, max number of sampled points along each ray, also affect min_stepsize. |
| 192 | Returns: |
| 193 | xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray) |
| 194 | dirs: float, [M, 3], all generated points' view dirs. |
| 195 | deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth) |
| 196 | rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0] |
| 197 | ''' |
| 198 | |
| 199 | if not rays_o.is_cuda: rays_o = rays_o.cuda() |
| 200 | if not rays_d.is_cuda: rays_d = rays_d.cuda() |
| 201 | if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda() |
| 202 | |
| 203 | rays_o = rays_o.contiguous().view(-1, 3) |
| 204 | rays_d = rays_d.contiguous().view(-1, 3) |
| 205 | density_bitfield = density_bitfield.contiguous() |
| 206 | |
| 207 | N = rays_o.shape[0] # num rays |
| 208 | M = N * max_steps # init max points number in total |
| 209 | |
| 210 | # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp) |
| 211 | # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated. |
| 212 | if not force_all_rays and mean_count > 0: |
| 213 | if align > 0: |
| 214 | mean_count += align - mean_count % align |
| 215 | M = mean_count |
| 216 | |
| 217 | xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) |
| 218 | dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) |
| 219 | deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) |
| 220 | rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps |
| 221 | |
| 222 | if step_counter is None: |
| 223 | step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter |
| 224 | |
| 225 | if perturb: |
| 226 | noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device) |
| 227 | else: |
| 228 | noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device) |
| 229 | |
| 230 | get_backend().march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number |
| 231 | |
| 232 | #print(step_counter, M) |
| 233 |
nothing calls this directly
no test coverage detected