Gaussian splats¶
Viser includes a WebGL-based Gaussian splat renderer.
Features:
viser.SceneApi.add_gaussian_splats()
to add a Gaussian splat objectCorrect sorting when multiple splat objects are present
Compositing with other scene objects
Note
This example requires external assets. To download them, run:
cd /path/to/viser/examples/assets
./download_assets.sh
Source: examples/01_scene/09_gaussian_splats.py

Code¶
1from __future__ import annotations
2
3import time
4from pathlib import Path
5from typing import TypedDict
6
7import numpy as np
8import numpy.typing as npt
9import tyro
10from plyfile import PlyData
11
12import viser
13from viser import transforms as tf
14
15
16class SplatFile(TypedDict):
17
18 centers: npt.NDArray[np.floating]
19 rgbs: npt.NDArray[np.floating]
20 opacities: npt.NDArray[np.floating]
21 covariances: npt.NDArray[np.floating]
22
23
24def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile:
25 start_time = time.time()
26 splat_buffer = splat_path.read_bytes()
27 bytes_per_gaussian = (
28 # Each Gaussian is serialized as:
29 # - position (vec3, float32)
30 3 * 4
31 # - xyz (vec3, float32)
32 + 3 * 4
33 # - rgba (vec4, uint8)
34 + 4
35 # - ijkl (vec4, uint8), where 0 => -1, 255 => 1.
36 + 4
37 )
38 assert len(splat_buffer) % bytes_per_gaussian == 0
39 num_gaussians = len(splat_buffer) // bytes_per_gaussian
40
41 # Reinterpret cast to dtypes that we want to extract.
42 splat_uint8 = np.frombuffer(splat_buffer, dtype=np.uint8).reshape(
43 (num_gaussians, bytes_per_gaussian)
44 )
45 scales = splat_uint8[:, 12:24].copy().view(np.float32)
46 wxyzs = splat_uint8[:, 28:32] / 255.0 * 2.0 - 1.0
47 Rs = tf.SO3(wxyzs).as_matrix()
48 covariances = np.einsum(
49 "nij,njk,nlk->nil", Rs, np.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs
50 )
51 centers = splat_uint8[:, 0:12].copy().view(np.float32)
52 if center:
53 centers -= np.mean(centers, axis=0, keepdims=True)
54 print(
55 f"Splat file with {num_gaussians=} loaded in {time.time() - start_time} seconds"
56 )
57 return {
58 "centers": centers,
59 # Colors should have shape (N, 3).
60 "rgbs": splat_uint8[:, 24:27] / 255.0,
61 "opacities": splat_uint8[:, 27:28] / 255.0,
62 # Covariances should have shape (N, 3, 3).
63 "covariances": covariances,
64 }
65
66
67def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile:
68 start_time = time.time()
69
70 SH_C0 = 0.28209479177387814
71
72 plydata = PlyData.read(ply_file_path)
73 v = plydata["vertex"]
74 positions = np.stack([v["x"], v["y"], v["z"]], axis=-1)
75 scales = np.exp(np.stack([v["scale_0"], v["scale_1"], v["scale_2"]], axis=-1))
76 wxyzs = np.stack([v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], axis=1)
77 colors = 0.5 + SH_C0 * np.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1)
78 opacities = 1.0 / (1.0 + np.exp(-v["opacity"][:, None]))
79
80 Rs = tf.SO3(wxyzs).as_matrix()
81 covariances = np.einsum(
82 "nij,njk,nlk->nil", Rs, np.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs
83 )
84 if center:
85 positions -= np.mean(positions, axis=0, keepdims=True)
86
87 num_gaussians = len(v)
88 print(
89 f"PLY file with {num_gaussians=} loaded in {time.time() - start_time} seconds"
90 )
91 return {
92 "centers": positions,
93 "rgbs": colors,
94 "opacities": opacities,
95 "covariances": covariances,
96 }
97
98
99def main(
100 splat_paths: tuple[Path, ...] = (
101 # Path(__file__).absolute().parent.parent / "assets" / "train.splat",
102 Path(__file__).absolute().parent.parent / "assets" / "nike.splat",
103 ),
104) -> None:
105 server = viser.ViserServer()
106
107 for i, splat_path in enumerate(splat_paths):
108 if splat_path.suffix == ".splat":
109 splat_data = load_splat_file(splat_path, center=True)
110 elif splat_path.suffix == ".ply":
111 splat_data = load_ply_file(splat_path, center=True)
112 else:
113 raise SystemExit("Please provide a filepath to a .splat or .ply file.")
114
115 server.scene.add_transform_controls(f"/{i}")
116 gs_handle = server.scene.add_gaussian_splats(
117 f"/{i}/gaussian_splats",
118 centers=splat_data["centers"],
119 rgbs=splat_data["rgbs"],
120 opacities=splat_data["opacities"],
121 covariances=splat_data["covariances"],
122 )
123
124 remove_button = server.gui.add_button(f"Remove splat object {i}")
125
126 @remove_button.on_click
127 def _(_, gs_handle=gs_handle, remove_button=remove_button) -> None:
128 gs_handle.remove()
129 remove_button.remove()
130
131 while True:
132 time.sleep(10.0)
133
134
135if __name__ == "__main__":
136 tyro.cli(main)