Source code for viser.infra._messages

"""Message type definitions. For synchronization with the TypeScript definitions, see
`_typescript_interface_gen.py.`"""

from __future__ import annotations

import abc
import dataclasses
import functools
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union, cast

import msgspec.msgpack
import numpy as np
from typing_extensions import get_args, get_origin, get_type_hints

if TYPE_CHECKING:
    from ._infra import ClientId
else:
    ClientId = Any


def _prepare_for_deserialization(value: Any, annotation: Type) -> Any:
    # If annotated as a float but we got an integer, cast to float. These
    # are both `number` in Javascript.
    if annotation is float:
        return float(value)
    elif annotation is int:
        return int(value)
    elif get_origin(annotation) is Union:
        # Handle Optional[T] and Union[T1, T2, ...] by finding the best
        # matching inner type. This avoids needing a blanket lists_to_tuple()
        # pass over the entire deserialized message.
        if value is None:
            return None
        args = get_args(annotation)
        for arg in args:
            if arg is type(None):
                continue
            if get_origin(arg) is tuple and isinstance(value, (list, tuple)):
                return _prepare_for_deserialization(value, arg)
            if arg is float and isinstance(value, (int, float)):
                return float(value)
            if arg is int and isinstance(value, int):
                return int(value)
        return value
    elif get_origin(annotation) is tuple:
        out = []
        args = get_args(annotation)
        if len(args) >= 2 and args[1] == ...:
            args = (args[0],) * len(value)
        elif len(value) != len(args):
            warnings.warn(f"[viser] {value} does not match annotation {annotation}")
            return value

        for i, v in enumerate(value):
            out.append(
                # Hack to be OK with wrong type annotations.
                # https://github.com/nerfstudio-project/nerfstudio/pull/1805
                _prepare_for_deserialization(v, args[i]) if i < len(args) else v
            )
        return tuple(out)
    return value


def _prepare_for_serialization(
    value: Any,
    annotation: object,
    binary_buffers: Optional[List[memoryview]] = None,
) -> Any:
    """Prepare any special types for serialization.

    If ``binary_buffers`` is provided, numpy arrays are extracted into it and
    replaced with tagged placeholder dicts (``{"__binary_index": i, "dtype": "<f4"}``).
    This pairs with the hybrid wire format where binary data is appended raw
    after the msgpack payload, enabling zero-copy typed array views on the client.

    If ``binary_buffers`` is None, numpy arrays are inlined as memoryviews."""
    if annotation is Any:
        annotation = type(value)

    # Coerce some scalar types: if we've annotated as float / int but we get an
    # np.float32 / np.int64, for example, we should cast automatically.
    if annotation is float or isinstance(value, np.floating):
        return float(value)
    if annotation is int or isinstance(value, np.integer):
        return int(value)

    if dataclasses.is_dataclass(annotation):
        return _prepare_for_serialization(vars(value), dict, binary_buffers)

    # Recursively handle tuples.
    if isinstance(value, tuple):
        out = []
        if get_origin(annotation) is tuple:
            args = get_args(annotation)
            if len(args) >= 2 and args[1] == ...:
                args = (args[0],) * len(value)
            elif len(value) != len(args):
                warnings.warn(f"[viser] {value} does not match annotation {annotation}")
                return value
        else:
            args = [Any] * len(value)

        for i, v in enumerate(value):
            out.append(
                # Hack to be OK with wrong type annotations.
                # https://github.com/nerfstudio-project/nerfstudio/pull/1805
                _prepare_for_serialization(v, args[i], binary_buffers)
                if i < len(args)
                else v
            )
        return tuple(out)

    # Handle numpy arrays: extract or inline depending on mode.
    if isinstance(value, np.ndarray):
        data = value.data if value.data.c_contiguous else value.copy().data
        if binary_buffers is not None:
            # Extract into separate buffer with tagged placeholder.
            idx = len(binary_buffers)
            binary_buffers.append(data)
            return {"__binary_index": idx, "dtype": value.dtype.str}
        else:
            # Inline as memoryview (used by API v0).
            return data

    if isinstance(value, list):
        return [_prepare_for_serialization(v, Any, binary_buffers) for v in value]

    if isinstance(value, dict):
        return {
            k: _prepare_for_serialization(v, Any, binary_buffers)
            for k, v in value.items()
        }  # type: ignore

    return value


T = TypeVar("T", bound="Message")


@functools.lru_cache(maxsize=None)
def get_type_hints_cached(cls: Type[Any]) -> Dict[str, Any]:
    return get_type_hints(cls)  # type: ignore


[docs] class Message(abc.ABC): """Base message type for server/client communication.""" excluded_self_client: Optional[ClientId] = None """Don't send this message to a particular client. Useful when a client wants to send synchronization information to other clients."""
[docs] def as_serializable_dict( self, binary_buffers: Optional[List[memoryview]] = None ) -> Dict[str, Any]: """Convert a Python Message object into a serializable dict. If ``binary_buffers`` is provided, numpy arrays are extracted into it and replaced with tagged placeholder dicts for the hybrid wire format. Otherwise, arrays are inlined as memoryviews (used by API v0).""" message_type = type(self) hints = get_type_hints_cached(message_type) # Filter to type-hinted fields only — excludes dynamic attributes # like cached values that shouldn't be serialized. out = { k: _prepare_for_serialization(v, hints[k], binary_buffers) for k, v in vars(self).items() if k in hints } out["type"] = message_type.__name__ return out
@classmethod def _from_serializable_dict(cls, mapping: Dict[str, Any]) -> Dict[str, Any]: """Convert a dict message back into a Python Message object.""" hints = get_type_hints_cached(cls) mapping = { k: _prepare_for_deserialization(v, hints[k]) for k, v in mapping.items() } return mapping
[docs] @classmethod def deserialize(cls, message: bytes) -> Message: """Convert bytes into a Python Message object.""" mapping = msgspec.msgpack.decode(message) # List-to-tuple conversion is handled per-field in # _prepare_for_deserialization (called from _from_serializable_dict), # which uses type annotations to convert only where needed. This avoids # a blanket recursive traversal of the entire message tree. message_type = cls._subclass_from_type_string()[cast(str, mapping.pop("type"))] message_kwargs = message_type._from_serializable_dict(mapping) return message_type(**message_kwargs)
@classmethod @functools.lru_cache(maxsize=100) def _subclass_from_type_string(cls: Type[T]) -> Dict[str, Type[T]]: subclasses = cls.get_subclasses() return {s.__name__: s for s in subclasses}
[docs] @classmethod def get_subclasses(cls: Type[T]) -> List[Type[T]]: """Recursively get message subclasses.""" def _get_subclasses(typ: Type[T]) -> List[Type[T]]: out = [] for sub in typ.__subclasses__(): if not sub.__name__.startswith("_"): out.append(sub) out.extend(_get_subclasses(sub)) return out return _get_subclasses(cls)
[docs] @abc.abstractmethod def redundancy_key(self) -> str: """Returns a unique key for this message, used for detecting redundant messages. For example: if we send 1000 "set value" messages for the same GUI element, we should only keep the latest message. """