SMPL visualizer (Skinned Mesh)ΒΆ

Requires a .npz model file.

See here for download instructions:

https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model

  1from __future__ import annotations
  2
  3import time
  4from dataclasses import dataclass
  5from pathlib import Path
  6from typing import List, Tuple
  7
  8import numpy as np
  9import tyro
 10
 11import viser
 12import viser.transforms as tf
 13
 14
 15@dataclass(frozen=True)
 16class SmplFkOutputs:
 17    T_world_joint: np.ndarray  # (num_joints, 4, 4)
 18    T_parent_joint: np.ndarray  # (num_joints, 4, 4)
 19
 20
 21class SmplHelper:
 22    """Helper for models in the SMPL family, implemented in numpy. Does not include blend skinning."""
 23
 24    def __init__(self, model_path: Path) -> None:
 25        assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
 26        body_dict = dict(**np.load(model_path, allow_pickle=True))
 27
 28        self.J_regressor = body_dict["J_regressor"]
 29        self.weights = body_dict["weights"]
 30        self.v_template = body_dict["v_template"]
 31        self.posedirs = body_dict["posedirs"]
 32        self.shapedirs = body_dict["shapedirs"]
 33        self.faces = body_dict["f"]
 34
 35        self.num_joints: int = self.weights.shape[-1]
 36        self.num_betas: int = self.shapedirs.shape[-1]
 37        self.parent_idx: np.ndarray = body_dict["kintree_table"][0]
 38
 39    def get_tpose(self, betas: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
 40        # Get shaped vertices + joint positions, when all local poses are identity.
 41        v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
 42        j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
 43        return v_tpose, j_tpose
 44
 45    def get_outputs(
 46        self, betas: np.ndarray, joint_rotmats: np.ndarray
 47    ) -> SmplFkOutputs:
 48        # Get shaped vertices + joint positions, when all local poses are identity.
 49        v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
 50        j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
 51
 52        # Local SE(3) transforms.
 53        T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
 54        T_parent_joint[:, :3, :3] = joint_rotmats
 55        T_parent_joint[0, :3, 3] = j_tpose[0]
 56        T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]]
 57
 58        # Forward kinematics.
 59        T_world_joint = T_parent_joint.copy()
 60        for i in range(1, self.num_joints):
 61            T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i]
 62
 63        return SmplFkOutputs(T_world_joint, T_parent_joint)
 64
 65
 66def main(model_path: Path) -> None:
 67    server = viser.ViserServer()
 68    server.scene.set_up_direction("+y")
 69
 70    # Main loop. We'll read pose/shape from the GUI elements, compute the mesh,
 71    # and then send the updated mesh in a loop.
 72    model = SmplHelper(model_path)
 73    gui_elements = make_gui_elements(
 74        server,
 75        num_betas=model.num_betas,
 76        num_joints=model.num_joints,
 77        parent_idx=model.parent_idx,
 78    )
 79    v_tpose, j_tpose = model.get_tpose(np.zeros((model.num_betas,)))
 80    mesh_handle = server.scene.add_mesh_skinned(
 81        "/human",
 82        v_tpose,
 83        model.faces,
 84        bone_wxyzs=tf.SO3.identity(batch_axes=(model.num_joints,)).wxyz,
 85        bone_positions=j_tpose,
 86        skin_weights=model.weights,
 87        wireframe=gui_elements.gui_wireframe.value,
 88        color=gui_elements.gui_rgb.value,
 89    )
 90    server.scene.add_grid("/grid", position=(0.0, -1.3, 0.0), plane="xz")
 91
 92    while True:
 93        # Do nothing if no change.
 94        time.sleep(0.02)
 95        if not gui_elements.changed:
 96            continue
 97
 98        # Shapes changed: update vertices / joint positions.
 99        if gui_elements.betas_changed:
