Plots as Images

Examples of sending plots as images to Viser’s GUI panel. This can be faster than using Plotly.

  1from __future__ import annotations
  2
  3import colorsys
  4import time
  5
  6import cv2
  7import numpy as np
  8import tyro
  9
 10import viser
 11import viser.transforms as vtf
 12
 13
 14def get_line_plot(
 15    xs: np.ndarray,
 16    ys: np.ndarray,
 17    height: int,
 18    width: int,
 19    *,
 20    x_bounds: tuple[float, float] | None = None,
 21    y_bounds: tuple[float, float] | None = None,
 22    title: str | None = None,
 23    line_thickness: int = 2,
 24    grid_x_lines: int = 8,
 25    grid_y_lines: int = 5,
 26    font_scale: float = 0.4,
 27    background_color: tuple[int, int, int] = (0, 0, 0),
 28    plot_area_color: tuple[int, int, int] = (0, 0, 0),
 29    grid_color: tuple[int, int, int] = (60, 60, 60),
 30    axes_color: tuple[int, int, int] = (100, 100, 100),
 31    line_color: tuple[int, int, int] = (255, 255, 255),
 32    text_color: tuple[int, int, int] = (200, 200, 200),
 33) -> np.ndarray:
 34    """Create a line plot using OpenCV with axes, labels, and grid.
 35
 36    This is much faster than using libraries like Matplotlib or Plotly, but is
 37    less flexible.
 38    """
 39
 40    if x_bounds is None:
 41        x_bounds = (np.min(xs), np.max(xs.round(decimals=4)))
 42    if y_bounds is None:
 43        y_bounds = (np.min(ys), np.max(ys))
 44
 45    # Calculate text sizes for padding.
 46    font = cv2.FONT_HERSHEY_DUPLEX
 47    sample_y_label = f"{max(abs(y_bounds[0]), abs(y_bounds[1])):.1f}"
 48    y_text_size = cv2.getTextSize(sample_y_label, font, font_scale, 1)[0]
 49
 50    sample_x_label = f"{max(abs(x_bounds[0]), abs(x_bounds[1])):.1f}"
 51    x_text_size = cv2.getTextSize(sample_x_label, font, font_scale, 1)[0]
 52
 53    # Define padding based on font scale.
 54    extra_padding = 8
 55    left_pad = int(y_text_size[0] * 1.5) + extra_padding  # Space for y-axis labels
 56    right_pad = int(10 * font_scale) + extra_padding
 57
 58    # Calculate top padding, accounting for title if present
 59    top_pad = int(10 * font_scale) + extra_padding
 60    title_font_scale = font_scale * 1.5  # Make title slightly larger
 61    if title is not None:
 62        title_size = cv2.getTextSize(title, font, title_font_scale, 1)[0]
 63        top_pad += title_size[1] + int(10 * font_scale)
 64
 65    bottom_pad = int(x_text_size[1] * 2.0) + extra_padding  # Space for x-axis labels
 66
 67    # Create larger image to accommodate padding.
 68    total_height = height
 69    total_width = width
 70    plot_width = width - left_pad - right_pad
 71    plot_height = height - top_pad - bottom_pad
 72    assert plot_width > 0 and plot_height > 0
 73
 74    # Create image with specified background color
 75    img = np.ones((total_height, total_width, 3), dtype=np.uint8)
 76    img[:] = background_color
 77
 78    # Create plot area with specified color
 79    plot_area = np.ones((plot_height, plot_width, 3), dtype=np.uint8)
 80    plot_area[:] = plot_area_color
 81    img[top_pad : top_pad + plot_height, left_pad : left_pad + plot_width] = plot_area
 82
 83    def scale_to_pixels(values, bounds, pixels):
 84        """Scale values from bounds range to pixel coordinates."""
 85        min_val, max_val = bounds
 86        normalized = (values - min_val) / (max_val - min_val)
 87        return (normalized * (pixels - 1)).astype(np.int32)
 88
 89    # Vertical grid lines.
 90    for i in range(grid_x_lines):
 91        x_pos = left_pad + int(plot_width * i / (grid_x_lines - 1))
 92        cv2.line(img, (x_pos, top_pad), (x_pos, top_pad + plot_height), grid_color, 1)
 93
 94    # Horizontal grid lines.
 95    for i in range(grid_y_lines):
 96        y_pos = top_pad + int(plot_height * i / (grid_y_lines - 1))
 97        cv2.line(img, (left_pad, y_pos), (left_pad + plot_width, y_pos), grid_color, 1)
 98
 99    # Draw axes.
