from __future__ import annotations
import dataclasses
from typing import NamedTuple, Tuple
import numpy as onp
import numpy.typing as onpt
from typing_extensions import override
from . import _base, hints
from .utils import broadcast_leading_axes, get_epsilon
class RollPitchYaw(NamedTuple):
"""Struct containing roll, pitch, and yaw Euler angles."""
roll: onpt.NDArray[onp.floating]
pitch: onpt.NDArray[onp.floating]
yaw: onpt.NDArray[onp.floating]
[docs]
@dataclasses.dataclass(frozen=True)
class SO3(
_base.SOBase,
matrix_dim=3,
parameters_dim=4,
tangent_dim=3,
space_dim=3,
):
"""Special orthogonal group for 3D rotations. Broadcasting rules are the same as
for numpy.
Ported to numpy from `jaxlie.SO3`.
Internal parameterization is `(qw, qx, qy, qz)`. Tangent parameterization is
`(omega_x, omega_y, omega_z)`.
"""
wxyz: onpt.NDArray[onp.floating]
"""Internal parameters. `(w, x, y, z)` quaternion. Shape should be `(*, 4)`."""
@override
def __repr__(self) -> str:
wxyz = onp.round(self.wxyz, 5)
return f"{self.__class__.__name__}(wxyz={wxyz})"
[docs]
@staticmethod
def from_x_radians(theta: hints.Scalar) -> SO3:
"""Generates a x-axis rotation.
Args:
angle: X rotation, in radians.
Returns:
Output.
"""
zeros = onp.zeros_like(theta)
return SO3.exp(onp.stack([theta, zeros, zeros], axis=-1))
[docs]
@staticmethod
def from_y_radians(theta: hints.Scalar) -> SO3:
"""Generates a y-axis rotation.
Args:
angle: Y rotation, in radians.
Returns:
Output.
"""
zeros = onp.zeros_like(theta)
return SO3.exp(onp.stack([zeros, theta, zeros], axis=-1))
[docs]
@staticmethod
def from_z_radians(theta: hints.Scalar) -> SO3:
"""Generates a z-axis rotation.
Args:
angle: Z rotation, in radians.
Returns:
Output.
"""
zeros = onp.zeros_like(theta)
return SO3.exp(onp.stack([zeros, zeros, theta], axis=-1))
[docs]
@staticmethod
def from_rpy_radians(
roll: hints.Scalar,
pitch: hints.Scalar,
yaw: hints.Scalar,
) -> SO3:
"""Generates a transform from a set of Euler angles. Uses the ZYX mobile robot
convention.
Args:
roll: X rotation, in radians. Applied first.
pitch: Y rotation, in radians. Applied second.
yaw: Z rotation, in radians. Applied last.
Returns:
Output.
"""
return (
SO3.from_z_radians(yaw)
[docs]
@ SO3.from_y_radians(pitch)
@ SO3.from_x_radians(roll)
)
@staticmethod
def from_quaternion_xyzw(xyzw: onpt.NDArray[onp.floating]) -> SO3:
"""Construct a rotation from an `xyzw` quaternion.
Note that `wxyz` quaternions can be constructed using the default dataclass
constructor.
Args:
xyzw: xyzw quaternion. Shape should be (*, 4).
Returns:
Output.
"""
assert xyzw.shape[-1:] == (4,)
return SO3(onp.roll(xyzw, axis=-1, shift=1))
[docs]
def as_quaternion_xyzw(self) -> onpt.NDArray[onp.floating]:
"""Grab parameters as xyzw quaternion."""
return onp.roll(self.wxyz, axis=-1, shift=-1)
[docs]
def as_rpy_radians(self) -> RollPitchYaw:
"""Computes roll, pitch, and yaw angles. Uses the ZYX mobile robot convention.
Returns:
NamedTuple containing Euler angles in radians.
"""
return RollPitchYaw(
roll=self.compute_roll_radians(),
pitch=self.compute_pitch_radians(),
yaw=self.compute_yaw_radians(),
)
[docs]
def compute_roll_radians(self) -> onpt.NDArray[onp.floating]:
"""Compute roll angle. Uses the ZYX mobile robot convention.
Returns:
Euler angle in radians.
"""
# https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion
q0, q1, q2, q3 = onp.moveaxis(self.wxyz, -1, 0)
return onp.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1**2 + q2**2))
[docs]
def compute_pitch_radians(self) -> onpt.NDArray[onp.floating]:
"""Compute pitch angle. Uses the ZYX mobile robot convention.
Returns:
Euler angle in radians.
"""
# https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion
q0, q1, q2, q3 = onp.moveaxis(self.wxyz, -1, 0)
return onp.arcsin(2 * (q0 * q2 - q3 * q1))
[docs]
def compute_yaw_radians(self) -> onpt.NDArray[onp.floating]:
"""Compute yaw angle. Uses the ZYX mobile robot convention.
Returns:
Euler angle in radians.
"""
# https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion
q0, q1, q2, q3 = onp.moveaxis(self.wxyz, -1, 0)
return onp.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2**2 + q3**2))
# Factory.
[docs]
@classmethod
@override
def identity(
cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64
) -> SO3:
return SO3(
wxyz=onp.broadcast_to(
onp.array([1.0, 0.0, 0.0, 0.0], dtype=dtype), (*batch_axes, 4)
)
)
[docs]
@classmethod
@override
def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SO3:
assert matrix.shape[-2:] == (3, 3)
# Modified from:
# > "Converting a Rotation Matrix to a Quaternion" from Mike Day
# > https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2015/01/matrix-to-quat.pdf
def case0(m):
t = 1 + m[..., 0, 0] - m[..., 1, 1] - m[..., 2, 2]
q = onp.stack(
[
m[..., 2, 1] - m[..., 1, 2],
t,
m[..., 1, 0] + m[..., 0, 1],
m[..., 0, 2] + m[..., 2, 0],
],
axis=-1,
)
return t, q
def case1(m):
t = 1 - m[..., 0, 0] + m[..., 1, 1] - m[..., 2, 2]
q = onp.stack(
[
m[..., 0, 2] - m[..., 2, 0],
m[..., 1, 0] + m[..., 0, 1],
t,
m[..., 2, 1] + m[..., 1, 2],
],
axis=-1,
)
return t, q
def case2(m):
t = 1 - m[..., 0, 0] - m[..., 1, 1] + m[..., 2, 2]
q = onp.stack(
[
m[..., 1, 0] - m[..., 0, 1],
m[..., 0, 2] + m[..., 2, 0],
m[..., 2, 1] + m[..., 1, 2],
t,
],
axis=-1,
)
return t, q
def case3(m):
t = 1 + m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2]
q = onp.stack(
[
t,
m[..., 2, 1] - m[..., 1, 2],
m[..., 0, 2] - m[..., 2, 0],
m[..., 1, 0] - m[..., 0, 1],
],
axis=-1,
)
return t, q
# Compute four cases, then pick the most precise one.
# Probably worth revisiting this!
case0_t, case0_q = case0(matrix)
case1_t, case1_q = case1(matrix)
case2_t, case2_q = case2(matrix)
case3_t, case3_q = case3(matrix)
cond0 = matrix[..., 2, 2] < 0
cond1 = matrix[..., 0, 0] > matrix[..., 1, 1]
cond2 = matrix[..., 0, 0] < -matrix[..., 1, 1]
t = onp.where(
cond0,
onp.where(cond1, case0_t, case1_t),
onp.where(cond2, case2_t, case3_t),
)
q = onp.where(
cond0[..., None],
onp.where(cond1[..., None], case0_q, case1_q),
onp.where(cond2[..., None], case2_q, case3_q),
)
return SO3(wxyz=(q * 0.5 / onp.sqrt(t[..., None])).astype(matrix.dtype))
# Accessors.
[docs]
@override
def as_matrix(self) -> onpt.NDArray[onp.floating]:
norm_sq = onp.sum(onp.square(self.wxyz), axis=-1, keepdims=True)
q = self.wxyz * onp.sqrt(2.0 / norm_sq) # (*, 4)
q_outer = onp.einsum("...i,...j->...ij", q, q) # (*, 4, 4)
return (
onp.stack(
[
1.0 - q_outer[..., 2, 2] - q_outer[..., 3, 3],
q_outer[..., 1, 2] - q_outer[..., 3, 0],
q_outer[..., 1, 3] + q_outer[..., 2, 0],
q_outer[..., 1, 2] + q_outer[..., 3, 0],
1.0 - q_outer[..., 1, 1] - q_outer[..., 3, 3],
q_outer[..., 2, 3] - q_outer[..., 1, 0],
q_outer[..., 1, 3] - q_outer[..., 2, 0],
q_outer[..., 2, 3] + q_outer[..., 1, 0],
1.0 - q_outer[..., 1, 1] - q_outer[..., 2, 2],
],
axis=-1,
)
.reshape(*q.shape[:-1], 3, 3)
.astype(self.wxyz.dtype)
)
[docs]
@override
def parameters(self) -> onpt.NDArray[onp.floating]:
return self.wxyz
# Operations.
[docs]
@override
def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]:
assert target.shape[-1:] == (3,)
self, target = broadcast_leading_axes((self, target))
# Compute using quaternion multiplys.
padded_target = onp.concatenate(
[onp.zeros((*self.get_batch_axes(), 1), dtype=target.dtype), target],
axis=-1,
)
return (self @ SO3(wxyz=padded_target) @ self.inverse()).wxyz[..., 1:]
[docs]
@override
def multiply(self, other: SO3) -> SO3:
w0, x0, y0, z0 = onp.moveaxis(self.wxyz, -1, 0)
w1, x1, y1, z1 = onp.moveaxis(other.wxyz, -1, 0)
return SO3(
wxyz=onp.stack(
[
-x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1,
x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1,
-x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1,
x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1,
],
axis=-1,
)
)
[docs]
@classmethod
@override
def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SO3:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L583
assert tangent.shape[-1:] == (3,)
theta_squared = onp.sum(onp.square(tangent), axis=-1)
theta_pow_4 = theta_squared * theta_squared
use_taylor = theta_squared < get_epsilon(tangent.dtype)
# Shim to avoid NaNs in onp.where branches, which cause failures for
# reverse-mode AD in JAX. This isn't needed for vanilla numpy.
safe_theta = onp.sqrt(
onp.where(
use_taylor,
onp.ones_like(theta_squared), # Any constant value should do here.
theta_squared,
)
)
# Fun fact: when safe_theta is a `float32` _scalar_, this
# multiplication will promote `safe_half_theta` to `float64`. We'll
# cast at the end to make sure our input/output dtypes match.
safe_half_theta = 0.5 * safe_theta
real_factor = onp.where(
use_taylor,
1.0 - theta_squared / 8.0 + theta_pow_4 / 384.0,
onp.cos(safe_half_theta),
)
imaginary_factor = onp.where(
use_taylor,
0.5 - theta_squared / 48.0 + theta_pow_4 / 3840.0,
onp.sin(safe_half_theta) / safe_theta,
)
return SO3(
wxyz=onp.concatenate(
[
real_factor[..., None],
imaginary_factor[..., None] * tangent,
],
axis=-1,
).astype(tangent.dtype)
)
[docs]
@override
def log(self) -> onpt.NDArray[onp.floating]:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L247
w = self.wxyz[..., 0]
norm_sq = onp.sum(onp.square(self.wxyz[..., 1:]), axis=-1)
use_taylor = norm_sq < get_epsilon(norm_sq.dtype)
# Shim to avoid NaNs in onp.where branches, which cause failures for
# reverse-mode AD in JAX. This isn't needed for vanilla numpy.
norm_safe = onp.sqrt(
onp.where(
use_taylor,
1.0, # Any non-zero value should do here.
norm_sq,
)
)
w_safe = onp.where(use_taylor, w, 1.0)
atan_n_over_w = onp.arctan2(
onp.where(w < 0, -norm_safe, norm_safe),
onp.abs(w),
)
atan_factor = onp.where(
use_taylor,
2.0 / w_safe - 2.0 / 3.0 * norm_sq / w_safe**3,
onp.where(
onp.abs(w) < get_epsilon(w.dtype),
onp.where(w > 0, 1.0, -1.0).astype(dtype=w.dtype) * onp.pi / norm_safe,
2.0 * atan_n_over_w / norm_safe,
),
)
return (atan_factor[..., None] * self.wxyz[..., 1:]).astype(self.wxyz.dtype)
[docs]
@override
def adjoint(self) -> onpt.NDArray[onp.floating]:
return self.as_matrix()
[docs]
@override
def inverse(self) -> SO3:
# Negate complex terms.
wxyz = self.wxyz.copy()
wxyz[..., 1:] *= -1
return SO3(wxyz)
[docs]
@override
def normalize(self) -> SO3:
return SO3(wxyz=self.wxyz / onp.linalg.norm(self.wxyz, axis=-1, keepdims=True))