SMPL skinned mesh

Visualize SMPL human body models using skinned mesh deformation. For visualizing human motion sequences, updating bone transformations should be much faster than updating the entire mesh.

See here for download instructions:

https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model

Features:

  • SMPL skinned mesh with real-time deformation

  • Interactive pose parameter controls

Note

This example requires external assets. To download them, run:

git clone -b v1.0.26 https://github.com/viser-project/viser.git
cd viser/examples
./assets/download_assets.sh
python 04_demos/04_smpl_skinned.py  # With viser installed.

Source: examples/04_demos/04_smpl_skinned.py

SMPL skinned mesh

Code

  1# mypy: disable-error-code="assignment"
  2#
  3# Asymmetric properties are supported in Pyright, but not yet in mypy.
  4# - https://github.com/python/mypy/issues/3004
  5# - https://github.com/python/mypy/pull/11643
  6
  7from __future__ import annotations
  8
  9import time
 10from dataclasses import dataclass
 11from pathlib import Path
 12from typing import List, Tuple
 13
 14import numpy as np
 15import tyro
 16
 17import viser
 18import viser.transforms as tf
 19
 20
 21@dataclass(frozen=True)
 22class SmplFkOutputs:
 23    T_world_joint: np.ndarray  # (num_joints, 4, 4)
 24    T_parent_joint: np.ndarray  # (num_joints, 4, 4)
 25
 26
 27class SmplHelper:
 28
 29    def __init__(self, model_path: Path) -> None:
 30        assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
 31        body_dict = dict(**np.load(model_path, allow_pickle=True))
 32
 33        self.J_regressor = body_dict["J_regressor"]
 34        self.weights = body_dict["weights"]
 35        self.v_template = body_dict["v_template"]
 36        self.posedirs = body_dict["posedirs"]
 37        self.shapedirs = body_dict["shapedirs"]
 38        self.faces = body_dict["f"]
 39
 40        self.num_joints: int = self.weights.shape[-1]
 41        self.num_betas: int = self.shapedirs.shape[-1]
 42        self.parent_idx: np.ndarray = body_dict["kintree_table"][0]
 43
 44    def get_tpose(self, betas: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
 45        # Get shaped vertices + joint positions, when all local poses are identity.
 46        v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
 47        j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
 48        return v_tpose, j_tpose
 49
 50    def get_outputs(
 51        self, betas: np.ndarray, joint_rotmats: np.ndarray
 52    ) -> SmplFkOutputs:
 53        # Get shaped vertices + joint positions, when all local poses are identity.
 54        v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
 55        j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
 56
 57        # Local SE(3) transforms.
 58        T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
 59        T_parent_joint[:, :3, :3] = joint_rotmats
 60        T_parent_joint[0, :3, 3] = j_tpose[0]
 61        T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]]
 62
 63        # Forward kinematics.
 64        T_world_joint = T_parent_joint.copy()
 65        for i in range(1, self.num_joints):
 66            T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i]
 67
 68        return SmplFkOutputs(T_world_joint, T_parent_joint)
 69
 70
 71def main(
 72    model_path: Path = Path(__file__).parent / "../assets/SMPLH_NEUTRAL.npz",
 73) -> None:
 74    server = viser.ViserServer()
 75    server.scene.set_up_direction("+y")
 76    server.initial_camera.position = (2.5, 1.0, 2.5)
 77
 78    # Main loop. We'll read pose/shape from the GUI elements, compute the mesh,
 79    # and then send the updated mesh in a loop.
 80    model = SmplHelper(model_path)
 81    gui_elements = make_gui_elements(
 82        server,
 83        num_betas=model.num_betas,
 84        num_joints=model.num_joints,
 85        parent_idx=model.parent_idx,
 86    )
 87    v_tpose, j_tpose = model.get_tpose(np.zeros((model.num_betas,)))
 88    mesh_handle = server.scene.add_mesh_skinned(
 89        "/human",
 90        v_tpose,
 91        model.faces,
 92        bone_wxyzs=tf.SO3.identity(batch_axes=(model.num_joints,)).wxyz,
 93        bone_positions=j_tpose,
 94        skin_weights=model.weights,
 95        wireframe=gui_elements.gui_wireframe.value,
 96        color=gui_elements.gui_rgb.value,
 97    )
 98    server.scene.add_grid("/grid", position=(0.0, -1.3, 0.0), plane="xz")
 99
