import abc
from typing import ClassVar, Generic, Tuple, TypeVar, Union, overload
import numpy as np
import numpy.typing as npt
from typing_extensions import Self, final, get_args, override
[docs]
class MatrixLieGroup(abc.ABC):
"""Interface definition for matrix Lie groups."""
matrix_dim: ClassVar[int]
"""Dimension of square matrix output from `.as_matrix()`."""
parameters_dim: ClassVar[int]
"""Dimension of underlying parameters, `.parameters()`."""
tangent_dim: ClassVar[int]
"""Dimension of tangent space."""
space_dim: ClassVar[int]
"""Dimension of coordinates that can be transformed."""
def __init__(
# Notes:
# - For the constructor signature to be consistent with subclasses, `parameters`
# should be marked as positional-only. But this isn't possible in Python 3.7.
# - This method is implicitly overriden by the dataclass decorator and
# should _not_ be marked abstract.
self,
parameters: np.ndarray,
):
"""Construct a group object from its underlying parameters."""
raise NotImplementedError()
[docs]
def __init_subclass__(
cls,
matrix_dim: int = 0,
parameters_dim: int = 0,
tangent_dim: int = 0,
space_dim: int = 0,
) -> None:
"""Set class properties for subclasses. We default to dummy values."""
cls.matrix_dim = matrix_dim
cls.parameters_dim = parameters_dim
cls.tangent_dim = tangent_dim
cls.space_dim = space_dim
# Shared implementations.
@overload
def __matmul__(self, other: Self) -> Self: ...
@overload
def __matmul__(
self, other: npt.NDArray[np.floating]
) -> npt.NDArray[np.floating]: ...
[docs]
def __matmul__(
self, other: Union[Self, npt.NDArray[np.floating]]
) -> Union[Self, npt.NDArray[np.floating]]:
"""Overload for the `@` operator.
Switches between the group action (`.apply()`) and multiplication
(`.multiply()`) based on the type of `other`.
"""
if isinstance(other, np.ndarray):
return self.apply(target=other)
elif isinstance(other, MatrixLieGroup):
assert self.space_dim == other.space_dim
return self.multiply(other=other)
else:
assert False, f"Invalid argument type for `@` operator: {type(other)}"
# Factory.
[docs]
@classmethod
@abc.abstractmethod
def identity(
cls, batch_axes: Tuple[int, ...] = (), dtype: npt.DTypeLike = np.float64
) -> Self:
"""Returns identity element.
Args:
batch_axes: Any leading batch axes for the output transform.
dtype: Datatype for the output.
Returns:
Identity element.
"""
[docs]
@classmethod
@abc.abstractmethod
def from_matrix(cls, matrix: npt.NDArray[np.floating]) -> Self:
"""Get group member from matrix representation.
Args:
matrix: Matrix representaiton.
Returns:
Group member.
"""
# Accessors.
[docs]
@abc.abstractmethod
def as_matrix(self) -> npt.NDArray[np.floating]:
"""Get transformation as a matrix. Homogeneous for SE groups."""
[docs]
@abc.abstractmethod
def parameters(self) -> npt.NDArray[np.floating]:
"""Get underlying representation."""
# Operations.
[docs]
@abc.abstractmethod
def apply(self, target: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
"""Applies group action to a point.
Args:
target: Point to transform.
Returns:
Transformed point.
"""
[docs]
@abc.abstractmethod
def multiply(self, other: Self) -> Self:
"""Composes this transformation with another.
Returns:
self @ other
"""
[docs]
@classmethod
@abc.abstractmethod
def exp(cls, tangent: npt.NDArray[np.floating]) -> Self:
"""Computes `expm(wedge(tangent))`.
Args:
tangent: Tangent vector to take the exponential of.
Returns:
Output.
"""
[docs]
@abc.abstractmethod
def log(self) -> npt.NDArray[np.floating]:
"""Computes `vee(logm(transformation matrix))`.
Returns:
Output. Shape should be `(tangent_dim,)`.
"""
[docs]
@abc.abstractmethod
def adjoint(self) -> npt.NDArray[np.floating]:
"""Computes the adjoint, which transforms tangent vectors between tangent
spaces.
More precisely, for a transform `GroupType`:
```
GroupType @ exp(omega) = exp(Adj_T @ omega) @ GroupType
```
In robotics, typically used for transforming twists, wrenches, and Jacobians
across different reference frames.
Returns:
Output. Shape should be `(tangent_dim, tangent_dim)`.
"""
[docs]
@abc.abstractmethod
def inverse(self) -> Self:
"""Computes the inverse of our transform.
Returns:
Output.
"""
[docs]
@abc.abstractmethod
def normalize(self) -> Self:
"""Normalize/projects values and returns.
Returns:
Normalized group member.
"""
[docs]
@final
def get_batch_axes(self) -> Tuple[int, ...]:
"""Return any leading batch axes in contained parameters. If an array of shape
`(100, 4)` is placed in the wxyz field of an SO3 object, for example, this will
return `(100,)`."""
return self.parameters().shape[:-1]
[docs]
class SOBase(MatrixLieGroup):
"""Base class for special orthogonal groups."""
ContainedSOType = TypeVar("ContainedSOType", bound=SOBase)
[docs]
class SEBase(Generic[ContainedSOType], MatrixLieGroup):
"""Base class for special Euclidean groups.
Each SE(N) group member contains an SO(N) rotation, as well as an N-dimensional
translation vector.
"""
# SE-specific interface.
[docs]
@classmethod
@abc.abstractmethod
def from_rotation_and_translation(
cls,
rotation: ContainedSOType,
translation: npt.NDArray[np.floating],
) -> Self:
"""Construct a rigid transform from a rotation and a translation.
Args:
rotation: Rotation term.
translation: translation term.
Returns:
Constructed transformation.
"""
[docs]
@final
@classmethod
def from_rotation(cls, rotation: ContainedSOType) -> Self:
return cls.from_rotation_and_translation(
rotation=rotation,
translation=np.zeros(
(*rotation.get_batch_axes(), cls.space_dim),
dtype=rotation.parameters().dtype,
),
)
[docs]
@final
@classmethod
def from_translation(cls, translation: npt.NDArray[np.floating]) -> Self:
# Extract rotation class from type parameter.
assert len(cls.__orig_bases__) == 1 # type: ignore
return cls.from_rotation_and_translation(
rotation=get_args(cls.__orig_bases__[0])[0].identity(), # type: ignore
translation=translation,
)
[docs]
@abc.abstractmethod
def rotation(self) -> ContainedSOType:
"""Returns a transform's rotation term."""
[docs]
@abc.abstractmethod
def translation(self) -> npt.NDArray[np.floating]:
"""Returns a transform's translation term."""
# Overrides.
[docs]
@final
@override
def apply(self, target: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
return self.rotation() @ target + self.translation() # type: ignore
[docs]
@final
@override
def multiply(self, other: Self) -> Self:
return type(self).from_rotation_and_translation(
rotation=self.rotation() @ other.rotation(),
translation=(self.rotation() @ other.translation()) + self.translation(),
)
[docs]
@final
@override
def inverse(self) -> Self:
R_inv = self.rotation().inverse()
return type(self).from_rotation_and_translation(
rotation=R_inv,
translation=-(R_inv @ self.translation()),
)
[docs]
@final
@override
def normalize(self) -> Self:
return type(self).from_rotation_and_translation(
rotation=self.rotation().normalize(),
translation=self.translation(),
)