SMPL skinned mesh

Visualize SMPL human body models using skinned mesh deformation. For visualizing human motion sequences, updating bone transformations should be much faster than updating the entire mesh.

See here for download instructions:

https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model

Features:

  • SMPL skinned mesh with real-time deformation

  • Interactive pose parameter controls

Note

This example requires external assets. To download them, run:

cd /path/to/viser/examples/assets
./download_assets.sh

Source: examples/04_demos/04_smpl_skinned.py

SMPL skinned mesh

Code

  1# mypy: disable-error-code="assignment"
  2#
  3# Asymmetric properties are supported in Pyright, but not yet in mypy.
  4# - https://github.com/python/mypy/issues/3004
  5# - https://github.com/python/mypy/pull/11643
  6
  7from __future__ import annotations
  8
  9import time
 10from dataclasses import dataclass
 11from pathlib import Path
 12from typing import List, Tuple
 13
 14import numpy as np
 15import tyro
 16
 17import viser
 18import viser.transforms as tf
 19
 20
 21@dataclass(frozen=True)
 22class SmplFkOutputs:
 23    T_world_joint: np.ndarray  # (num_joints, 4, 4)
 24    T_parent_joint: np.ndarray  # (num_joints, 4, 4)
 25
 26
 27class SmplHelper:
 28
 29    def __init__(self, model_path: Path) -> None:
 30        assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
 31        body_dict = dict(**np.load(model_path, allow_pickle=True))
 32
 33        self.J_regressor = body_dict["J_regressor"]
 34        self.weights = body_dict["weights"]
 35        self.v_template = body_dict["v_template"]
 36        self.posedirs = body_dict["posedirs"]
 37        self.shapedirs = body_dict["shapedirs"]
 38        self.faces = body_dict["f"]
 39
 40        self.num_joints: int = self.weights.shape[-1]
 41        self.num_betas: int = self.shapedirs.shape[-1]
 42        self.parent_idx: np.ndarray = body_dict["kintree_table"][0]
 43
 44    def get_tpose(self, betas: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
 45        # Get shaped vertices + joint positions, when all local poses are identity.
 46        v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
 47        j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
 48        return v_tpose, j_tpose
 49
 50    def get_outputs(
 51        self, betas: np.ndarray, joint_rotmats: np.ndarray
 52    ) -> SmplFkOutputs:
 53        # Get shaped vertices + joint positions, when all local poses are identity.
 54        v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
 55        j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
 56
 57        # Local SE(3) transforms.
 58        T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
 59        T_parent_joint[:, :3, :3] = joint_rotmats
 60        T_parent_joint[0, :3, 3] = j_tpose[0]
 61        T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]]
 62
 63        # Forward kinematics.
 64        T_world_joint = T_parent_joint.copy()
 65        for i in range(1, self.num_joints):
 66            T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i]
 67
 68        return SmplFkOutputs(T_world_joint, T_parent_joint)
 69
 70
 71def main(
 72    model_path: Path = Path(__file__).parent / "../assets/SMPLH_NEUTRAL.npz",
 73) -> None:
 74    server = viser.ViserServer()
 75    server.scene.set_up_direction("+y")
 76
 77    # Main loop. We'll read pose/shape from the GUI elements, compute the mesh,
 78    # and then send the updated mesh in a loop.
 79    model = SmplHelper(model_path)
 80    gui_elements = make_gui_elements(
 81        server,
 82        num_betas=model.num_betas,
 83        num_joints=model.num_joints,
 84        parent_idx=model.parent_idx,
 85    )
 86    v_tpose, j_tpose = model.get_tpose(np.zeros((model.num_betas,)))
 87    mesh_handle = server.scene.add_mesh_skinned(
 88        "/human",
 89        v_tpose,
 90        model.faces,
 91        bone_wxyzs=tf.SO3.identity(batch_axes=(model.num_joints,)).wxyz,
 92        bone_positions=j_tpose,
 93        skin_weights=model.weights,
 94        wireframe=gui_elements.gui_wireframe.value,
 95        color=gui_elements.gui_rgb.value,
 96    )
 97    server.scene.add_grid("/grid", position=(0.0, -1.3, 0.0), plane="xz")
 98
 99    while True:
