SMPL human model

Visualize SMPL human body models with pose and shape parameter controls.

Requires a .npz model file. See here for download instructions:

https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model

Features:

  • SMPL model loading and mesh generation

  • Interactive pose parameter sliders

  • Real-time body shape and pose updates

  • 3D mesh visualization with viser.SceneApi.add_mesh_simple()

Note

This example requires external assets. To download them, run:

git clone -b v1.0.26 https://github.com/viser-project/viser.git
cd viser/examples
./assets/download_assets.sh
python 04_demos/03_smpl_visualizer.py  # With viser installed.

Source: examples/04_demos/03_smpl_visualizer.py

SMPL human model

Code

  1from __future__ import annotations
  2
  3import time
  4from dataclasses import dataclass
  5from pathlib import Path
  6
  7import numpy as np
  8import trimesh
  9import tyro
 10
 11import viser
 12import viser.transforms as tf
 13
 14
 15@dataclass(frozen=True)
 16class SmplOutputs:
 17    vertices: np.ndarray
 18    faces: np.ndarray
 19    T_world_joint: np.ndarray  # (num_joints, 4, 4)
 20    T_parent_joint: np.ndarray  # (num_joints, 4, 4)
 21
 22
 23class SmplHelper:
 24
 25    def __init__(self, model_path: Path) -> None:
 26        assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
 27        body_dict = dict(**np.load(model_path, allow_pickle=True))
 28
 29        self.J_regressor = body_dict["J_regressor"]
 30        self.weights = body_dict["weights"]
 31        self.v_template = body_dict["v_template"]
 32        self.posedirs = body_dict["posedirs"]
 33        self.shapedirs = body_dict["shapedirs"]
 34        self.faces = body_dict["f"]
 35
 36        self.num_joints: int = self.weights.shape[-1]
 37        self.num_betas: int = self.shapedirs.shape[-1]
 38        self.parent_idx: np.ndarray = body_dict["kintree_table"][0]
 39
 40    def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs:
 41        # Get shaped vertices + joint positions, when all local poses are identity.
 42        v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
 43        j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
 44
 45        # Local SE(3) transforms.
 46        T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
 47        T_parent_joint[:, :3, :3] = joint_rotmats
 48        T_parent_joint[0, :3, 3] = j_tpose[0]
 49        T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]]
 50
 51        # Forward kinematics.
 52        T_world_joint = T_parent_joint.copy()
 53        for i in range(1, self.num_joints):
 54            T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i]
 55
 56        # Linear blend skinning.
 57        pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten()
 58        v_blend = v_tpose + np.einsum("byn,n->by", self.posedirs, pose_delta)
 59        v_delta = np.ones((v_blend.shape[0], self.num_joints, 4))
 60        v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :]
 61        v_posed = np.einsum(
 62            "jxy,vj,vjy->vx", T_world_joint[:, :3, :], self.weights, v_delta
 63        )
 64        return SmplOutputs(v_posed, self.faces, T_world_joint, T_parent_joint)
 65
 66
 67def main(
 68    model_path: Path = Path(__file__).parent / "../assets/SMPLH_NEUTRAL.npz",
 69) -> None:
 70    server = viser.ViserServer()
 71    server.scene.set_up_direction("+y")
 72    server.initial_camera.position = (2.5, 1.0, 2.5)
 73    server.scene.add_grid("/grid", position=(0.0, -1.3, 0.0), plane="xz")
 74
 75    # Main loop. We'll read pose/shape from the GUI elements, compute the mesh,
 76    # and then send the updated mesh in a loop.
 77    model = SmplHelper(model_path)
 78    gui_elements = make_gui_elements(
 79        server,
 80        num_betas=model.num_betas,
 81        num_joints=model.num_joints,
 82        parent_idx=model.parent_idx,
 83    )
 84    body_handle = server.scene.add_mesh_simple(
 85        "/human",
 86        model.v_template,
 87        model.faces,
 88        wireframe=gui_elements.gui_wireframe.value,
 89        color=gui_elements.gui_rgb.value,
 90    )
 91
 92    # Add a vertex selector to the mesh. This will allow us to click on
 93    # vertices to get indices.
 94    red_sphere = trimesh.creation.icosphere(radius=0.001, subdivisions=1)
 95    red_sphere.visual.vertex_colors = (255, 0, 0, 255)  # type: ignore
 96    vertex_selector = server.scene.add_batched_meshes_trimesh(
 97        "/selector",
 98        red_sphere,
 99        batched_positions=model.v_template,
100        batched_wxyzs=((1.0, 0.0, 0.0, 0.0),) * model.v_template.shape[0],
101    )
102
103    @vertex_selector.on_click
104    def _(event: viser.SceneNodePointerEvent) -> None:
105        event.client.add_notification(
106            f"Clicked on vertex {event.instance_index}",
107            body="",
108            auto_close_seconds=3.0,
109        )
110
111    while True:
112        # Do nothing if no change.
113        time.sleep(0.02)
114        if not gui_elements.changed:
115            continue
116
117        gui_elements.changed = False
118
119        # If anything has changed, re-compute SMPL outputs.
120        smpl_outputs = model.get_outputs(
121            betas=np.array([x.value for x in gui_elements.gui_betas]),
122            joint_rotmats=tf.SO3.exp(
123                # (num_joints, 3)
124                np.array([x.value for x in gui_elements.gui_joints])
125            ).as_matrix(),
126        )
127
128        # Update the mesh properties based on the SMPL model output + GUI
129        # elements.
130        body_handle.vertices = smpl_outputs.vertices
131        body_handle.wireframe = gui_elements.gui_wireframe.value
132        body_handle.color = gui_elements.gui_rgb.value
133        vertex_selector.batched_positions = smpl_outputs.vertices
134
135        # Match transform control gizmos to joint positions.
136        for i, control in enumerate(gui_elements.transform_controls):
137            control.position = smpl_outputs.T_parent_joint[i, :3, 3]
138
139
140@dataclass
141class GuiElements:
142
143    gui_rgb: viser.GuiInputHandle[tuple[int, int, int]]
144    gui_wireframe: viser.GuiInputHandle[bool]
145    gui_betas: list[viser.GuiInputHandle[float]]
146    gui_joints: list[viser.GuiInputHandle[tuple[float, float, float]]]
147    transform_controls: list[viser.TransformControlsHandle]
148
149    changed: bool
150
151
152def make_gui_elements(
153    server: viser.ViserServer,
154    num_betas: int,
155    num_joints: int,
156    parent_idx: np.ndarray,
157) -> GuiElements:
158
159    tab_group = server.gui.add_tab_group()
160
161    def set_changed(_) -> None:
162        out.changed = True  # out is define later!
163
164    # GUI elements: mesh settings + visibility.
165    with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
166        gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255))
167        gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False)
168        gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True)
169
170        gui_rgb.on_update(set_changed)
171        gui_wireframe.on_update(set_changed)
172
173        @gui_show_controls.on_update
174        def _(_):
175            for control in transform_controls:
176                control.visible = gui_show_controls.value
177
178    # GUI elements: shape parameters.
179    with tab_group.add_tab("Shape", viser.Icon.BOX):
180        gui_reset_shape = server.gui.add_button("Reset Shape")
181        gui_random_shape = server.gui.add_button("Random Shape")
182
183        @gui_reset_shape.on_click
184        def _(_):
185            for beta in gui_betas:
186                beta.value = 0.0
187
188        @gui_random_shape.on_click
189        def _(_):
190            for beta in gui_betas:
191                beta.value = np.random.normal(loc=0.0, scale=1.0)
192
193        gui_betas = []
194        for i in range(num_betas):
195            beta = server.gui.add_slider(
196                f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0
197            )
198            gui_betas.append(beta)
199            beta.on_update(set_changed)
200
201    # GUI elements: joint angles.
202    with tab_group.add_tab("Joints", viser.Icon.ANGLE):
203        gui_reset_joints = server.gui.add_button("Reset Joints")
204        gui_random_joints = server.gui.add_button("Random Joints")
205
206        @gui_reset_joints.on_click
207        def _(_):
208            for joint in gui_joints:
209                joint.value = (0.0, 0.0, 0.0)
210
211        @gui_random_joints.on_click
212        def _(_):
213            rng = np.random.default_rng()
214            for joint in gui_joints:
215                joint.value = tf.SO3.sample_uniform(rng).log()
216
217        gui_joints: list[viser.GuiInputHandle[tuple[float, float, float]]] = []
218        for i in range(num_joints):
219            gui_joint = server.gui.add_vector3(
220                label=f"Joint {i}",
221                initial_value=(0.0, 0.0, 0.0),
222                step=0.05,
223            )
224            gui_joints.append(gui_joint)
225
226            def set_callback_in_closure(i: int) -> None:
227                @gui_joint.on_update
228                def _(_):
229                    transform_controls[i].wxyz = tf.SO3.exp(
230                        np.array(gui_joints[i].value)
231                    ).wxyz
232                    out.changed = True
233
234            set_callback_in_closure(i)
235
236    # Transform control gizmos on joints.
237    transform_controls: list[viser.TransformControlsHandle] = []
238    prefixed_joint_names = []  # Joint names, but prefixed with parents.
239    for i in range(num_joints):
240        prefixed_joint_name = f"joint_{i}"
241        if i > 0:
242            prefixed_joint_name = (
243                prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name
244            )
245        prefixed_joint_names.append(prefixed_joint_name)
246        controls = server.scene.add_transform_controls(
247            f"/smpl/{prefixed_joint_name}",
248            depth_test=False,
249            scale=0.2 * (0.75 ** prefixed_joint_name.count("/")),
250            disable_axes=True,
251            disable_sliders=True,
252            visible=gui_show_controls.value,
253        )
254        transform_controls.append(controls)
255
256        def set_callback_in_closure(i: int) -> None:
257            @controls.on_update
258            def _(_) -> None:
259                axisangle = tf.SO3(transform_controls[i].wxyz).log()
260                gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2])
261
262        set_callback_in_closure(i)
263
264    out = GuiElements(
265        gui_rgb,
266        gui_wireframe,
267        gui_betas,
268        gui_joints,
269        transform_controls=transform_controls,
270        changed=True,
271    )
272    return out
273
274
275if __name__ == "__main__":
276    tyro.cli(main, description=__doc__)