SMPL skinned mesh¶
Visualize SMPL human body models using skinned mesh deformation. For visualizing human motion sequences, updating bone transformations should be much faster than updating the entire mesh.
- See here for download instructions:
https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model
Features:
SMPL skinned mesh with real-time deformation
Interactive pose parameter controls
Note
This example requires external assets. To download them, run:
cd /path/to/viser/examples/assets
./download_assets.sh
Source: examples/04_demos/04_smpl_skinned.py

Code¶
1# mypy: disable-error-code="assignment"
2#
3# Asymmetric properties are supported in Pyright, but not yet in mypy.
4# - https://github.com/python/mypy/issues/3004
5# - https://github.com/python/mypy/pull/11643
6
7from __future__ import annotations
8
9import time
10from dataclasses import dataclass
11from pathlib import Path
12from typing import List, Tuple
13
14import numpy as np
15import tyro
16
17import viser
18import viser.transforms as tf
19
20
21@dataclass(frozen=True)
22class SmplFkOutputs:
23 T_world_joint: np.ndarray # (num_joints, 4, 4)
24 T_parent_joint: np.ndarray # (num_joints, 4, 4)
25
26
27class SmplHelper:
28
29 def __init__(self, model_path: Path) -> None:
30 assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
31 body_dict = dict(**np.load(model_path, allow_pickle=True))
32
33 self.J_regressor = body_dict["J_regressor"]
34 self.weights = body_dict["weights"]
35 self.v_template = body_dict["v_template"]
36 self.posedirs = body_dict["posedirs"]
37 self.shapedirs = body_dict["shapedirs"]
38 self.faces = body_dict["f"]
39
40 self.num_joints: int = self.weights.shape[-1]
41 self.num_betas: int = self.shapedirs.shape[-1]
42 self.parent_idx: np.ndarray = body_dict["kintree_table"][0]
43
44 def get_tpose(self, betas: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
45 # Get shaped vertices + joint positions, when all local poses are identity.
46 v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
47 j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
48 return v_tpose, j_tpose
49
50 def get_outputs(
51 self, betas: np.ndarray, joint_rotmats: np.ndarray
52 ) -> SmplFkOutputs:
53 # Get shaped vertices + joint positions, when all local poses are identity.
54 v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
55 j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
56
57 # Local SE(3) transforms.
58 T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
59 T_parent_joint[:, :3, :3] = joint_rotmats
60 T_parent_joint[0, :3, 3] = j_tpose[0]
61 T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]]
62
63 # Forward kinematics.
64 T_world_joint = T_parent_joint.copy()
65 for i in range(1, self.num_joints):
66 T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i]
67
68 return SmplFkOutputs(T_world_joint, T_parent_joint)
69
70
71def main(
72 model_path: Path = Path(__file__).parent / "../assets/SMPLH_NEUTRAL.npz",
73) -> None:
74 server = viser.ViserServer()
75 server.scene.set_up_direction("+y")
76
77 # Main loop. We'll read pose/shape from the GUI elements, compute the mesh,
78 # and then send the updated mesh in a loop.
79 model = SmplHelper(model_path)
80 gui_elements = make_gui_elements(
81 server,
82 num_betas=model.num_betas,
83 num_joints=model.num_joints,
84 parent_idx=model.parent_idx,
85 )
86 v_tpose, j_tpose = model.get_tpose(np.zeros((model.num_betas,)))
87 mesh_handle = server.scene.add_mesh_skinned(
88 "/human",
89 v_tpose,
90 model.faces,
91 bone_wxyzs=tf.SO3.identity(batch_axes=(model.num_joints,)).wxyz,
92 bone_positions=j_tpose,
93 skin_weights=model.weights,
94 wireframe=gui_elements.gui_wireframe.value,
95 color=gui_elements.gui_rgb.value,
96 )
97 server.scene.add_grid("/grid", position=(0.0, -1.3, 0.0), plane="xz")
98
99 while True:
100 # Do nothing if no change.
101 time.sleep(0.02)
102 if not gui_elements.changed:
103 continue
104
105 # Shapes changed: update vertices / joint positions.
106 if gui_elements.betas_changed:
107 v_tpose, j_tpose = model.get_tpose(
108 np.array([gui_beta.value for gui_beta in gui_elements.gui_betas])
109 )
110 mesh_handle.vertices = v_tpose
111 mesh_handle.bone_positions = j_tpose
112
113 mesh_handle.color = gui_elements.gui_rgb.value
114 gui_elements.changed = False
115 gui_elements.betas_changed = False
116
117 # Render as wireframe?
118 mesh_handle.wireframe = gui_elements.gui_wireframe.value
119
120 # Compute SMPL outputs.
121 smpl_outputs = model.get_outputs(
122 betas=np.array([x.value for x in gui_elements.gui_betas]),
123 joint_rotmats=np.stack(
124 [
125 tf.SO3.exp(np.array(x.value)).as_matrix()
126 for x in gui_elements.gui_joints
127 ],
128 axis=0,
129 ),
130 )
131
132 # Match transform control gizmos to joint positions.
133 for i, control in enumerate(gui_elements.transform_controls):
134 control.position = smpl_outputs.T_parent_joint[i, :3, 3]
135 mesh_handle.bones[i].wxyz = tf.SO3.from_matrix(
136 smpl_outputs.T_world_joint[i, :3, :3]
137 ).wxyz
138 mesh_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3]
139
140
141@dataclass
142class GuiElements:
143
144 gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]]
145 gui_wireframe: viser.GuiInputHandle[bool]
146 gui_betas: List[viser.GuiInputHandle[float]]
147 gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]]
148 transform_controls: List[viser.TransformControlsHandle]
149
150 changed: bool
151
152 betas_changed: bool
153
154
155def make_gui_elements(
156 server: viser.ViserServer,
157 num_betas: int,
158 num_joints: int,
159 parent_idx: np.ndarray,
160) -> GuiElements:
161
162 tab_group = server.gui.add_tab_group()
163
164 def set_changed(_) -> None:
165 out.changed = True # out is defined later!
166
167 def set_betas_changed(_) -> None:
168 out.betas_changed = True
169 out.changed = True
170
171 # GUI elements: mesh settings + visibility.
172 with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
173 gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255))
174 gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False)
175 gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True)
176 gui_control_size = server.gui.add_slider(
177 "Handle size", min=0.0, max=10.0, step=0.01, initial_value=1.0
178 )
179
180 gui_rgb.on_update(set_changed)
181 gui_wireframe.on_update(set_changed)
182
183 @gui_show_controls.on_update
184 def _(_):
185 for control in transform_controls:
186 control.visible = gui_show_controls.value
187
188 @gui_control_size.on_update
189 def _(_):
190 for control in transform_controls:
191 prefixed_joint_name = control.name
192 control.scale = (
193 0.2
194 * (0.75 ** prefixed_joint_name.count("/"))
195 * gui_control_size.value
196 )
197
198 # GUI elements: shape parameters.
199 with tab_group.add_tab("Shape", viser.Icon.BOX):
200 gui_reset_shape = server.gui.add_button("Reset Shape")
201 gui_random_shape = server.gui.add_button("Random Shape")
202
203 @gui_reset_shape.on_click
204 def _(_):
205 for beta in gui_betas:
206 beta.value = 0.0
207
208 @gui_random_shape.on_click
209 def _(_):
210 for beta in gui_betas:
211 beta.value = np.random.normal(loc=0.0, scale=1.0)
212
213 gui_betas = []
214 for i in range(num_betas):
215 beta = server.gui.add_slider(
216 f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0
217 )
218 gui_betas.append(beta)
219 beta.on_update(set_betas_changed)
220
221 # GUI elements: joint angles.
222 with tab_group.add_tab("Joints", viser.Icon.ANGLE):
223 gui_reset_joints = server.gui.add_button("Reset Joints")
224 gui_random_joints = server.gui.add_button("Random Joints")
225
226 @gui_reset_joints.on_click
227 def _(_):
228 for joint in gui_joints:
229 joint.value = (0.0, 0.0, 0.0)
230
231 @gui_random_joints.on_click
232 def _(_):
233 rng = np.random.default_rng()
234 for joint in gui_joints:
235 joint.value = tf.SO3.sample_uniform(rng).log()
236
237 gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = []
238 for i in range(num_joints):
239 gui_joint = server.gui.add_vector3(
240 label=f"Joint {i}",
241 initial_value=(0.0, 0.0, 0.0),
242 step=0.05,
243 )
244 gui_joints.append(gui_joint)
245
246 def set_callback_in_closure(i: int) -> None:
247 @gui_joint.on_update
248 def _(_):
249 transform_controls[i].wxyz = tf.SO3.exp(
250 np.array(gui_joints[i].value)
251 ).wxyz
252 out.changed = True
253
254 set_callback_in_closure(i)
255
256 # Transform control gizmos on joints.
257 transform_controls: List[viser.TransformControlsHandle] = []
258 prefixed_joint_names = [] # Joint names, but prefixed with parents.
259 for i in range(num_joints):
260 prefixed_joint_name = f"joint_{i}"
261 if i > 0:
262 prefixed_joint_name = (
263 prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name
264 )
265 prefixed_joint_names.append(prefixed_joint_name)
266 controls = server.scene.add_transform_controls(
267 f"/smpl/{prefixed_joint_name}",
268 depth_test=False,
269 scale=0.2 * (0.75 ** prefixed_joint_name.count("/")),
270 disable_axes=True,
271 disable_sliders=True,
272 visible=gui_show_controls.value,
273 )
274 transform_controls.append(controls)
275
276 def set_callback_in_closure(i: int) -> None:
277 @controls.on_update
278 def _(_) -> None:
279 axisangle = tf.SO3(transform_controls[i].wxyz).log()
280 gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2])
281
282 set_callback_in_closure(i)
283
284 out = GuiElements(
285 gui_rgb,
286 gui_wireframe,
287 gui_betas,
288 gui_joints,
289 transform_controls=transform_controls,
290 changed=True,
291 betas_changed=False,
292 )
293 return out
294
295
296if __name__ == "__main__":
297 tyro.cli(main, description=__doc__)