100        # Do nothing if no change.
101        time.sleep(0.02)
102        if not gui_elements.changed:
103            continue
104
105        # Shapes changed: update vertices / joint positions.
106        if gui_elements.betas_changed:
107            v_tpose, j_tpose = model.get_tpose(
108                np.array([gui_beta.value for gui_beta in gui_elements.gui_betas])
109            )
110            mesh_handle.vertices = v_tpose
111            mesh_handle.bone_positions = j_tpose
112
113        mesh_handle.color = gui_elements.gui_rgb.value
114        gui_elements.changed = False
115        gui_elements.betas_changed = False
116
117        # Render as wireframe?
118        mesh_handle.wireframe = gui_elements.gui_wireframe.value
119
120        # Compute SMPL outputs.
121        smpl_outputs = model.get_outputs(
122            betas=np.array([x.value for x in gui_elements.gui_betas]),
123            joint_rotmats=np.stack(
124                [
125                    tf.SO3.exp(np.array(x.value)).as_matrix()
126                    for x in gui_elements.gui_joints
127                ],
128                axis=0,
129            ),
130        )
131
132        # Match transform control gizmos to joint positions.
133        for i, control in enumerate(gui_elements.transform_controls):
134            control.position = smpl_outputs.T_parent_joint[i, :3, 3]
135            mesh_handle.bones[i].wxyz = tf.SO3.from_matrix(
136                smpl_outputs.T_world_joint[i, :3, :3]
137            ).wxyz
138            mesh_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3]
139
140
141@dataclass
142class GuiElements:
143
144    gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]]
145    gui_wireframe: viser.GuiInputHandle[bool]
146    gui_betas: List[viser.GuiInputHandle[float]]
147    gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]]
148    transform_controls: List[viser.TransformControlsHandle]
149
150    changed: bool
151
152    betas_changed: bool
153
154
155def make_gui_elements(
156    server: viser.ViserServer,
157    num_betas: int,
158    num_joints: int,
159    parent_idx: np.ndarray,
160) -> GuiElements:
161
162    tab_group = server.gui.add_tab_group()
163
164    def set_changed(_) -> None:
165        out.changed = True  # out is defined later!
166
167    def set_betas_changed(_) -> None:
168        out.betas_changed = True
169        out.changed = True
170
171    # GUI elements: mesh settings + visibility.
172    with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
173        gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255))
174        gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False)
175        gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True)
176        gui_control_size = server.gui.add_slider(
177            "Handle size", min=0.0, max=10.0, step=0.01, initial_value=1.0
178        )
179
180        gui_rgb.on_update(set_changed)
181        gui_wireframe.on_update(set_changed)
182
183        @gui_show_controls.on_update
184        def _(_):
185            for control in transform_controls:
186                control.visible = gui_show_controls.value
187
188        @gui_control_size.on_update
189        def _(_):
190            for control in transform_controls:
191                prefixed_joint_name = control.name
192                control.scale = (
193                    0.2
194                    * (0.75 ** prefixed_joint_name.count("/"))
195                    * gui_control_size.value
196                )
197
198    # GUI elements: shape parameters.
199    with tab_group.add_tab("Shape", viser.Icon.BOX):
200        gui_reset_shape = server.gui.add_button("Reset Shape")
201        gui_random_shape = server.gui.add_button("Random Shape")
202
203        @gui_reset_shape.on_click
204        def _(_):
205            for beta in gui_betas:
206                beta.value = 0.0
207
208        @gui_random_shape.on_click
209        def _(_):
210            for beta in gui_betas:
211                beta.value = np.random.normal(loc=0.0, scale=1.0)
212
213        gui_betas = []
214        for i in range(num_betas):
215            beta = server.gui.add_slider(
216                f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0
217            )
218            gui_betas.append(beta)
219            beta.on_update(set_betas_changed)
220
221    # GUI elements: joint angles.
222    with tab_group.add_tab("Joints", viser.Icon.ANGLE):
223        gui_reset_joints = server.gui.add_button("Reset Joints")
224        gui_random_joints = server.gui.add_button("Random Joints")
225
226        @gui_reset_joints.on_click
227        def _(_):
228            for joint in gui_joints:
229                joint.value = (0.0, 0.0, 0.0)
230
231        @gui_random_joints.on_click
232        def _(_):
233            rng = np.random.default_rng()
234            for joint in gui_joints:
235                joint.value = tf.SO3.sample_uniform(rng).log()
236
237        gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = []
238        for i in range(num_joints):
239            gui_joint = server.gui.add_vector3(
240                label=f"Joint {i}",
241                initial_value=(0.0, 0.0, 0.0),
242                step=0.05,
243            )
244            gui_joints.append(gui_joint)
245
246            def set_callback_in_closure(i: int) -> None:
247                @gui_joint.on_update
248                def _(_):
249                    transform_controls[i].wxyz = tf.SO3.exp(
250                        np.array(gui_joints[i].value)
251                    ).wxyz
252                    out.changed = True
253
254            set_callback_in_closure(i)
255
256    # Transform control gizmos on joints.
257    transform_controls: List[viser.TransformControlsHandle] = []
258    prefixed_joint_names = []  # Joint names, but prefixed with parents.
259    for i in range(num_joints):
260        prefixed_joint_name = f"joint_{i}"
261        if i > 0:
262            prefixed_joint_name = (
263                prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name
264            )
265        prefixed_joint_names.append(prefixed_joint_name)
266        controls = server.scene.add_transform_controls(
267            f"/smpl/{prefixed_joint_name}",
268            depth_test=False,
269            scale=0.2 * (0.75 ** prefixed_joint_name.count("/")),
270            disable_axes=True,
271            disable_sliders=True,
272            visible=gui_show_controls.value,
273        )
274        transform_controls.append(controls)
275
276        def set_callback_in_closure(i: int) -> None:
277            @controls.on_update
278            def _(_) -> None:
279                axisangle = tf.SO3(transform_controls[i].wxyz).log()
280                gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2])
281
282        set_callback_in_closure(i)
283
284    out = GuiElements(
285        gui_rgb,
286        gui_wireframe,
287        gui_betas,
288        gui_joints,
289        transform_controls=transform_controls,
290        changed=True,
291        betas_changed=False,
292    )
293    return out
294
295
296if __name__ == "__main__":
297    tyro.cli(main, description=__doc__)