from __future__ import annotations
import dataclasses
from typing import Tuple, cast
import numpy as onp
import numpy.typing as onpt
from typing_extensions import override
from . import _base
from ._so3 import SO3
from .utils import broadcast_leading_axes, get_epsilon, register_lie_group
def _skew(omega: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]:
"""Returns the skew-symmetric form of a length-3 vector."""
wx, wy, wz = onp.moveaxis(omega, -1, 0)
zeros = onp.zeros_like(wx)
return onp.stack(
[zeros, -wz, wy, wz, zeros, -wx, -wy, wx, zeros],
axis=-1,
).reshape((*omega.shape[:-1], 3, 3))
[docs]
@register_lie_group(
matrix_dim=4,
parameters_dim=7,
tangent_dim=6,
space_dim=3,
)
@dataclasses.dataclass(frozen=True)
class SE3(_base.SEBase[SO3]):
"""Special Euclidean group for proper rigid transforms in 3D. Broadcasting
rules are the same as for numpy.
Ported to numpy from `jaxlie.SE3`.
Internal parameterization is `(qw, qx, qy, qz, x, y, z)`. Tangent parameterization
is `(vx, vy, vz, omega_x, omega_y, omega_z)`.
"""
# SE3-specific.
wxyz_xyz: onpt.NDArray[onp.floating]
"""Internal parameters. wxyz quaternion followed by xyz translation. Shape should be `(*, 7)`."""
@override
def __repr__(self) -> str:
quat = onp.round(self.wxyz_xyz[..., :4], 5)
trans = onp.round(self.wxyz_xyz[..., 4:], 5)
return f"{self.__class__.__name__}(wxyz={quat}, xyz={trans})"
# SE-specific.
[docs]
@classmethod
@override
def from_rotation_and_translation(
cls,
rotation: SO3,
translation: onpt.NDArray[onp.floating],
) -> SE3:
assert translation.shape[-1:] == (3,)
rotation, translation = broadcast_leading_axes((rotation, translation))
return SE3(wxyz_xyz=onp.concatenate([rotation.wxyz, translation], axis=-1))
[docs]
@override
def rotation(self) -> SO3:
return SO3(wxyz=self.wxyz_xyz[..., :4])
[docs]
@override
def translation(self) -> onpt.NDArray[onp.floating]:
return self.wxyz_xyz[..., 4:]
# Factory.
[docs]
@classmethod
@override
def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SE3:
return SE3(
wxyz_xyz=onp.broadcast_to(
onp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), (*batch_axes, 7)
)
)
[docs]
@classmethod
@override
def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SE3:
assert matrix.shape[-2:] == (4, 4) or matrix.shape[-2:] == (3, 4)
# Currently assumes bottom row is [0, 0, 0, 1].
return SE3.from_rotation_and_translation(
rotation=SO3.from_matrix(matrix[..., :3, :3]),
translation=matrix[..., :3, 3],
)
# Accessors.
[docs]
@override
def as_matrix(self) -> onpt.NDArray[onp.floating]:
out = onp.zeros((*self.get_batch_axes(), 4, 4))
out[..., :3, :3] = self.rotation().as_matrix()
out[..., :3, 3] = self.translation()
out[..., 3, 3] = 1.0
return out
[docs]
@override
def parameters(self) -> onpt.NDArray[onp.floating]:
return self.wxyz_xyz
# Operations.
[docs]
@classmethod
@override
def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE3:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761
# (x, y, z, omega_x, omega_y, omega_z)
assert tangent.shape[-1:] == (6,)
rotation = SO3.exp(tangent[..., 3:])
theta_squared = onp.sum(onp.square(tangent[..., 3:]), axis=-1)
use_taylor = theta_squared < get_epsilon(theta_squared.dtype)
# Shim to avoid NaNs in onp.where branches, which cause failures for
# reverse-mode AD in JAX. This isn't needed for vanilla numpy.
theta_squared_safe = cast(
onp.ndarray,
onp.where(
use_taylor,
onp.ones_like(theta_squared), # Any non-zero value should do here.
theta_squared,
),
)
del theta_squared
theta_safe = onp.sqrt(theta_squared_safe)
skew_omega = _skew(tangent[..., 3:])
V = onp.where(
use_taylor[..., None, None],
rotation.as_matrix(),
(
onp.eye(3)
+ ((1.0 - onp.cos(theta_safe)) / (theta_squared_safe))[..., None, None]
* skew_omega
+ (
(theta_safe - onp.sin(theta_safe))
/ (theta_squared_safe * theta_safe)
)[..., None, None]
* onp.einsum("...ij,...jk->...ik", skew_omega, skew_omega)
),
)
return SE3.from_rotation_and_translation(
rotation=rotation,
translation=onp.einsum("...ij,...j->...i", V, tangent[..., :3]),
)
[docs]
@override
def log(self) -> onpt.NDArray[onp.floating]:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223
omega = self.rotation().log()
theta_squared = onp.sum(onp.square(omega), axis=-1)
use_taylor = theta_squared < get_epsilon(theta_squared.dtype)
skew_omega = _skew(omega)
# Shim to avoid NaNs in onp.where branches, which cause failures for
# reverse-mode AD in JAX. This isn't needed for vanilla numpy.
theta_squared_safe = onp.where(
use_taylor,
onp.ones_like(theta_squared), # Any non-zero value should do here.
theta_squared,
)
del theta_squared
theta_safe = onp.sqrt(theta_squared_safe)
half_theta_safe = theta_safe / 2.0
V_inv = onp.where(
use_taylor[..., None, None],
onp.eye(3)
- 0.5 * skew_omega
+ onp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) / 12.0,
(
onp.eye(3)
- 0.5 * skew_omega
+ (
(
1.0
- theta_safe
* onp.cos(half_theta_safe)
/ (2.0 * onp.sin(half_theta_safe))
)
/ theta_squared_safe
)[..., None, None]
* onp.einsum("...ij,...jk->...ik", skew_omega, skew_omega)
),
)
return onp.concatenate(
[onp.einsum("...ij,...j->...i", V_inv, self.translation()), omega], axis=-1
)
[docs]
@override
def adjoint(self) -> onpt.NDArray[onp.floating]:
R = self.rotation().as_matrix()
return onp.concatenate(
[
onp.concatenate(
[R, onp.einsum("...ij,...jk->...ik", _skew(self.translation()), R)],
axis=-1,
),
onp.concatenate(
[onp.zeros((*self.get_batch_axes(), 3, 3)), R], axis=-1
),
],
axis=-2,
)
# @classmethod
# @override
# def sample_uniform(
# cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = ()
# ) -> SE3:
# key0, key1 = jax.random.split(key)
# return SE3.from_rotation_and_translation(
# rotation=SO3.sample_uniform(key0, batch_axes=batch_axes),
# translation=jax.random.uniform(
# key=key1, shape=(*batch_axes, 3), minval=-1.0, maxval=1.0
# ),
# )