Scene pointer events.#

This example shows how to use scene pointer events to specify rays, and how they can be used to interact with the scene (e.g., ray-mesh intersections).

To get the demo data, see ./assets/download_dragon_mesh.sh.

  1import time
  2from pathlib import Path
  3from typing import List, cast
  4
  5import numpy as onp
  6import trimesh
  7import trimesh.creation
  8import trimesh.ray
  9import viser
 10import viser.transforms as tf
 11
 12server = viser.ViserServer()
 13server.configure_theme(brand_color=(130, 0, 150))
 14server.set_up_direction("+y")
 15
 16mesh = cast(
 17    trimesh.Trimesh, trimesh.load_mesh(str(Path(__file__).parent / "assets/dragon.obj"))
 18)
 19mesh.apply_scale(0.05)
 20
 21mesh_handle = server.add_mesh_trimesh(
 22    name="/mesh",
 23    mesh=mesh,
 24    position=(0.0, 0.0, 0.0),
 25)
 26
 27hit_pos_handles: List[viser.GlbHandle] = []
 28
 29
 30# Buttons + callbacks will operate on a per-client basis, but will modify the global scene! :)
 31@server.on_client_connect
 32def _(client: viser.ClientHandle) -> None:
 33    # Set up the camera -- this gives a nice view of the full mesh.
 34    client.camera.position = onp.array([0.0, 0.0, -10.0])
 35    client.camera.wxyz = onp.array([0.0, 0.0, 0.0, 1.0])
 36
 37    # Tests "click" scenepointerevent.
 38    click_button_handle = client.add_gui_button("Add sphere", icon=viser.Icon.POINTER)
 39
 40    @click_button_handle.on_click
 41    def _(_):
 42        click_button_handle.disabled = True
 43
 44        @client.on_scene_pointer(event_type="click")
 45        def _(event: viser.ScenePointerEvent) -> None:
 46            # Check for intersection with the mesh, using trimesh's ray-mesh intersection.
 47            # Note that mesh is in the mesh frame, so we need to transform the ray.
 48            R_world_mesh = tf.SO3(mesh_handle.wxyz)
 49            R_mesh_world = R_world_mesh.inverse()
 50            origin = (R_mesh_world @ onp.array(event.ray_origin)).reshape(1, 3)
 51            direction = (R_mesh_world @ onp.array(event.ray_direction)).reshape(1, 3)
 52            intersector = trimesh.ray.ray_triangle.RayMeshIntersector(mesh)
 53            hit_pos, _, _ = intersector.intersects_location(origin, direction)
 54
 55            if len(hit_pos) == 0:
 56                return
 57            client.remove_scene_pointer_callback()
 58
 59            # Get the first hit position (based on distance from the ray origin).
 60            hit_pos = min(hit_pos, key=lambda x: onp.linalg.norm(x - origin))
 61
 62            # Create a sphere at the hit location.
 63            hit_pos_mesh = trimesh.creation.icosphere(radius=0.1)
 64            hit_pos_mesh.vertices += R_world_mesh @ hit_pos
 65            hit_pos_mesh.visual.vertex_colors = (0.5, 0.0, 0.7, 1.0)  # type: ignore
 66            hit_pos_handle = server.add_mesh_trimesh(
 67                name=f"/hit_pos_{len(hit_pos_handles)}", mesh=hit_pos_mesh
 68            )
 69            hit_pos_handles.append(hit_pos_handle)
 70
 71        @client.on_scene_pointer_removed
 72        def _():
 73            click_button_handle.disabled = False
 74
 75    # Tests "rect-select" scenepointerevent.
 76    paint_button_handle = client.add_gui_button("Paint mesh", icon=viser.Icon.PAINT)
 77
 78    @paint_button_handle.on_click
 79    def _(_):
 80        paint_button_handle.disabled = True
 81
 82        @client.on_scene_pointer(event_type="rect-select")
 83        def _(message: viser.ScenePointerEvent) -> None:
 84            client.remove_scene_pointer_callback()
 85
 86            global mesh_handle
 87            camera = message.client.camera
 88
 89            # Put the mesh in the camera frame.
 90            R_world_mesh = tf.SO3(mesh_handle.wxyz)
 91            R_mesh_world = R_world_mesh.inverse()
 92            R_camera_world = tf.SE3.from_rotation_and_translation(
 93                tf.SO3(camera.wxyz), camera.position
 94            ).inverse()
 95            vertices = cast(onp.ndarray, mesh.vertices)
 96            vertices = (R_mesh_world.as_matrix() @ vertices.T).T
 97            vertices = (
 98                R_camera_world.as_matrix()
 99                @ onp.hstack([vertices, onp.ones((vertices.shape[0], 1))]).T
100            ).T[:, :3]
101
102            # Get the camera intrinsics, and project the vertices onto the image plane.
103            fov, aspect = camera.fov, camera.aspect
104            vertices_proj = vertices[:, :2] / vertices[:, 2].reshape(-1, 1)
105            vertices_proj /= onp.tan(fov / 2)
106            vertices_proj[:, 0] /= aspect
107
108            # Move the origin to the upper-left corner, and scale to [0, 1].
109            # ... make sure to match the OpenCV's image coordinates!
110            vertices_proj = (1 + vertices_proj) / 2
111
112            # Select the vertices that lie inside the 2D selected box, once projected.
113            mask = (
114                (vertices_proj > onp.array(message.screen_pos[0]))
115                & (vertices_proj < onp.array(message.screen_pos[1]))
116            ).all(axis=1)[..., None]
117
118            # Update the mesh color based on whether the vertices are inside the box
119            mesh.visual.vertex_colors = onp.where(  # type: ignore
120                mask, (0.5, 0.0, 0.7, 1.0), (0.9, 0.9, 0.9, 1.0)
121            )
122            mesh_handle = server.add_mesh_trimesh(
123                name="/mesh",
124                mesh=mesh,
125                position=(0.0, 0.0, 0.0),
126            )
127
128        @client.on_scene_pointer_removed
129        def _():
130            paint_button_handle.disabled = False
131
132    # Button to clear spheres.
133    clear_button_handle = client.add_gui_button("Clear scene", icon=viser.Icon.X)
134
135    @clear_button_handle.on_click
136    def _(_):
137        """Reset the mesh color and remove all click-generated spheres."""
138        global mesh_handle
139        for handle in hit_pos_handles:
140            handle.remove()
141        hit_pos_handles.clear()
142        mesh.visual.vertex_colors = (0.9, 0.9, 0.9, 1.0)  # type: ignore
143        mesh_handle = server.add_mesh_trimesh(
144            name="/mesh",
145            mesh=mesh,
146            position=(0.0, 0.0, 0.0),
147        )
148
149
150while True:
151    time.sleep(10.0)