Gaussian splats

Viser includes a WebGL-based Gaussian splat renderer.

Features:

Note

This example requires external assets. To download them, run:

cd /path/to/viser/examples/assets
./download_assets.sh

Source: examples/01_scene/09_gaussian_splats.py

Gaussian splats

Code

  1from __future__ import annotations
  2
  3import time
  4from pathlib import Path
  5from typing import TypedDict
  6
  7import numpy as np
  8import numpy.typing as npt
  9import tyro
 10from plyfile import PlyData
 11
 12import viser
 13from viser import transforms as tf
 14
 15
 16class SplatFile(TypedDict):
 17
 18    centers: npt.NDArray[np.floating]
 19    rgbs: npt.NDArray[np.floating]
 20    opacities: npt.NDArray[np.floating]
 21    covariances: npt.NDArray[np.floating]
 22
 23
 24def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile:
 25    start_time = time.time()
 26    splat_buffer = splat_path.read_bytes()
 27    bytes_per_gaussian = (
 28        # Each Gaussian is serialized as:
 29        # - position (vec3, float32)
 30        3 * 4
 31        # - xyz (vec3, float32)
 32        + 3 * 4
 33        # - rgba (vec4, uint8)
 34        + 4
 35        # - ijkl (vec4, uint8), where 0 => -1, 255 => 1.
 36        + 4
 37    )
 38    assert len(splat_buffer) % bytes_per_gaussian == 0
 39    num_gaussians = len(splat_buffer) // bytes_per_gaussian
 40
 41    # Reinterpret cast to dtypes that we want to extract.
 42    splat_uint8 = np.frombuffer(splat_buffer, dtype=np.uint8).reshape(
 43        (num_gaussians, bytes_per_gaussian)
 44    )
 45    scales = splat_uint8[:, 12:24].copy().view(np.float32)
 46    wxyzs = splat_uint8[:, 28:32] / 255.0 * 2.0 - 1.0
 47    Rs = tf.SO3(wxyzs).as_matrix()
 48    covariances = np.einsum(
 49        "nij,njk,nlk->nil", Rs, np.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs
 50    )
 51    centers = splat_uint8[:, 0:12].copy().view(np.float32)
 52    if center:
 53        centers -= np.mean(centers, axis=0, keepdims=True)
 54    print(
 55        f"Splat file with {num_gaussians=} loaded in {time.time() - start_time} seconds"
 56    )
 57    return {
 58        "centers": centers,
 59        # Colors should have shape (N, 3).
 60        "rgbs": splat_uint8[:, 24:27] / 255.0,
 61        "opacities": splat_uint8[:, 27:28] / 255.0,
 62        # Covariances should have shape (N, 3, 3).
 63        "covariances": covariances,
 64    }
 65
 66
 67def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile:
 68    start_time = time.time()
 69
 70    SH_C0 = 0.28209479177387814
 71
 72    plydata = PlyData.read(ply_file_path)
 73    v = plydata["vertex"]
 74    positions = np.stack([v["x"], v["y"], v["z"]], axis=-1)
 75    scales = np.exp(np.stack([v["scale_0"], v["scale_1"], v["scale_2"]], axis=-1))
 76    wxyzs = np.stack([v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], axis=1)
 77    colors = 0.5 + SH_C0 * np.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1)
 78    opacities = 1.0 / (1.0 + np.exp(-v["opacity"][:, None]))
 79
 80    Rs = tf.SO3(wxyzs).as_matrix()
 81    covariances = np.einsum(
 82        "nij,njk,nlk->nil", Rs, np.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs
 83    )
 84    if center:
 85        positions -= np.mean(positions, axis=0, keepdims=True)
 86
 87    num_gaussians = len(v)
 88    print(
 89        f"PLY file with {num_gaussians=} loaded in {time.time() - start_time} seconds"
 90    )
 91    return {
 92        "centers": positions,
 93        "rgbs": colors,
 94        "opacities": opacities,
 95        "covariances": covariances,
 96    }
 97
 98
 99def main(
100    splat_paths: tuple[Path, ...] = (
101        # Path(__file__).absolute().parent.parent / "assets" / "train.splat",
102        Path(__file__).absolute().parent.parent / "assets" / "nike.splat",
103    ),
104) -> None:
105    server = viser.ViserServer()
106
107    for i, splat_path in enumerate(splat_paths):
108        if splat_path.suffix == ".splat":
109            splat_data = load_splat_file(splat_path, center=True)
110        elif splat_path.suffix == ".ply":
111            splat_data = load_ply_file(splat_path, center=True)
112        else:
113            raise SystemExit("Please provide a filepath to a .splat or .ply file.")
114
115        server.scene.add_transform_controls(f"/{i}")
116        gs_handle = server.scene.add_gaussian_splats(
117            f"/{i}/gaussian_splats",
118            centers=splat_data["centers"],
119            rgbs=splat_data["rgbs"],
120            opacities=splat_data["opacities"],
121            covariances=splat_data["covariances"],
122        )
123
124        remove_button = server.gui.add_button(f"Remove splat object {i}")
125
126        @remove_button.on_click
127        def _(_, gs_handle=gs_handle, remove_button=remove_button) -> None:
128            gs_handle.remove()
129            remove_button.remove()
130
131    while True:
132        time.sleep(10.0)
133
134
135if __name__ == "__main__":
136    tyro.cli(main)