Convert rotations given as rotation matrices to quaternions. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). Returns: quaternions with real part last, as tensor of shape (..., 4). Quaternion Order: XYZW or say ijkr, scalar-last
(matrix: torch.Tensor)
| 45 | |
| 46 | |
| 47 | def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: |
| 48 | """ |
| 49 | Convert rotations given as rotation matrices to quaternions. |
| 50 | |
| 51 | Args: |
| 52 | matrix: Rotation matrices as tensor of shape (..., 3, 3). |
| 53 | |
| 54 | Returns: |
| 55 | quaternions with real part last, as tensor of shape (..., 4). |
| 56 | Quaternion Order: XYZW or say ijkr, scalar-last |
| 57 | """ |
| 58 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: |
| 59 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") |
| 60 | |
| 61 | batch_dim = matrix.shape[:-2] |
| 62 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) |
| 63 | |
| 64 | q_abs = _sqrt_positive_part( |
| 65 | torch.stack( |
| 66 | [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1 |
| 67 | ) |
| 68 | ) |
| 69 | |
| 70 | # we produce the desired quaternion multiplied by each of r, i, j, k |
| 71 | quat_by_rijk = torch.stack( |
| 72 | [ |
| 73 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and |
| 74 | # `int`. |
| 75 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), |
| 76 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and |
| 77 | # `int`. |
| 78 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), |
| 79 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and |
| 80 | # `int`. |
| 81 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), |
| 82 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and |
| 83 | # `int`. |
| 84 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), |
| 85 | ], |
| 86 | dim=-2, |
| 87 | ) |
| 88 | |
| 89 | # We floor here at 0.1 but the exact level is not important; if q_abs is small, |
| 90 | # the candidate won't be picked. |
| 91 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) |
| 92 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) |
| 93 | |
| 94 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign), |
| 95 | # forall i; we pick the best-conditioned one (with the largest denominator) |
| 96 | out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) |
| 97 | |
| 98 | # Convert from rijk to ijkr |
| 99 | out = out[..., [1, 2, 3, 0]] |
| 100 | |
| 101 | out = standardize_quaternion(out) |
| 102 | |
| 103 | return out |
| 104 |
no test coverage detected