.. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. SMPL visualizer (Skinned Mesh) ========================================== Requires a .npz model file. See here for download instructions: https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model .. code-block:: python :linenos: from __future__ import annotations import time from dataclasses import dataclass from pathlib import Path from typing import List, Tuple import numpy as np import tyro import viser import viser.transforms as tf @dataclass(frozen=True) class SmplFkOutputs: T_world_joint: np.ndarray # (num_joints, 4, 4) T_parent_joint: np.ndarray # (num_joints, 4, 4) class SmplHelper: """Helper for models in the SMPL family, implemented in numpy. Does not include blend skinning.""" def __init__(self, model_path: Path) -> None: assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!" body_dict = dict(**np.load(model_path, allow_pickle=True)) self.J_regressor = body_dict["J_regressor"] self.weights = body_dict["weights"] self.v_template = body_dict["v_template"] self.posedirs = body_dict["posedirs"] self.shapedirs = body_dict["shapedirs"] self.faces = body_dict["f"] self.num_joints: int = self.weights.shape[-1] self.num_betas: int = self.shapedirs.shape[-1] self.parent_idx: np.ndarray = body_dict["kintree_table"][0] def get_tpose(self, betas: np.ndarray) -> tuple[np.ndarray, np.ndarray]: # Get shaped vertices + joint positions, when all local poses are identity. v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas) j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose) return v_tpose, j_tpose def get_outputs( self, betas: np.ndarray, joint_rotmats: np.ndarray ) -> SmplFkOutputs: # Get shaped vertices + joint positions, when all local poses are identity. v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas) j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose) # Local SE(3) transforms. T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4) T_parent_joint[:, :3, :3] = joint_rotmats T_parent_joint[0, :3, 3] = j_tpose[0] T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]] # Forward kinematics. T_world_joint = T_parent_joint.copy() for i in range(1, self.num_joints): T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i] return SmplFkOutputs(T_world_joint, T_parent_joint) def main(model_path: Path) -> None: server = viser.ViserServer() server.scene.set_up_direction("+y") server.gui.configure_theme(control_layout="collapsible") # Main loop. We'll read pose/shape from the GUI elements, compute the mesh, # and then send the updated mesh in a loop. model = SmplHelper(model_path) gui_elements = make_gui_elements( server, num_betas=model.num_betas, num_joints=model.num_joints, parent_idx=model.parent_idx, ) v_tpose, j_tpose = model.get_tpose(np.zeros((model.num_betas,))) mesh_handle = server.scene.add_mesh_skinned( "/human", v_tpose, model.faces, bone_wxyzs=tf.SO3.identity(batch_axes=(model.num_joints,)).wxyz, bone_positions=j_tpose, skin_weights=model.weights, wireframe=gui_elements.gui_wireframe.value, color=gui_elements.gui_rgb.value, ) while True: # Do nothing if no change. time.sleep(0.02) if not gui_elements.changed: continue # Shapes changed: update vertices / joint positions. if gui_elements.betas_changed: v_tpose, j_tpose = model.get_tpose( np.array([gui_beta.value for gui_beta in gui_elements.gui_betas]) ) mesh_handle.vertices = v_tpose mesh_handle.bone_positions = j_tpose gui_elements.changed = False gui_elements.betas_changed = False # Render as wireframe? mesh_handle.wireframe = gui_elements.gui_wireframe.value # Compute SMPL outputs. smpl_outputs = model.get_outputs( betas=np.array([x.value for x in gui_elements.gui_betas]), joint_rotmats=np.stack( [ tf.SO3.exp(np.array(x.value)).as_matrix() for x in gui_elements.gui_joints ], axis=0, ), ) # Match transform control gizmos to joint positions. for i, control in enumerate(gui_elements.transform_controls): control.position = smpl_outputs.T_parent_joint[i, :3, 3] mesh_handle.bones[i].wxyz = tf.SO3.from_matrix( smpl_outputs.T_world_joint[i, :3, :3] ).wxyz mesh_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3] @dataclass class GuiElements: """Structure containing handles for reading from GUI elements.""" gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]] gui_wireframe: viser.GuiInputHandle[bool] gui_betas: List[viser.GuiInputHandle[float]] gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] transform_controls: List[viser.TransformControlsHandle] changed: bool """This flag will be flipped to True whenever any input is changed.""" betas_changed: bool """This flag will be flipped to True whenever the shape changes.""" def make_gui_elements( server: viser.ViserServer, num_betas: int, num_joints: int, parent_idx: np.ndarray, ) -> GuiElements: """Make GUI elements for interacting with the model.""" tab_group = server.gui.add_tab_group() def set_changed(_) -> None: out.changed = True # out is defined later! def set_betas_changed(_) -> None: out.betas_changed = True out.changed = True # GUI elements: mesh settings + visibility. with tab_group.add_tab("View", viser.Icon.VIEWFINDER): gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255)) gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False) gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True) gui_control_size = server.gui.add_slider( "Handle size", min=0.0, max=10.0, step=0.01, initial_value=1.0 ) gui_rgb.on_update(set_changed) gui_wireframe.on_update(set_changed) @gui_show_controls.on_update def _(_): for control in transform_controls: control.visible = gui_show_controls.value @gui_control_size.on_update def _(_): for control in transform_controls: prefixed_joint_name = control.name control.scale = ( 0.2 * (0.75 ** prefixed_joint_name.count("/")) * gui_control_size.value ) # GUI elements: shape parameters. with tab_group.add_tab("Shape", viser.Icon.BOX): gui_reset_shape = server.gui.add_button("Reset Shape") gui_random_shape = server.gui.add_button("Random Shape") @gui_reset_shape.on_click def _(_): for beta in gui_betas: beta.value = 0.0 @gui_random_shape.on_click def _(_): for beta in gui_betas: beta.value = np.random.normal(loc=0.0, scale=1.0) gui_betas = [] for i in range(num_betas): beta = server.gui.add_slider( f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0 ) gui_betas.append(beta) beta.on_update(set_betas_changed) # GUI elements: joint angles. with tab_group.add_tab("Joints", viser.Icon.ANGLE): gui_reset_joints = server.gui.add_button("Reset Joints") gui_random_joints = server.gui.add_button("Random Joints") @gui_reset_joints.on_click def _(_): for joint in gui_joints: joint.value = (0.0, 0.0, 0.0) @gui_random_joints.on_click def _(_): rng = np.random.default_rng() for joint in gui_joints: joint.value = tf.SO3.sample_uniform(rng).log() gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = [] for i in range(num_joints): gui_joint = server.gui.add_vector3( label=f"Joint {i}", initial_value=(0.0, 0.0, 0.0), step=0.05, ) gui_joints.append(gui_joint) def set_callback_in_closure(i: int) -> None: @gui_joint.on_update def _(_): transform_controls[i].wxyz = tf.SO3.exp( np.array(gui_joints[i].value) ).wxyz out.changed = True set_callback_in_closure(i) # Transform control gizmos on joints. transform_controls: List[viser.TransformControlsHandle] = [] prefixed_joint_names = [] # Joint names, but prefixed with parents. for i in range(num_joints): prefixed_joint_name = f"joint_{i}" if i > 0: prefixed_joint_name = ( prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name ) prefixed_joint_names.append(prefixed_joint_name) controls = server.scene.add_transform_controls( f"/smpl/{prefixed_joint_name}", depth_test=False, scale=0.2 * (0.75 ** prefixed_joint_name.count("/")), disable_axes=True, disable_sliders=True, visible=gui_show_controls.value, ) transform_controls.append(controls) def set_callback_in_closure(i: int) -> None: @controls.on_update def _(_) -> None: axisangle = tf.SO3(transform_controls[i].wxyz).log() gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2]) set_callback_in_closure(i) out = GuiElements( gui_rgb, gui_wireframe, gui_betas, gui_joints, transform_controls=transform_controls, changed=True, betas_changed=False, ) return out if __name__ == "__main__": tyro.cli(main, description=__doc__)