COLMAP visualizerΒΆ

Visualize COLMAP sparse reconstruction outputs. To get demo data, see ./assets/download_colmap_garden.sh.

  1import random
  2import time
  3from pathlib import Path
  4from typing import List
  5
  6import imageio.v3 as iio
  7import numpy as np
  8import tyro
  9from tqdm.auto import tqdm
 10
 11import viser
 12import viser.transforms as vtf
 13from viser.extras.colmap import (
 14    read_cameras_binary,
 15    read_images_binary,
 16    read_points3d_binary,
 17)
 18
 19
 20def main(
 21    colmap_path: Path = Path(__file__).parent / "assets/colmap_garden/sparse/0",
 22    images_path: Path = Path(__file__).parent / "assets/colmap_garden/images_8",
 23    downsample_factor: int = 2,
 24    reorient_scene: bool = True,
 25) -> None:
 26    """Visualize COLMAP sparse reconstruction outputs.
 27
 28    Args:
 29        colmap_path: Path to the COLMAP reconstruction directory.
 30        images_path: Path to the COLMAP images directory.
 31        downsample_factor: Downsample factor for the images.
 32    """
 33    server = viser.ViserServer()
 34    server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
 35
 36    server.scene.enable_default_lights(cast_shadow=True)
 37
 38    # Load the colmap info.
 39    cameras = read_cameras_binary(colmap_path / "cameras.bin")
 40    images = read_images_binary(colmap_path / "images.bin")
 41    points3d = read_points3d_binary(colmap_path / "points3D.bin")
 42
 43    points = np.array([points3d[p_id].xyz for p_id in points3d])
 44    colors = np.array([points3d[p_id].rgb for p_id in points3d])
 45
 46    gui_reset_up = server.gui.add_button(
 47        "Reset up direction",
 48        hint="Set the camera control 'up' direction to the current camera's 'up'.",
 49    )
 50
 51    # Let's rotate the scene so the average camera direction is pointing up.
 52    if reorient_scene:
 53        average_up = (
 54            vtf.SO3(np.array([img.qvec for img in images.values()]))
 55            @ np.array([0.0, -1.0, 0.0])  # -y is up in the local frame!
 56        ).mean(axis=0)
 57        average_up /= np.linalg.norm(average_up)
 58
 59        rotate_axis = np.cross(average_up, np.array([0.0, 0.0, 1.0]))
 60        rotate_axis /= np.linalg.norm(rotate_axis)
 61        rotate_angle = np.arccos(np.dot(average_up, np.array([0.0, 0.0, 1.0])))
 62        R_scene_colmap = vtf.SO3.exp(rotate_axis * rotate_angle)
 63        server.scene.add_frame(
 64            "/colmap",
 65            show_axes=False,
 66            wxyz=R_scene_colmap.wxyz,
 67        )
 68    else:
 69        R_scene_colmap = vtf.SO3.identity()
 70
 71    # Get transformed z-coordinates and place grid at 5th percentile height.
 72    transformed_z = (R_scene_colmap @ points)[..., 2]
 73    grid_height = float(np.percentile(transformed_z, 5))
 74    server.scene.add_grid(name="/grid", position=(0.0, 0.0, grid_height))
 75
 76    @gui_reset_up.on_click
 77    def _(event: viser.GuiEvent) -> None:
 78        client = event.client
 79        assert client is not None
 80        client.camera.up_direction = vtf.SO3(client.camera.wxyz) @ np.array(
 81            [0.0, -1.0, 0.0]
 82        )
 83
 84    gui_points = server.gui.add_slider(
 85        "Max points",
 86        min=1,
 87        max=len(points3d),
 88        step=1,
 89        initial_value=min(len(points3d), 50_000),
 90    )
 91    gui_frames = server.gui.add_slider(
 92        "Max frames",
 93        min=1,
 94        max=len(images),
 95        step=1,
 96        initial_value=min(len(images), 100),
 97    )
 98    gui_point_size = server.gui.add_slider(
 99        "Point size", min=0.01, max=0.1, step=0.001, initial_value=0.05
100    )
101
102    point_mask = np.random.choice(points.shape[0], gui_points.value, replace=False)
103    point_cloud = server.scene.add_point_cloud(
104        name="/colmap/pcd",
105        points=points[point_mask],
106        colors=colors[point_mask],
107        point_size=gui_point_size.value,
108    )
109    frames: List[viser.FrameHandle] = []
110
111    def visualize_frames() -> None:
112        """Send all COLMAP elements to viser for visualization. This could be optimized
113        a ton!"""
114
115        # Remove existing image frames.
116        for frame in frames:
117            frame.remove()
118        frames.clear()
119
120        # Interpret the images and cameras.
121        img_ids = [im.id for im in images.values()]
122        random.shuffle(img_ids)
123        img_ids = sorted(img_ids[: gui_frames.value])
124
125        def attach_callback(
126            frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle
127        ) -> None:
128            @frustum.on_click
129            def _(_) -> None:
130                for client in server.get_clients().values():
131                    client.camera.wxyz = frame.wxyz
132                    client.camera.position = frame.position
133
134        for img_id in tqdm(img_ids):
135            img = images[img_id]
136            cam = cameras[img.camera_id]
137
138            # Skip images that don't exist.
139            image_filename = images_path / img.name
140            if not image_filename.exists():
141                continue
142
143            T_world_camera = vtf.SE3.from_rotation_and_translation(
144                vtf.SO3(img.qvec), img.tvec
145            ).inverse()
146            frame = server.scene.add_frame(
147                f"/colmap/frame_{img_id}",
148                wxyz=T_world_camera.rotation().wxyz,
149                position=T_world_camera.translation(),
150                axes_length=0.1,
151                axes_radius=0.005,
152            )
153            frames.append(frame)
154
155            # For pinhole cameras, cam.params will be (fx, fy, cx, cy).
156            if cam.model != "PINHOLE":
157                print(f"Expected pinhole camera, but got {cam.model}")
158
159            H, W = cam.height, cam.width
160            fy = cam.params[1]
161            image = iio.imread(image_filename)
162            image = image[::downsample_factor, ::downsample_factor]
163            frustum = server.scene.add_camera_frustum(
164                f"/colmap/frame_{img_id}/frustum",
165                fov=2 * np.arctan2(H / 2, fy),
166                aspect=W / H,
167                scale=0.15,
168                image=image,
169            )
170            attach_callback(frustum, frame)
171
172    need_update = True
173
174    @gui_points.on_update
175    def _(_) -> None:
176        point_mask = np.random.choice(points.shape[0], gui_points.value, replace=False)
177        point_cloud.points = points[point_mask]
178        point_cloud.colors = colors[point_mask]
179
180    @gui_frames.on_update
181    def _(_) -> None:
182        nonlocal need_update
183        need_update = True
184
185    @gui_point_size.on_update
186    def _(_) -> None:
187        point_cloud.point_size = gui_point_size.value
188
189    while True:
190        if need_update:
191            need_update = False
192            visualize_frames()
193
194        time.sleep(1e-3)
195
196
197if __name__ == "__main__":
198    tyro.cli(main)