COLMAP visualizerΒΆ
Visualize COLMAP sparse reconstruction outputs. To get demo data, see ./assets/download_colmap_garden.sh
.
1import random
2import time
3from pathlib import Path
4from typing import List
5
6import imageio.v3 as iio
7import numpy as np
8import tyro
9from tqdm.auto import tqdm
10
11import viser
12import viser.transforms as vtf
13from viser.extras.colmap import (
14 read_cameras_binary,
15 read_images_binary,
16 read_points3d_binary,
17)
18
19
20def main(
21 colmap_path: Path = Path(__file__).parent / "assets/colmap_garden/sparse/0",
22 images_path: Path = Path(__file__).parent / "assets/colmap_garden/images_8",
23 downsample_factor: int = 2,
24 reorient_scene: bool = True,
25) -> None:
26 """Visualize COLMAP sparse reconstruction outputs.
27
28 Args:
29 colmap_path: Path to the COLMAP reconstruction directory.
30 images_path: Path to the COLMAP images directory.
31 downsample_factor: Downsample factor for the images.
32 """
33 server = viser.ViserServer()
34 server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
35
36 server.scene.enable_default_lights(cast_shadow=True)
37
38 # Load the colmap info.
39 cameras = read_cameras_binary(colmap_path / "cameras.bin")
40 images = read_images_binary(colmap_path / "images.bin")
41 points3d = read_points3d_binary(colmap_path / "points3D.bin")
42
43 points = np.array([points3d[p_id].xyz for p_id in points3d])
44 colors = np.array([points3d[p_id].rgb for p_id in points3d])
45
46 gui_reset_up = server.gui.add_button(
47 "Reset up direction",
48 hint="Set the camera control 'up' direction to the current camera's 'up'.",
49 )
50
51 # Let's rotate the scene so the average camera direction is pointing up.
52 if reorient_scene:
53 average_up = (
54 vtf.SO3(np.array([img.qvec for img in images.values()]))
55 @ np.array([0.0, -1.0, 0.0]) # -y is up in the local frame!
56 ).mean(axis=0)
57 average_up /= np.linalg.norm(average_up)
58
59 rotate_axis = np.cross(average_up, np.array([0.0, 0.0, 1.0]))
60 rotate_axis /= np.linalg.norm(rotate_axis)
61 rotate_angle = np.arccos(np.dot(average_up, np.array([0.0, 0.0, 1.0])))
62 R_scene_colmap = vtf.SO3.exp(rotate_axis * rotate_angle)
63 server.scene.add_frame(
64 "/colmap",
65 show_axes=False,
66 wxyz=R_scene_colmap.wxyz,
67 )
68 else:
69 R_scene_colmap = vtf.SO3.identity()
70
71 # Get transformed z-coordinates and place grid at 5th percentile height.
72 transformed_z = (R_scene_colmap @ points)[..., 2]
73 grid_height = float(np.percentile(transformed_z, 5))
74 server.scene.add_grid(name="/grid", position=(0.0, 0.0, grid_height))
75
76 @gui_reset_up.on_click
77 def _(event: viser.GuiEvent) -> None:
78 client = event.client
79 assert client is not None
80 client.camera.up_direction = vtf.SO3(client.camera.wxyz) @ np.array(
81 [0.0, -1.0, 0.0]
82 )
83
84 gui_points = server.gui.add_slider(
85 "Max points",
86 min=1,
87 max=len(points3d),
88 step=1,
89 initial_value=min(len(points3d), 50_000),
90 )
91 gui_frames = server.gui.add_slider(
92 "Max frames",
93 min=1,
94 max=len(images),
95 step=1,
96 initial_value=min(len(images), 100),
97 )
98 gui_point_size = server.gui.add_slider(
99 "Point size", min=0.01, max=0.1, step=0.001, initial_value=0.05
100 )
101
102 point_mask = np.random.choice(points.shape[0], gui_points.value, replace=False)
103 point_cloud = server.scene.add_point_cloud(
104 name="/colmap/pcd",
105 points=points[point_mask],
106 colors=colors[point_mask],
107 point_size=gui_point_size.value,
108 )
109 frames: List[viser.FrameHandle] = []
110
111 def visualize_frames() -> None:
112 """Send all COLMAP elements to viser for visualization. This could be optimized
113 a ton!"""
114
115 # Remove existing image frames.
116 for frame in frames:
117 frame.remove()
118 frames.clear()
119
120 # Interpret the images and cameras.
121 img_ids = [im.id for im in images.values()]
122 random.shuffle(img_ids)
123 img_ids = sorted(img_ids[: gui_frames.value])
124
125 def attach_callback(
126 frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle
127 ) -> None:
128 @frustum.on_click
129 def _(_) -> None:
130 for client in server.get_clients().values():
131 client.camera.wxyz = frame.wxyz
132 client.camera.position = frame.position
133
134 for img_id in tqdm(img_ids):
135 img = images[img_id]
136 cam = cameras[img.camera_id]
137
138 # Skip images that don't exist.
139 image_filename = images_path / img.name
140 if not image_filename.exists():
141 continue
142
143 T_world_camera = vtf.SE3.from_rotation_and_translation(
144 vtf.SO3(img.qvec), img.tvec
145 ).inverse()
146 frame = server.scene.add_frame(
147 f"/colmap/frame_{img_id}",
148 wxyz=T_world_camera.rotation().wxyz,
149 position=T_world_camera.translation(),
150 axes_length=0.1,
151 axes_radius=0.005,
152 )
153 frames.append(frame)
154
155 # For pinhole cameras, cam.params will be (fx, fy, cx, cy).
156 if cam.model != "PINHOLE":
157 print(f"Expected pinhole camera, but got {cam.model}")
158
159 H, W = cam.height, cam.width
160 fy = cam.params[1]
161 image = iio.imread(image_filename)
162 image = image[::downsample_factor, ::downsample_factor]
163 frustum = server.scene.add_camera_frustum(
164 f"/colmap/frame_{img_id}/frustum",
165 fov=2 * np.arctan2(H / 2, fy),
166 aspect=W / H,
167 scale=0.15,
168 image=image,
169 )
170 attach_callback(frustum, frame)
171
172 need_update = True
173
174 @gui_points.on_update
175 def _(_) -> None:
176 point_mask = np.random.choice(points.shape[0], gui_points.value, replace=False)
177 point_cloud.points = points[point_mask]
178 point_cloud.colors = colors[point_mask]
179
180 @gui_frames.on_update
181 def _(_) -> None:
182 nonlocal need_update
183 need_update = True
184
185 @gui_point_size.on_update
186 def _(_) -> None:
187 point_cloud.point_size = gui_point_size.value
188
189 while True:
190 if need_update:
191 need_update = False
192 visualize_frames()
193
194 time.sleep(1e-3)
195
196
197if __name__ == "__main__":
198 tyro.cli(main)