Apply multi-view transform to vertices based on camera transform matrix. Args: vertices: torch.Tensor, shape [N, 3], vertex coordinates frame: dict containing transform_matrix Returns: transformed_vertices: torch.Tensor, shape [N, 3]
(vertices, frame)
| 23 | # ==================== PBR-specific transform functions ==================== |
| 24 | |
| 25 | def transform_vertices(vertices, frame): |
| 26 | """ |
| 27 | Apply multi-view transform to vertices based on camera transform matrix. |
| 28 | |
| 29 | Args: |
| 30 | vertices: torch.Tensor, shape [N, 3], vertex coordinates |
| 31 | frame: dict containing transform_matrix |
| 32 | |
| 33 | Returns: |
| 34 | transformed_vertices: torch.Tensor, shape [N, 3] |
| 35 | """ |
| 36 | device = vertices.device |
| 37 | c2w_orig = torch.tensor(frame['transform_matrix'], dtype=torch.float32, device=device) |
| 38 | |
| 39 | # Old and new camera matrices |
| 40 | radius = c2w_orig[:3, 3].norm().item() |
| 41 | c2w_new = get_new_camera_matrix(radius=radius, yaw=-90/180.0*math.pi, pitch=0.0, |
| 42 | dtype=torch.float32, device=device) |
| 43 | w2c_orig = torch.inverse(c2w_orig) |
| 44 | |
| 45 | # Initial and final axis alignment matrices |
| 46 | R_init = torch.tensor([ |
| 47 | [1.0, 0.0, 0.0, 0.0], |
| 48 | [0.0, 0.0, -1.0, 0.0], |
| 49 | [0.0, 1.0, 0.0, 0.0], |
| 50 | [0.0, 0.0, 0.0, 1.0] |
| 51 | ], dtype=torch.float32, device=device) |
| 52 | |
| 53 | R_back = torch.tensor([ |
| 54 | [1.0, 0.0, 0.0, 0.0], |
| 55 | [0.0, 0.0, 1.0, 0.0], |
| 56 | [0.0, -1.0, 0.0, 0.0], |
| 57 | [0.0, 0.0, 0.0, 1.0] |
| 58 | ], dtype=torch.float32, device=device) |
| 59 | |
| 60 | R_ply = torch.tensor([ |
| 61 | [1.0, 0.0, 0.0, 0.0], |
| 62 | [0.0, 0.0, 1.0, 0.0], |
| 63 | [0.0, -1.0, 0.0, 0.0], |
| 64 | [0.0, 0.0, 0.0, 1.0] |
| 65 | ], dtype=torch.float32, device=device) |
| 66 | |
| 67 | T_cam = c2w_new @ w2c_orig @ R_ply |
| 68 | T_final = R_back @ T_cam @ R_init |
| 69 | |
| 70 | # Apply transform |
| 71 | vertices = vertices.reshape(-1, 3) |
| 72 | verts_h = torch.cat([vertices, torch.ones((vertices.shape[0], 1), dtype=torch.float32, device=device)], dim=1) |
| 73 | verts_trans = (T_final @ verts_h.T).T[:, :3] |
| 74 | |
| 75 | return verts_trans |
| 76 | |
| 77 | |
| 78 | def transform_normals(normals, frame): |
no test coverage detected