Source code for viser._viser

from __future__ import annotations

import contextlib
import dataclasses
import io
import threading
import time
from pathlib import Path
from typing import Callable, Dict, Generator, List, Optional, Tuple

import imageio.v3 as iio
import numpy as onp
import numpy.typing as npt
import rich
from rich import box, style
from rich.panel import Panel
from rich.table import Table
from typing_extensions import Literal, override

from . import _client_autobuild, _messages, infra
from . import transforms as tf
from ._gui_api import GuiApi
from ._message_api import MessageApi, cast_vector
from ._scene_handles import FrameHandle, _SceneNodeHandleState
from ._tunnel import ViserTunnel


@dataclasses.dataclass
class _CameraHandleState:
    """Information about a client's camera state."""

    client: ClientHandle
    wxyz: npt.NDArray[onp.float64]
    position: npt.NDArray[onp.float64]
    fov: float
    aspect: float
    look_at: npt.NDArray[onp.float64]
    up_direction: npt.NDArray[onp.float64]
    update_timestamp: float
    camera_cb: List[Callable[[CameraHandle], None]]


[docs] @dataclasses.dataclass class CameraHandle: _state: _CameraHandleState @property def client(self) -> ClientHandle: """Client that this camera corresponds to.""" return self._state.client @property def wxyz(self) -> npt.NDArray[onp.float64]: """Corresponds to the R in `P_world = [R | t] p_camera`. Synchronized automatically when assigned.""" assert self._state.update_timestamp != 0.0 return self._state.wxyz # Note: asymmetric properties are supported in Pyright, but not yet in mypy. # - https://github.com/python/mypy/issues/3004 # - https://github.com/python/mypy/pull/11643 @wxyz.setter def wxyz(self, wxyz: Tuple[float, float, float, float] | onp.ndarray) -> None: R_world_camera = tf.SO3(onp.asarray(wxyz)).as_matrix() look_distance = onp.linalg.norm(self.look_at - self.position) # We're following OpenCV conventions: look_direction is +Z, up_direction is -Y, # right_direction is +X. look_direction = R_world_camera[:, 2] up_direction = -R_world_camera[:, 1] right_direction = R_world_camera[:, 0] # Minimize our impact on the orbit controls by keeping the new up direction as # close to the old one as possible. projected_up_direction = ( self.up_direction - float(self.up_direction @ right_direction) * right_direction ) up_cosine = float(up_direction @ projected_up_direction) if abs(up_cosine) < 0.05: projected_up_direction = up_direction elif up_cosine < 0.0: projected_up_direction = up_direction new_look_at = look_direction * look_distance + self.position # Update lookat and up direction. self.look_at = new_look_at self.up_direction = projected_up_direction # The internal camera orientation should be set in the look_at / # up_direction setters. We can uncomment this assert to check this. # assert onp.allclose(self._state.wxyz, wxyz) or onp.allclose( # self._state.wxyz, -wxyz # ) @property def position(self) -> npt.NDArray[onp.float64]: """Corresponds to the t in `P_world = [R | t] p_camera`. Synchronized automatically when assigned. The `look_at` point and `up_direction` vectors are maintained when updating `position`, which means that updates to `position` will often also affect `wxyz`. """ assert self._state.update_timestamp != 0.0 return self._state.position @position.setter def position(self, position: Tuple[float, float, float] | onp.ndarray) -> None: offset = onp.asarray(position) - onp.array(self.position) # type: ignore self._state.position = onp.asarray(position) self.look_at = onp.array(self.look_at) + offset self._state.update_timestamp = time.time() self._state.client._queue( _messages.SetCameraPositionMessage(cast_vector(position, 3)) ) def _update_wxyz(self) -> None: """Compute and update the camera orientation from the internal look_at, position, and up vectors.""" z = self._state.look_at - self._state.position z /= onp.linalg.norm(z) y = tf.SO3.exp(z * onp.pi) @ self._state.up_direction y = y - onp.dot(z, y) * z y /= onp.linalg.norm(y) x = onp.cross(y, z) self._state.wxyz = tf.SO3.from_matrix(onp.stack([x, y, z], axis=1)).wxyz @property def fov(self) -> float: """Vertical field of view of the camera, in radians. Synchronized automatically when assigned.""" assert self._state.update_timestamp != 0.0 return self._state.fov @fov.setter def fov(self, fov: float) -> None: self._state.fov = fov self._state.update_timestamp = time.time() self._state.client._queue(_messages.SetCameraFovMessage(fov)) @property def aspect(self) -> float: """Canvas width divided by height. Not assignable.""" assert self._state.update_timestamp != 0.0 return self._state.aspect @property def update_timestamp(self) -> float: assert self._state.update_timestamp != 0.0 return self._state.update_timestamp @property def look_at(self) -> npt.NDArray[onp.float64]: """Look at point for the camera. Synchronized automatically when set.""" assert self._state.update_timestamp != 0.0 return self._state.look_at @look_at.setter def look_at(self, look_at: Tuple[float, float, float] | onp.ndarray) -> None: self._state.look_at = onp.asarray(look_at) self._state.update_timestamp = time.time() self._update_wxyz() self._state.client._queue( _messages.SetCameraLookAtMessage(cast_vector(look_at, 3)) ) @property def up_direction(self) -> npt.NDArray[onp.float64]: """Up direction for the camera. Synchronized automatically when set.""" assert self._state.update_timestamp != 0.0 return self._state.up_direction @up_direction.setter def up_direction( self, up_direction: Tuple[float, float, float] | onp.ndarray ) -> None: self._state.up_direction = onp.asarray(up_direction) self._update_wxyz() self._state.update_timestamp = time.time() self._state.client._queue( _messages.SetCameraUpDirectionMessage(cast_vector(up_direction, 3)) )
[docs] def on_update( self, callback: Callable[[CameraHandle], None] ) -> Callable[[CameraHandle], None]: """Attach a callback to run when a new camera message is received.""" self._state.camera_cb.append(callback) return callback
[docs] def get_render( self, height: int, width: int, transport_format: Literal["png", "jpeg"] = "jpeg" ) -> onp.ndarray: """Request a render from a client, block until it's done and received, then return it as a numpy array. Args: height: Height of rendered image. Should be <= the browser height. width: Width of rendered image. Should be <= the browser width. transport_format: Image transport format. JPEG will return a lossy (H, W, 3) RGB array. PNG will return a lossless (H, W, 4) RGBA array, but can cause memory issues on the frontend if called too quickly for higher-resolution images. """ # Listen for a render reseponse message, which should contain the rendered # image. render_ready_event = threading.Event() out: Optional[onp.ndarray] = None connection = self.client._state.connection def got_render_cb( client_id: int, message: _messages.GetRenderResponseMessage ) -> None: del client_id connection.unregister_handler( _messages.GetRenderResponseMessage, got_render_cb ) nonlocal out out = iio.imread( io.BytesIO(message.payload), extension=f".{transport_format}", ) render_ready_event.set() connection.register_handler(_messages.GetRenderResponseMessage, got_render_cb) self.client._queue( _messages.GetRenderRequestMessage( "image/jpeg" if transport_format == "jpeg" else "image/png", height=height, width=width, # Only used for JPEG. The main reason to use a lower quality version # value is (unfortunately) to make life easier for the Javascript # garbage collector. quality=80, ) ) render_ready_event.wait() assert out is not None return out
@dataclasses.dataclass class _ClientHandleState: viser_server: ViserServer server: infra.Server connection: infra.ClientConnection
[docs] @dataclasses.dataclass class ClientHandle(MessageApi, GuiApi): """Handle for interacting with a specific client. Can be used to send messages to individual clients and read/write camera information.""" client_id: int """Unique ID for this client.""" camera: CameraHandle """Handle for reading from and manipulating the client's viewport camera.""" _state: _ClientHandleState def __post_init__(self): super().__init__(self._state.connection, self._state.server._thread_executor) @override def _get_api(self) -> MessageApi: """Message API to use.""" return self @override def _queue_unsafe(self, message: _messages.Message) -> None: """Define how the message API should send messages.""" self._state.connection.send(message)
[docs] @override @contextlib.contextmanager def atomic(self) -> Generator[None, None, None]: """Returns a context where: all outgoing messages are grouped and applied by clients atomically. This should be treated as a soft constraint that's helpful for things like animations, or when we want position and orientation updates to happen synchronously. Returns: Context manager. """ # If called multiple times in the same thread, we ignore inner calls. thread_id = threading.get_ident() if thread_id == self._locked_thread_id: got_lock = False else: self._atomic_lock.acquire() self._locked_thread_id = thread_id got_lock = True yield if got_lock: self._atomic_lock.release() self._locked_thread_id = -1
[docs] @override def flush(self) -> None: """Flush the outgoing message buffer. Any buffered messages will immediately be sent. (by default they are windowed)""" self._state.server.flush_client(self.client_id)
# We can serialize the state of a ViserServer via a tuple of # (serialized message, timestamp) pairs. SerializedServerState = Tuple[Tuple[bytes, float], ...] def dummy_process() -> None: pass @dataclasses.dataclass class _ViserServerState: connection: infra.Server connected_clients: Dict[int, ClientHandle] = dataclasses.field(default_factory=dict) client_lock: threading.Lock = dataclasses.field(default_factory=threading.Lock)
[docs] class ViserServer(MessageApi, GuiApi): """Viser server class. The primary interface for functionality in `viser`. Commands on a server object (`add_frame`, `add_gui_*`, ...) will be sent to all clients, including new clients that connect after a command is called. Args: host: Host to bind server to. port: Port to bind server to. label: Label shown at the top of the GUI panel. """ world_axes: FrameHandle """Handle for manipulating the world frame axes (/WorldAxes), which is instantiated and then hidden by default.""" # Hide deprecated arguments from docstring and type checkers. def __init__( self, host: str = "0.0.0.0", port: int = 8080, label: Optional[str] = None ): ... def _actual_init( self, host: str = "0.0.0.0", port: int = 8080, label: Optional[str] = None, **_deprecated_kwargs, ): # Create server. server = infra.Server( host=host, port=port, message_class=_messages.Message, http_server_root=Path(__file__).absolute().parent / "client" / "build", client_api_version=1, ) self._server = server super().__init__(server, server._thread_executor) _client_autobuild.ensure_client_is_built() state = _ViserServerState(server) self._state = state self._client_connect_cb: List[Callable[[ClientHandle], None]] = [] self._client_disconnect_cb: List[Callable[[ClientHandle], None]] = [] # For new clients, register and add a handler for camera messages. @server.on_client_connect def _(conn: infra.ClientConnection) -> None: camera = CameraHandle( _CameraHandleState( # TODO: values are initially not valid. client=None, # type: ignore wxyz=onp.zeros(4), position=onp.zeros(3), fov=0.0, aspect=0.0, look_at=onp.zeros(3), up_direction=onp.zeros(3), update_timestamp=0.0, camera_cb=[], ) ) client = ClientHandle( conn.client_id, camera, _ClientHandleState(self, server, conn), ) camera._state.client = client first = True def handle_camera_message( client_id: infra.ClientId, message: _messages.ViewerCameraMessage ) -> None: nonlocal first assert client_id == client.client_id # Update the client's camera. with client.atomic(): client.camera._state = _CameraHandleState( client, onp.array(message.wxyz), onp.array(message.position), message.fov, message.aspect, onp.array(message.look_at), onp.array(message.up_direction), time.time(), camera_cb=client.camera._state.camera_cb, ) # We consider a client to be connected after the first camera message is # received. if first: first = False with self._state.client_lock: state.connected_clients[conn.client_id] = client for cb in self._client_connect_cb: cb(client) for camera_cb in client.camera._state.camera_cb: camera_cb(client.camera) conn.register_handler(_messages.ViewerCameraMessage, handle_camera_message) # Remove clients when they disconnect. @server.on_client_disconnect def _(conn: infra.ClientConnection) -> None: with self._state.client_lock: if conn.client_id not in state.connected_clients: return handle = state.connected_clients.pop(conn.client_id) for cb in self._client_disconnect_cb: cb(handle) # Start the server. server.start() server.register_handler( _messages.ShareUrlDisconnect, lambda client_id, msg: self.disconnect_share_url(), ) server.register_handler( _messages.ShareUrlRequest, lambda client_id, msg: self.request_share_url() ) # Form status print. port = server._port # Port may have changed. http_url = f"http://{host}:{port}" ws_url = f"ws://{host}:{port}" table = Table( title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True), ) table.add_row("HTTP", http_url) table.add_row("Websocket", ws_url) rich.print(Panel(table, title="[bold]viser[/bold]", expand=False)) self._share_tunnel: Optional[ViserTunnel] = None # Create share tunnel if requested. # This is deprecated: we should use get_share_url() instead. share = _deprecated_kwargs.get("share", False) if share: self.request_share_url() self.reset_scene() self.set_gui_panel_label(label) # Create a handle for the world axes, which are hardcoded to exist in the client. self.world_axes = FrameHandle( _SceneNodeHandleState( "/WorldAxes", self, wxyz=onp.array([1.0, 0.0, 0.0, 0.0]), position=onp.zeros(3), ) ) self.world_axes.visible = False
[docs] def get_host(self) -> str: """Returns the host address of the Viser server. Returns: Host address as string. """ return self._server._host
[docs] def get_port(self) -> int: """Returns the port of the Viser server. This could be different from the originally requested one. Returns: Port as integer. """ return self._server._port
[docs] def request_share_url(self, verbose: bool = True) -> Optional[str]: """Request a share URL for the Viser server, which allows for public access. On the first call, will block until a connecting with the share URL server is established. Afterwards, the URL will be returned directly. This is an experimental feature that relies on an external server; it shouldn't be relied on for critical applications. Returns: Share URL as string, or None if connection fails or is closed. """ if self._share_tunnel is not None: # Tunnel already exists. while self._share_tunnel.get_status() in ("ready", "connecting"): time.sleep(0.05) return self._share_tunnel.get_url() else: # Create a new tunnel!. if verbose: rich.print("[bold](viser)[/bold] Share URL requested!") connect_event = threading.Event() self._share_tunnel = ViserTunnel("share.viser.studio", self._server._port) @self._share_tunnel.on_disconnect def _() -> None: rich.print("[bold](viser)[/bold] Disconnected from share URL") self._share_tunnel = None self._server.broadcast(_messages.ShareUrlUpdated(None)) @self._share_tunnel.on_connect def _(max_clients: int) -> None: assert self._share_tunnel is not None share_url = self._share_tunnel.get_url() if verbose: if share_url is None: rich.print("[bold](viser)[/bold] Could not generate share URL") else: rich.print( f"[bold](viser)[/bold] Generated share URL (expires in 24 hours, max {max_clients} clients): {share_url}" ) self._server.broadcast(_messages.ShareUrlUpdated(share_url)) connect_event.set() connect_event.wait() url = self._share_tunnel.get_url() return url
[docs] def disconnect_share_url(self) -> None: """Disconnect from the share URL server.""" if self._share_tunnel is not None: self._share_tunnel.close() else: rich.print( "[bold](viser)[/bold] Tried to disconnect from share URL, but already disconnected" )
[docs] def stop(self) -> None: """Stop the Viser server and associated threads and tunnels.""" self._server.stop() if self._share_tunnel is not None: self._share_tunnel.close()
[docs] def get_clients(self) -> Dict[int, ClientHandle]: """Creates and returns a copy of the mapping from connected client IDs to handles. Returns: Dictionary of clients. """ with self._state.client_lock: return self._state.connected_clients.copy()
[docs] def on_client_connect( self, cb: Callable[[ClientHandle], None] ) -> Callable[[ClientHandle], None]: """Attach a callback to run for newly connected clients.""" with self._state.client_lock: clients = self._state.connected_clients.copy().values() self._client_connect_cb.append(cb) # Trigger callback on any already-connected clients. # If we have: # # server = viser.ViserServer() # server.on_client_connect(...) # # This makes sure that the the callback is applied to any clients that # connect between the two lines. for client in clients: cb(client) return cb
[docs] def on_client_disconnect( self, cb: Callable[[ClientHandle], None] ) -> Callable[[ClientHandle], None]: """Attach a callback to run when clients disconnect.""" self._client_disconnect_cb.append(cb) return cb
[docs] @override @contextlib.contextmanager def atomic(self) -> Generator[None, None, None]: """Returns a context where: all outgoing messages are grouped and applied by clients atomically. This should be treated as a soft constraint that's helpful for things like animations, or when we want position and orientation updates to happen synchronously. Returns: Context manager. """ # Acquire the global atomic lock. # If called multiple times in the same thread, we ignore inner calls. thread_id = threading.get_ident() if thread_id == self._locked_thread_id: got_lock = False else: self._atomic_lock.acquire() self._locked_thread_id = thread_id got_lock = True with contextlib.ExitStack() as stack: if got_lock: # Grab each client's atomic lock. # We don't need to do anything with `client._locked_thread_id`. for client in self.get_clients().values(): stack.enter_context(client._atomic_lock) yield if got_lock: self._atomic_lock.release() self._locked_thread_id = -1
[docs] @override def flush(self) -> None: """Flush the outgoing message buffer. Any buffered messages will immediately be sent. (by default they are windowed)""" self._server.flush()
@override def _get_api(self) -> MessageApi: """Message API to use.""" return self @override def _queue_unsafe(self, message: _messages.Message) -> None: """Define how the message API should send messages.""" self._server.broadcast(message)
ViserServer.__init__ = ViserServer._actual_init # type: ignore