SMPL model visualizer#

Visualizer for SMPL human body models. Requires a .npz model file.

See here for download instructions:

https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model

  1from __future__ import annotations
  2
  3import time
  4from dataclasses import dataclass
  5from pathlib import Path
  6
  7import numpy as np
  8import numpy as onp
  9import tyro
 10import viser
 11import viser.transforms as tf
 12
 13
 14@dataclass(frozen=True)
 15class SmplOutputs:
 16    vertices: np.ndarray
 17    faces: np.ndarray
 18    T_world_joint: np.ndarray  # (num_joints, 4, 4)
 19    T_parent_joint: np.ndarray  # (num_joints, 4, 4)
 20
 21
 22class SmplHelper:
 23    """Helper for models in the SMPL family, implemented in numpy."""
 24
 25    def __init__(self, model_path: Path) -> None:
 26        assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
 27        body_dict = dict(**onp.load(model_path, allow_pickle=True))
 28
 29        self._J_regressor = body_dict["J_regressor"]
 30        self._weights = body_dict["weights"]
 31        self._v_template = body_dict["v_template"]
 32        self._posedirs = body_dict["posedirs"]
 33        self._shapedirs = body_dict["shapedirs"]
 34        self._faces = body_dict["f"]
 35
 36        self.num_joints: int = self._weights.shape[-1]
 37        self.num_betas: int = self._shapedirs.shape[-1]
 38        self.parent_idx: np.ndarray = body_dict["kintree_table"][0]
 39
 40    def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs:
 41        # Get shaped vertices + joint positions, when all local poses are identity.
 42        v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas)
 43        j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose)
 44
 45        # Local SE(3) transforms.
 46        T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
 47        T_parent_joint[:, :3, :3] = joint_rotmats
 48        T_parent_joint[0, :3, 3] = j_tpose[0]
 49        T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]]
 50
 51        # Forward kinematics.
 52        T_world_joint = T_parent_joint.copy()
 53        for i in range(1, self.num_joints):
 54            T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i]
 55
 56        # Linear blend skinning.
 57        pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten()
 58        v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta)
 59        v_delta = np.ones((v_blend.shape[0], self.num_joints, 4))
 60        v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :]
 61        v_posed = np.einsum(
 62            "jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta
 63        )
 64        return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint)
 65
 66
 67def main(model_path: Path) -> None:
 68    server = viser.ViserServer()
 69    server.scene.set_up_direction("+y")
 70    server.gui.configure_theme(control_layout="collapsible")
 71
 72    # Main loop. We'll read pose/shape from the GUI elements, compute the mesh,
 73    # and then send the updated mesh in a loop.
 74    model = SmplHelper(model_path)
 75    gui_elements = make_gui_elements(
 76        server,
 77        num_betas=model.num_betas,
 78        num_joints=model.num_joints,
 79        parent_idx=model.parent_idx,
 80    )
 81    while True:
 82        # Do nothing if no change.
 83        time.sleep(0.02)
 84        if not gui_elements.changed:
 85            continue
 86
 87        gui_elements.changed = False
 88
 89        # Compute SMPL outputs.
 90        smpl_outputs = model.get_outputs(
 91            betas=np.array([x.value for x in gui_elements.gui_betas]),
 92            joint_rotmats=tf.SO3.exp(
 93                # (num_joints, 3)
 94                np.array([x.value for x in gui_elements.gui_joints])
 95            ).as_matrix(),
 96        )
 97        server.scene.add_mesh_simple(
 98            "/human",
 99            smpl_outputs.vertices,
100            smpl_outputs.faces,
101            wireframe=gui_elements.gui_wireframe.value,
102            color=gui_elements.gui_rgb.value,
103        )
104
105        # Match transform control gizmos to joint positions.
106        for i, control in enumerate(gui_elements.transform_controls):
107            control.position = smpl_outputs.T_parent_joint[i, :3, 3]
108
109
110@dataclass
111class GuiElements:
112    """Structure containing handles for reading from GUI elements."""
113
114    gui_rgb: viser.GuiInputHandle[tuple[int, int, int]]
115    gui_wireframe: viser.GuiInputHandle[bool]
116    gui_betas: list[viser.GuiInputHandle[float]]
117    gui_joints: list[viser.GuiInputHandle[tuple[float, float, float]]]
118    transform_controls: list[viser.TransformControlsHandle]
119
120    changed: bool
121    """This flag will be flipped to True whenever the mesh needs to be re-generated."""
122
123
124def make_gui_elements(
125    server: viser.ViserServer,
126    num_betas: int,
127    num_joints: int,
128    parent_idx: np.ndarray,
129) -> GuiElements:
130    """Make GUI elements for interacting with the model."""
131
132    tab_group = server.gui.add_tab_group()
133
134    def set_changed(_) -> None:
135        out.changed = True  # out is define later!
136
137    # GUI elements: mesh settings + visibility.
138    with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
139        gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255))
140        gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False)
141        gui_show_controls = server.gui.add_checkbox("Handles", initial_value=False)
142
143        gui_rgb.on_update(set_changed)
144        gui_wireframe.on_update(set_changed)
145
146        @gui_show_controls.on_update
147        def _(_):
148            for control in transform_controls:
149                control.visible = gui_show_controls.value
150
151    # GUI elements: shape parameters.
152    with tab_group.add_tab("Shape", viser.Icon.BOX):
153        gui_reset_shape = server.gui.add_button("Reset Shape")
154        gui_random_shape = server.gui.add_button("Random Shape")
155
156        @gui_reset_shape.on_click
157        def _(_):
158            for beta in gui_betas:
159                beta.value = 0.0
160
161        @gui_random_shape.on_click
162        def _(_):
163            for beta in gui_betas:
164                beta.value = onp.random.normal(loc=0.0, scale=1.0)
165
166        gui_betas = []
167        for i in range(num_betas):
168            beta = server.gui.add_slider(
169                f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0
170            )
171            gui_betas.append(beta)
172            beta.on_update(set_changed)
173
174    # GUI elements: joint angles.
175    with tab_group.add_tab("Joints", viser.Icon.ANGLE):
176        gui_reset_joints = server.gui.add_button("Reset Joints")
177        gui_random_joints = server.gui.add_button("Random Joints")
178
179        @gui_reset_joints.on_click
180        def _(_):
181            for joint in gui_joints:
182                joint.value = (0.0, 0.0, 0.0)
183
184        @gui_random_joints.on_click
185        def _(_):
186            for joint in gui_joints:
187                # It's hard to uniformly sample orientations directly in so(3), so we
188                # first sample on S^3 and then convert.
189                quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,))
190                quat /= onp.linalg.norm(quat)
191                joint.value = tf.SO3(wxyz=quat).log()
192
193        gui_joints: list[viser.GuiInputHandle[tuple[float, float, float]]] = []
194        for i in range(num_joints):
195            gui_joint = server.gui.add_vector3(
196                label=f"Joint {i}",
197                initial_value=(0.0, 0.0, 0.0),
198                step=0.05,
199            )
200            gui_joints.append(gui_joint)
201
202            def set_callback_in_closure(i: int) -> None:
203                @gui_joint.on_update
204                def _(_):
205                    transform_controls[i].wxyz = tf.SO3.exp(
206                        np.array(gui_joints[i].value)
207                    ).wxyz
208                    out.changed = True
209
210            set_callback_in_closure(i)
211
212    # Transform control gizmos on joints.
213    transform_controls: list[viser.TransformControlsHandle] = []
214    prefixed_joint_names = []  # Joint names, but prefixed with parents.
215    for i in range(num_joints):
216        prefixed_joint_name = f"joint_{i}"
217        if i > 0:
218            prefixed_joint_name = (
219                prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name
220            )
221        prefixed_joint_names.append(prefixed_joint_name)
222        controls = server.scene.add_transform_controls(
223            f"/smpl/{prefixed_joint_name}",
224            depth_test=False,
225            scale=0.2 * (0.75 ** prefixed_joint_name.count("/")),
226            disable_axes=True,
227            disable_sliders=True,
228            visible=gui_show_controls.value,
229        )
230        transform_controls.append(controls)
231
232        def set_callback_in_closure(i: int) -> None:
233            @controls.on_update
234            def _(_) -> None:
235                axisangle = tf.SO3(transform_controls[i].wxyz).log()
236                gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2])
237
238        set_callback_in_closure(i)
239
240    out = GuiElements(
241        gui_rgb,
242        gui_wireframe,
243        gui_betas,
244        gui_joints,
245        transform_controls=transform_controls,
246        changed=True,
247    )
248    return out
249
250
251if __name__ == "__main__":
252    tyro.cli(main, description=__doc__)