"""Message type definitions. For synchronization with the TypeScript definitions, see
`_typescript_interface_gen.py.`"""
from __future__ import annotations
import abc
import dataclasses
import functools
import warnings
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Type,
TypeVar,
cast,
)
import msgspec
import numpy as np
from typing_extensions import get_args, get_origin, get_type_hints
if TYPE_CHECKING:
from ._infra import ClientId
else:
ClientId = Any
def _prepare_for_deserialization(value: Any, annotation: Type) -> Any:
# If annotated as a float but we got an integer, cast to float. These
# are both `number` in Javascript.
if annotation is float:
return float(value)
elif annotation is int:
return int(value)
elif get_origin(annotation) is tuple:
out = []
args = get_args(annotation)
if len(args) >= 2 and args[1] == ...:
args = (args[0],) * len(value)
elif len(value) != len(args):
warnings.warn(f"[viser] {value} does not match annotation {annotation}")
return value
for i, v in enumerate(value):
out.append(
# Hack to be OK with wrong type annotations.
# https://github.com/nerfstudio-project/nerfstudio/pull/1805
_prepare_for_deserialization(v, args[i]) if i < len(args) else v
)
return tuple(out)
return value
def _prepare_for_serialization(value: Any, annotation: object) -> Any:
"""Prepare any special types for serialization."""
if annotation is Any:
annotation = type(value)
# Coerce some scalar types: if we've annotated as float / int but we get an
# np.float32 / np.int64, for example, we should cast automatically.
if annotation is float or isinstance(value, np.floating):
return float(value)
if annotation is int or isinstance(value, np.integer):
return int(value)
if dataclasses.is_dataclass(annotation):
return _prepare_for_serialization(vars(value), dict)
# Recursively handle tuples.
if isinstance(value, tuple):
if isinstance(value, np.ndarray):
assert False, (
"Expected a tuple, but got an array... missing a cast somewhere?"
f" {value}"
)
out = []
if get_origin(annotation) is tuple:
args = get_args(annotation)
if len(args) >= 2 and args[1] == ...:
args = (args[0],) * len(value)
elif len(value) != len(args):
warnings.warn(f"[viser] {value} does not match annotation {annotation}")
return value
else:
args = [Any] * len(value)
for i, v in enumerate(value):
out.append(
# Hack to be OK with wrong type annotations.
# https://github.com/nerfstudio-project/nerfstudio/pull/1805
_prepare_for_serialization(v, args[i]) if i < len(args) else v
)
return tuple(out)
# For arrays, we serialize underlying data directly. The client is responsible for
# reading using the correct dtype.
if isinstance(value, np.ndarray):
return value.data if value.data.c_contiguous else value.copy().data
if isinstance(value, dict):
return {k: _prepare_for_serialization(v, Any) for k, v in value.items()} # type: ignore
return value
T = TypeVar("T", bound="Message")
@functools.lru_cache(maxsize=None)
def get_type_hints_cached(cls: Type[Any]) -> Dict[str, Any]:
return get_type_hints(cls) # type: ignore
[docs]
class Message(abc.ABC):
"""Base message type for server/client communication."""
excluded_self_client: Optional[ClientId] = None
"""Don't send this message to a particular client. Useful when a client wants to
send synchronization information to other clients."""
[docs]
def as_serializable_dict(self) -> Dict[str, Any]:
"""Convert a Python Message object into bytes."""
message_type = type(self)
hints = get_type_hints_cached(message_type)
out = {
k: _prepare_for_serialization(v, hints[k]) for k, v in vars(self).items()
}
out["type"] = message_type.__name__
return out
@classmethod
def _from_serializable_dict(cls, mapping: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a dict message back into a Python Message object."""
hints = get_type_hints_cached(cls)
mapping = {
k: _prepare_for_deserialization(v, hints[k]) for k, v in mapping.items()
}
return mapping
[docs]
@classmethod
def deserialize(cls, message: bytes) -> Message:
"""Convert bytes into a Python Message object."""
mapping = msgspec.msgpack.decode(message)
# msgpack deserializes to lists by default, but all of our annotations use
# tuples.
def lists_to_tuple(obj: Any) -> Any:
if isinstance(obj, list):
return tuple(lists_to_tuple(x) for x in obj)
elif isinstance(obj, dict):
return {k: lists_to_tuple(v) for k, v in obj.items()}
else:
return obj
mapping = lists_to_tuple(mapping)
message_type = cls._subclass_from_type_string()[cast(str, mapping.pop("type"))]
message_kwargs = message_type._from_serializable_dict(mapping)
return message_type(**message_kwargs)
@classmethod
@functools.lru_cache(maxsize=100)
def _subclass_from_type_string(cls: Type[T]) -> Dict[str, Type[T]]:
subclasses = cls.get_subclasses()
return {s.__name__: s for s in subclasses}
[docs]
@classmethod
def get_subclasses(cls: Type[T]) -> List[Type[T]]:
"""Recursively get message subclasses."""
def _get_subclasses(typ: Type[T]) -> List[Type[T]]:
out = []
for sub in typ.__subclasses__():
if not sub.__name__.startswith("_"):
out.append(sub)
out.extend(_get_subclasses(sub))
return out
return _get_subclasses(cls)
[docs]
@abc.abstractmethod
def redundancy_key(self) -> str:
"""Returns a unique key for this message, used for detecting redundant
messages.
For example: if we send 1000 "set value" messages for the same GUI element, we
should only keep the latest message.
"""