import abc
from typing import ClassVar, Generic, Tuple, TypeVar, Union, overload
import numpy as onp
import numpy.typing as onpt
from typing_extensions import Self, final, get_args, override
[docs]
class MatrixLieGroup(abc.ABC):
"""Interface definition for matrix Lie groups."""
# Class properties.
# > These will be set in `_utils.register_lie_group()`.
matrix_dim: ClassVar[int]
"""Dimension of square matrix output from `.as_matrix()`."""
parameters_dim: ClassVar[int]
"""Dimension of underlying parameters, `.parameters()`."""
tangent_dim: ClassVar[int]
"""Dimension of tangent space."""
space_dim: ClassVar[int]
"""Dimension of coordinates that can be transformed."""
def __init__(
# Notes:
# - For the constructor signature to be consistent with subclasses, `parameters`
# should be marked as positional-only. But this isn't possible in Python 3.7.
# - This method is implicitly overriden by the dataclass decorator and
# should _not_ be marked abstract.
self,
parameters: onp.ndarray,
):
"""Construct a group object from its underlying parameters."""
raise NotImplementedError()
# Shared implementations.
@overload
def __matmul__(self, other: Self) -> Self: ...
@overload
def __matmul__(
self, other: onpt.NDArray[onp.floating]
) -> onpt.NDArray[onp.floating]: ...
[docs]
def __matmul__(
self, other: Union[Self, onpt.NDArray[onp.floating]]
) -> Union[Self, onpt.NDArray[onp.floating]]:
"""Overload for the `@` operator.
Switches between the group action (`.apply()`) and multiplication
(`.multiply()`) based on the type of `other`.
"""
if isinstance(other, onp.ndarray):
return self.apply(target=other)
elif isinstance(other, MatrixLieGroup):
assert self.space_dim == other.space_dim
return self.multiply(other=other)
else:
assert False, f"Invalid argument type for `@` operator: {type(other)}"
# Factory.
[docs]
@classmethod
@abc.abstractmethod
def identity(cls, batch_axes: Tuple[int, ...] = ()) -> Self:
"""Returns identity element.
Args:
batch_axes: Any leading batch axes for the output transform.
Returns:
Identity element.
"""
[docs]
@classmethod
@abc.abstractmethod
def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> Self:
"""Get group member from matrix representation.
Args:
matrix: Matrix representaiton.
Returns:
Group member.
"""
# Accessors.
[docs]
@abc.abstractmethod
def as_matrix(self) -> onpt.NDArray[onp.floating]:
"""Get transformation as a matrix. Homogeneous for SE groups."""
[docs]
@abc.abstractmethod
def parameters(self) -> onpt.NDArray[onp.floating]:
"""Get underlying representation."""
# Operations.
[docs]
@abc.abstractmethod
def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]:
"""Applies group action to a point.
Args:
target: Point to transform.
Returns:
Transformed point.
"""
[docs]
@abc.abstractmethod
def multiply(self, other: Self) -> Self:
"""Composes this transformation with another.
Returns:
self @ other
"""
[docs]
@classmethod
@abc.abstractmethod
def exp(cls, tangent: onpt.NDArray[onp.floating]) -> Self:
"""Computes `expm(wedge(tangent))`.
Args:
tangent: Tangent vector to take the exponential of.
Returns:
Output.
"""
[docs]
@abc.abstractmethod
def log(self) -> onpt.NDArray[onp.floating]:
"""Computes `vee(logm(transformation matrix))`.
Returns:
Output. Shape should be `(tangent_dim,)`.
"""
[docs]
@abc.abstractmethod
def adjoint(self) -> onpt.NDArray[onp.floating]:
"""Computes the adjoint, which transforms tangent vectors between tangent
spaces.
More precisely, for a transform `GroupType`:
```
GroupType @ exp(omega) = exp(Adj_T @ omega) @ GroupType
```
In robotics, typically used for transforming twists, wrenches, and Jacobians
across different reference frames.
Returns:
Output. Shape should be `(tangent_dim, tangent_dim)`.
"""
[docs]
@abc.abstractmethod
def inverse(self) -> Self:
"""Computes the inverse of our transform.
Returns:
Output.
"""
[docs]
@abc.abstractmethod
def normalize(self) -> Self:
"""Normalize/projects values and returns.
Returns:
Normalized group member.
"""
# @classmethod
# @abc.abstractmethod
# def sample_uniform(cls, key: onp.ndarray, batch_axes: Tuple[int, ...] = ()) -> Self:
# """Draw a uniform sample from the group. Translations (if applicable) are in the
# range [-1, 1].
#
# Args:
# key: PRNG key, as returned by `jax.random.PRNGKey()`.
# batch_axes: Any leading batch axes for the output transforms. Each
# sampled transform will be different.
#
# Returns:
# Sampled group member.
# """
[docs]
@final
def get_batch_axes(self) -> Tuple[int, ...]:
"""Return any leading batch axes in contained parameters. If an array of shape
`(100, 4)` is placed in the wxyz field of an SO3 object, for example, this will
return `(100,)`."""
return self.parameters().shape[:-1]
[docs]
class SOBase(MatrixLieGroup):
"""Base class for special orthogonal groups."""
ContainedSOType = TypeVar("ContainedSOType", bound=SOBase)
[docs]
class SEBase(Generic[ContainedSOType], MatrixLieGroup):
"""Base class for special Euclidean groups.
Each SE(N) group member contains an SO(N) rotation, as well as an N-dimensional
translation vector.
"""
# SE-specific interface.
[docs]
@classmethod
@abc.abstractmethod
def from_rotation_and_translation(
cls,
rotation: ContainedSOType,
translation: onpt.NDArray[onp.floating],
) -> Self:
"""Construct a rigid transform from a rotation and a translation.
Args:
rotation: Rotation term.
translation: translation term.
Returns:
Constructed transformation.
"""
[docs]
@final
@classmethod
def from_rotation(cls, rotation: ContainedSOType) -> Self:
return cls.from_rotation_and_translation(
rotation=rotation,
translation=onp.zeros(
(*rotation.get_batch_axes(), cls.space_dim),
dtype=rotation.parameters().dtype,
),
)
[docs]
@final
@classmethod
def from_translation(cls, translation: onpt.NDArray[onp.floating]) -> Self:
# Extract rotation class from type parameter.
assert len(cls.__orig_bases__) == 1 # type: ignore
return cls.from_rotation_and_translation(
rotation=get_args(cls.__orig_bases__[0])[0].identity(), # type: ignore
translation=translation,
)
[docs]
@abc.abstractmethod
def rotation(self) -> ContainedSOType:
"""Returns a transform's rotation term."""
[docs]
@abc.abstractmethod
def translation(self) -> onpt.NDArray[onp.floating]:
"""Returns a transform's translation term."""
# Overrides.
[docs]
@final
@override
def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]:
return self.rotation() @ target + self.translation() # type: ignore
[docs]
@final
@override
def multiply(self, other: Self) -> Self:
return type(self).from_rotation_and_translation(
rotation=self.rotation() @ other.rotation(),
translation=(self.rotation() @ other.translation()) + self.translation(),
)
[docs]
@final
@override
def inverse(self) -> Self:
R_inv = self.rotation().inverse()
return type(self).from_rotation_and_translation(
rotation=R_inv,
translation=-(R_inv @ self.translation()),
)
[docs]
@final
@override
def normalize(self) -> Self:
return type(self).from_rotation_and_translation(
rotation=self.rotation().normalize(),
translation=self.translation(),
)