100    cv2.line(
101        img,
102        (left_pad, top_pad + plot_height),
103        (left_pad + plot_width, top_pad + plot_height),
104        axes_color,
105        1,
106    )  # x-axis
107    cv2.line(
108        img, (left_pad, top_pad), (left_pad, top_pad + plot_height), axes_color, 1
109    )  # y-axis
110
111    # Scale and plot the data.
112    x_scaled = scale_to_pixels(xs, x_bounds, plot_width) + left_pad
113    y_scaled = top_pad + plot_height - 1 - scale_to_pixels(ys, y_bounds, plot_height)
114    pts = np.column_stack((x_scaled, y_scaled)).reshape((-1, 1, 2))
115
116    # Draw the main plot line.
117    cv2.polylines(
118        img, [pts], False, line_color, thickness=line_thickness, lineType=cv2.LINE_AA
119    )
120
121    # Draw title if specified
122    if title is not None:
123        title_size = cv2.getTextSize(title, font, title_font_scale, 1)[0]
124        title_x = left_pad + (plot_width - title_size[0]) // 2
125        title_y = int(top_pad / 2) + title_size[1] // 2 - 1
126        cv2.putText(
127            img,
128            title,
129            (title_x, title_y),
130            font,
131            title_font_scale,
132            text_color,
133            1,
134            cv2.LINE_AA,
135        )
136
137    # X-axis labels.
138    for i in range(grid_x_lines):
139        x_val = x_bounds[0] + (x_bounds[1] - x_bounds[0]) * i / (grid_x_lines - 1)
140        x_pos = left_pad + int(plot_width * i / (grid_x_lines - 1))
141        label = f"{x_val:.1f}"
142        if label == "-0.0":
143            label = "0.0"
144        text_size = cv2.getTextSize(label, font, font_scale, 1)[0]
145        cv2.putText(
146            img,
147            label,
148            (x_pos - text_size[0] // 2, top_pad + plot_height + text_size[1] + 10),
149            font,
150            font_scale,
151            text_color,
152            1,
153            cv2.LINE_AA,
154        )
155
156    # Y-axis labels.
157    for i in range(grid_y_lines):
158        y_val = y_bounds[0] + (y_bounds[1] - y_bounds[0]) * (grid_y_lines - 1 - i) / (
159            grid_y_lines - 1
160        )
161        y_pos = top_pad + int(plot_height * i / (grid_y_lines - 1))
162        label = f"{y_val:.1f}"
163        if label == "-0.0":
164            label = "0.0"
165        text_size = cv2.getTextSize(label, font, font_scale, 1)[0]
166        cv2.putText(
167            img,
168            label,
169            (left_pad - text_size[0] - 5, y_pos + 5),
170            font,
171            font_scale,
172            text_color,
173            1,
174            cv2.LINE_AA,
175        )
176
177    return img
178
179
180def create_sine_plot(title: str, counter: int) -> np.ndarray:
181    """Create a sine wave plot with the given counter offset."""
182    xs = np.linspace(0, 2 * np.pi, 20)
183    rgb = colorsys.hsv_to_rgb(counter / 4000 % 1, 1, 1)
184    return get_line_plot(
185        xs=xs,
186        ys=np.sin(xs + counter / 20),
187        title=title,
188        line_color=(int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255)),
189        height=150,
190        width=350,
191    )
192
193
194def main(num_plots: int = 8) -> None:
195    server = viser.ViserServer()
196
197    # Create GUI elements for display runtimes.
198    with server.gui.add_folder("Runtime"):
199        draw_time = server.gui.add_text("Draw / plot (ms)", "0.00", disabled=True)
200        send_gui_time = server.gui.add_text(
201            "Gui update / plot (ms)", "0.00", disabled=True
202        )
203        send_scene_time = server.gui.add_text(
204            "Scene update / plot (ms)", "0.00", disabled=True
205        )
206
207    # Add 2D plots to the GUI.
208    with server.gui.add_folder("Plots"):
209        plots_cb = server.gui.add_checkbox("Update plots", True)
210        gui_image_handles = [
211            server.gui.add_image(
212                create_sine_plot(f"Plot {i}", counter=0),
213                label=f"Image {i}",
214                format="jpeg",
215            )
216            for i in range(num_plots)
217        ]
218
219    # Add 2D plots to the scene. We flip them with a parent coordinate frame.
220    server.scene.add_frame(
221        "/images", wxyz=vtf.SO3.from_y_radians(np.pi).wxyz, show_axes=False
222    )
223    scene_image_handles = [
224        server.scene.add_image(
225            f"/images/plot{i}",
226            image=gui_image_handles[i].image,
227            render_width=3.5,
228            render_height=1.5,
229            format="jpeg",
230            position=(
231                (i % 2 - 0.5) * 3.5,
232                (i // 2 - (num_plots - 1) / 4) * 1.5,
233                0,
234            ),
235        )
236        for i in range(num_plots)
237    ]
238
239    counter = 0
240
241    while True:
242        if plots_cb.value:
243            # Create and time the plot generation.
244            start = time.time()
245            images = [
246                create_sine_plot(f"Plot {i}", counter=counter * (i + 1))
247                for i in range(num_plots)
248            ]
249            draw_time.value = f"{0.98 * float(draw_time.value) + 0.02 * (time.time() - start) / num_plots * 1000:.2f}"
250
251            # Update all plot images.
252            start = time.time()
253            for i, handle in enumerate(gui_image_handles):
254                handle.image = images[i]
255            send_gui_time.value = f"{0.98 * float(send_gui_time.value) + 0.02 * (time.time() - start) / num_plots * 1000:.2f}"
256
257            # Update all scene images.
258            start = time.time()
259            for i, handle in enumerate(scene_image_handles):
260                handle.image = gui_image_handles[i].image
261            send_scene_time.value = f"{0.98 * float(send_scene_time.value) + 0.02 * (time.time() - start) / num_plots * 1000:.2f}"
262
263        # Sleep a bit before continuing.
264        time.sleep(0.02)
265        counter += 1
266
267
268if __name__ == "__main__":
269    tyro.cli(main)