100            v_tpose, j_tpose = model.get_tpose(
101                np.array([gui_beta.value for gui_beta in gui_elements.gui_betas])
102            )
103            mesh_handle.vertices = v_tpose
104            mesh_handle.bone_positions = j_tpose
105
106        mesh_handle.color = gui_elements.gui_rgb.value
107        gui_elements.changed = False
108        gui_elements.betas_changed = False
109
110        # Render as wireframe?
111        mesh_handle.wireframe = gui_elements.gui_wireframe.value
112
113        # Compute SMPL outputs.
114        smpl_outputs = model.get_outputs(
115            betas=np.array([x.value for x in gui_elements.gui_betas]),
116            joint_rotmats=np.stack(
117                [
118                    tf.SO3.exp(np.array(x.value)).as_matrix()
119                    for x in gui_elements.gui_joints
120                ],
121                axis=0,
122            ),
123        )
124
125        # Match transform control gizmos to joint positions.
126        for i, control in enumerate(gui_elements.transform_controls):
127            control.position = smpl_outputs.T_parent_joint[i, :3, 3]
128            mesh_handle.bones[i].wxyz = tf.SO3.from_matrix(
129                smpl_outputs.T_world_joint[i, :3, :3]
130            ).wxyz
131            mesh_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3]
132
133
134@dataclass
135class GuiElements:
136    """Structure containing handles for reading from GUI elements."""
137
138    gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]]
139    gui_wireframe: viser.GuiInputHandle[bool]
140    gui_betas: List[viser.GuiInputHandle[float]]
141    gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]]
142    transform_controls: List[viser.TransformControlsHandle]
143
144    changed: bool
145    """This flag will be flipped to True whenever any input is changed."""
146
147    betas_changed: bool
148    """This flag will be flipped to True whenever the shape changes."""
149
150
151def make_gui_elements(
152    server: viser.ViserServer,
153    num_betas: int,
154    num_joints: int,
155    parent_idx: np.ndarray,
156) -> GuiElements:
157    """Make GUI elements for interacting with the model."""
158
159    tab_group = server.gui.add_tab_group()
160
161    def set_changed(_) -> None:
162        out.changed = True  # out is defined later!
163
164    def set_betas_changed(_) -> None:
165        out.betas_changed = True
166        out.changed = True
167
168    # GUI elements: mesh settings + visibility.
169    with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
170        gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255))
171        gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False)
172        gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True)
173        gui_control_size = server.gui.add_slider(
174            "Handle size", min=0.0, max=10.0, step=0.01, initial_value=1.0
175        )
176
177        gui_rgb.on_update(set_changed)
178        gui_wireframe.on_update(set_changed)
179
180        @gui_show_controls.on_update
181        def _(_):
182            for control in transform_controls:
183                control.visible = gui_show_controls.value
184
185        @gui_control_size.on_update
186        def _(_):
187            for control in transform_controls:
188                prefixed_joint_name = control.name
189                control.scale = (
190                    0.2
191                    * (0.75 ** prefixed_joint_name.count("/"))
192                    * gui_control_size.value
193                )
194
195    # GUI elements: shape parameters.
196    with tab_group.add_tab("Shape", viser.Icon.BOX):
197        gui_reset_shape = server.gui.add_button("Reset Shape")
198        gui_random_shape = server.gui.add_button("Random Shape")
199
200        @gui_reset_shape.on_click
201        def _(_):
202            for beta in gui_betas:
203                beta.value = 0.0
204
205        @gui_random_shape.on_click
206        def _(_):
207            for beta in gui_betas:
208                beta.value = np.random.normal(loc=0.0, scale=1.0)
209
210        gui_betas = []
211        for i in range(num_betas):
212            beta = server.gui.add_slider(
213                f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0
214            )
215            gui_betas.append(beta)
216            beta.on_update(set_betas_changed)
217
218    # GUI elements: joint angles.
219    with tab_group.add_tab("Joints", viser.Icon.ANGLE):
220        gui_reset_joints = server.gui.add_button("Reset Joints")
221        gui_random_joints = server.gui.add_button("Random Joints")
222
223        @gui_reset_joints.on_click
224        def _(_):
225            for joint in gui_joints:
226                joint.value = (0.0, 0.0, 0.0)
227
228        @gui_random_joints.on_click
229        def _(_):
230            rng = np.random.default_rng()
231            for joint in gui_joints:
232                joint.value = tf.SO3.sample_uniform(rng).log()
233
234        gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = []
235        for i in range(num_joints):
236            gui_joint = server.gui.add_vector3(
237                label=f"Joint {i}",
238                initial_value=(0.0, 0.0, 0.0),
239                step=0.05,
240            )
241            gui_joints.append(gui_joint)
242
243            def set_callback_in_closure(i: int) -> None:
244                @gui_joint.on_update
245                def _(_):
246                    transform_controls[i].wxyz = tf.SO3.exp(
247                        np.array(gui_joints[i].value)
248                    ).wxyz
249                    out.changed = True
250
251            set_callback_in_closure(i)
252
253    # Transform control gizmos on joints.
254    transform_controls: List[viser.TransformControlsHandle] = []
255    prefixed_joint_names = []  # Joint names, but prefixed with parents.
256    for i in range(num_joints):
257        prefixed_joint_name = f"joint_{i}"
258        if i > 0:
259            prefixed_joint_name = (
260                prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name
261            )
262        prefixed_joint_names.append(prefixed_joint_name)
263        controls = server.scene.add_transform_controls(
264            f"/smpl/{prefixed_joint_name}",
265            depth_test=False,
266            scale=0.2 * (0.75 ** prefixed_joint_name.count("/")),
267            disable_axes=True,
268            disable_sliders=True,
269            visible=gui_show_controls.value,
270        )
271        transform_controls.append(controls)
272
273        def set_callback_in_closure(i: int) -> None:
274            @controls.on_update
275            def _(_) -> None:
276                axisangle = tf.SO3(transform_controls[i].wxyz).log()
277                gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2])
278
279        set_callback_in_closure(i)
280
281    out = GuiElements(
282        gui_rgb,
283        gui_wireframe,
284        gui_betas,
285        gui_joints,
286        transform_controls=transform_controls,
287        changed=True,
288        betas_changed=False,
289    )
290    return out
291
292
293if __name__ == "__main__":
294    tyro.cli(main, description=__doc__)