(pcd=pcd, feat=feat, radius=radius)
| 53 | |
| 54 | |
| 55 | def test_pytorch3d_rendering(pcd=pcd, feat=feat, radius=radius): |
| 56 | H, W, K, R, T, C = get_pytorch3d_camera_params(add_batch(to_cuda(camera.to_batch()))) |
| 57 | cameras = PerspectiveCameras(device='cuda', R=R, T=T, K=K) |
| 58 | ndc_pcd = rasterizer.transform(Pointclouds(pcd), cameras=cameras).points_padded() # B, N, 3 |
| 59 | if isinstance(radius, torch.Tensor): radius = radius[..., 0] # remove last dim to make it homogenous for float and tensor |
| 60 | radius = abs(K[..., 1, 1][..., None] * radius / (ndc_pcd[..., -1] + 1e-10)) # z: B, 1 * B, N, world space radius |
| 61 | |
| 62 | # Actual forward rasterization |
| 63 | # FIXME: STUPID PYTORCH3D POINT CLOUD CREATION WILL SYNC CUDA |
| 64 | idx, zbuf, dists = rasterize_points(Pointclouds(ndc_pcd), (H, W), radius, pts_per_pix, None, None) |
| 65 | frags = PointFragments(idx=idx, zbuf=zbuf, dists=dists) |
| 66 | idx, zbuf, dists = frags.idx, frags.zbuf, frags.dists |
| 67 | |
| 68 | # Prepare for composition |
| 69 | pix_radius = multi_gather(radius, torch.where(idx == -1, 0, idx).view(radius.shape[0], -1).long(), dim=-1).view(idx.shape) # B, H, W, K (B, HWK -> B, N -> B, H, W, K) |
| 70 | pix_weight = 1 - dists / (pix_radius * pix_radius) # B, H, W, K |
| 71 | acc = torch.ones_like(feat[..., :1]) |
| 72 | depth: torch.Tensor = (pcd - C.mT).norm(dim=-1, keepdim=True) |
| 73 | feat = torch.cat([feat, acc, depth], dim=-1) # B, N, 3 + C |
| 74 | |
| 75 | # The actual computation |
| 76 | feat = compositor(idx.long().permute(0, 3, 1, 2), |
| 77 | pix_weight.permute(0, 3, 1, 2), |
| 78 | feat.view(-1, feat.shape[-1]).permute(1, 0)).permute(0, 2, 3, 1) # B, H, W, 3 |
| 79 | |
| 80 | # TODO: Implement and return random background here |
| 81 | rgb_map, acc_map, dpt_map = feat[..., :3], feat[..., 3:4], feat[..., 4:5] |
| 82 | dpt_map = dpt_map + (1 - acc_map) * depth.max() |
| 83 | rgb_map, acc_map, dpt_map = torch.cat([rgb_map, acc_map], dim=-1), torch.cat([acc_map, acc_map, acc_map, acc_map], dim=-1), torch.cat([dpt_map, dpt_map, dpt_map, acc_map], dim=-1) |
| 84 | |
| 85 | save_image('test_pytorch3d_rendering_rgb.png', rgb_map[0].detach().cpu().numpy()) |
| 86 | save_image('test_pytorch3d_rendering_dpt.png', dpt_map[0].detach().cpu().numpy()) |
| 87 | save_image('test_pytorch3d_rendering_acc.png', acc_map[0].detach().cpu().numpy()) |
| 88 | |
| 89 | |
| 90 | def test_hybrid_pulsar_rendering(pcd=pcd, feat=feat, radius=radius): |
nothing calls this directly
no test coverage detected