(ctx, grad_weights_sum, grad_depth, grad_image)
| 339 | @staticmethod |
| 340 | @custom_bwd |
| 341 | def backward(ctx, grad_weights_sum, grad_depth, grad_image): |
| 342 | |
| 343 | # NOTE: grad_depth is not used now! It won't be propagated to sigmas. |
| 344 | |
| 345 | grad_weights_sum = grad_weights_sum.contiguous() |
| 346 | grad_image = grad_image.contiguous() |
| 347 | |
| 348 | sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors |
| 349 | M, N, T_thresh = ctx.dims |
| 350 | |
| 351 | grad_sigmas = torch.zeros_like(sigmas) |
| 352 | grad_rgbs = torch.zeros_like(rgbs) |
| 353 | |
| 354 | get_backend().composite_sdf_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs) |
| 355 | |
| 356 | return grad_sigmas, grad_rgbs, None, None, None |
| 357 | |
| 358 | |
| 359 | composite_sdf_rays_train = _composite_sdf_rays_train.apply |
nothing calls this directly
no test coverage detected