100    while True:
101        # Do nothing if no change.
102        time.sleep(0.02)
103        if not gui_elements.changed:
104            continue
105
106        # Shapes changed: update vertices / joint positions.
107        if gui_elements.betas_changed:
108            v_tpose, j_tpose = model.get_tpose(
109                np.array([gui_beta.value for gui_beta in gui_elements.gui_betas])
110            )
111            mesh_handle.vertices = v_tpose
112            mesh_handle.bone_positions = j_tpose
113
114        mesh_handle.color = gui_elements.gui_rgb.value
115        gui_elements.changed = False
116        gui_elements.betas_changed = False
117
118        # Render as wireframe?
119        mesh_handle.wireframe = gui_elements.gui_wireframe.value
120
121        # Compute SMPL outputs.
122        smpl_outputs = model.get_outputs(
123            betas=np.array([x.value for x in gui_elements.gui_betas]),
124            joint_rotmats=np.stack(
125                [
126                    tf.SO3.exp(np.array(x.value)).as_matrix()
127                    for x in gui_elements.gui_joints
128                ],
129                axis=0,
130            ),
131        )
132
133        # Match transform control gizmos to joint positions.
134        for i, control in enumerate(gui_elements.transform_controls):
135            control.position = smpl_outputs.T_parent_joint[i, :3, 3]
136            mesh_handle.bones[i].wxyz = tf.SO3.from_matrix(
137                smpl_outputs.T_world_joint[i, :3, :3]
138            ).wxyz
139            mesh_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3]
140
141
142@dataclass
143class GuiElements:
144
145    gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]]
146    gui_wireframe: viser.GuiInputHandle[bool]
147    gui_betas: List[viser.GuiInputHandle[float]]
148    gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]]
149    transform_controls: List[viser.TransformControlsHandle]
150
151    changed: bool
152
153    betas_changed: bool
154
155
156def make_gui_elements(
157    server: viser.ViserServer,
158    num_betas: int,
159    num_joints: int,
160    parent_idx: np.ndarray,
161) -> GuiElements:
162
163    tab_group = server.gui.add_tab_group()
164
165    def set_changed(_) -> None:
166        out.changed = True  # out is defined later!
167
168    def set_betas_changed(_) -> None:
169        out.betas_changed = True
170        out.changed = True
171
172    # GUI elements: mesh settings + visibility.
173    with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
174        gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255))
175        gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False)
176        gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True)
177        gui_control_size = server.gui.add_slider(
178            "Handle size", min=0.0, max=10.0, step=0.01, initial_value=1.0
179        )
180
181        gui_rgb.on_update(set_changed)
182        gui_wireframe.on_update(set_changed)
183
184        @gui_show_controls.on_update
185        def _(_):
186            for control in transform_controls:
187                control.visible = gui_show_controls.value
188
189        @gui_control_size.on_update
190        def _(_):
191            for control in transform_controls:
192                prefixed_joint_name = control.name
193                control.scale = (
194                    0.2
195                    * (0.75 ** prefixed_joint_name.count("/"))
196                    * gui_control_size.value
197                )
198
199    # GUI elements: shape parameters.
200    with tab_group.add_tab("Shape", viser.Icon.BOX):
201        gui_reset_shape = server.gui.add_button("Reset Shape")
202        gui_random_shape = server.gui.add_button("Random Shape")
203
204        @gui_reset_shape.on_click
205        def _(_):
206            for beta in gui_betas:
207                beta.value = 0.0
208
209        @gui_random_shape.on_click
210        def _(_):
211            for beta in gui_betas:
212                beta.value = np.random.normal(loc=0.0, scale=1.0)
213
214        gui_betas = []
215        for i in range(num_betas):
216            beta = server.gui.add_slider(
217                f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0
218            )
219            gui_betas.append(beta)
220            beta.on_update(set_betas_changed)
221
222    # GUI elements: joint angles.
223    with tab_group.add_tab("Joints", viser.Icon.ANGLE):
224        gui_reset_joints = server.gui.add_button("Reset Joints")
225        gui_random_joints = server.gui.add_button("Random Joints")
226
227        @gui_reset_joints.on_click
228        def _(_):
229            for joint in gui_joints:
230                joint.value = (0.0, 0.0, 0.0)
231
232        @gui_random_joints.on_click
233        def _(_):
234            rng = np.random.default_rng()
235            for joint in gui_joints:
236                joint.value = tf.SO3.sample_uniform(rng).log()
237
238        gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = []
239        for i in range(num_joints):
240            gui_joint = server.gui.add_vector3(
241                label=f"Joint {i}",
242                initial_value=(0.0, 0.0, 0.0),
243                step=0.05,
244            )
245            gui_joints.append(gui_joint)
246
247            def set_callback_in_closure(i: int) -> None:
248                @gui_joint.on_update
249                def _(_):
250                    transform_controls[i].wxyz = tf.SO3.exp(
251                        np.array(gui_joints[i].value)
252                    ).wxyz
253                    out.changed = True
254
255            set_callback_in_closure(i)
256
257    # Transform control gizmos on joints.
258    transform_controls: List[viser.TransformControlsHandle] = []
259    prefixed_joint_names = []  # Joint names, but prefixed with parents.
260    for i in range(num_joints):
261        prefixed_joint_name = f"joint_{i}"
262        if i > 0:
263            prefixed_joint_name = (
264                prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name
265            )
266        prefixed_joint_names.append(prefixed_joint_name)
267        controls = server.scene.add_transform_controls(
268            f"/smpl/{prefixed_joint_name}",
269            depth_test=False,
270            scale=0.2 * (0.75 ** prefixed_joint_name.count("/")),
271            disable_axes=True,
272            disable_sliders=True,
273            visible=gui_show_controls.value,
274        )
275        transform_controls.append(controls)
276
277        def set_callback_in_closure(i: int) -> None:
278            @controls.on_update
279            def _(_) -> None:
280                axisangle = tf.SO3(transform_controls[i].wxyz).log()
281                gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2])
282
283        set_callback_in_closure(i)
284
285    out = GuiElements(
286        gui_rgb,
287        gui_wireframe,
288        gui_betas,
289        gui_joints,
290        transform_controls=transform_controls,
291        changed=True,
292        betas_changed=False,
293    )
294    return out
295
296
297if __name__ == "__main__":
298    tyro.cli(main, description=__doc__)