SMPL skinned mesh¶
Visualize SMPL human body models using skinned mesh deformation. For visualizing human motion sequences, updating bone transformations should be much faster than updating the entire mesh.
- See here for download instructions:
https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model
Features:
SMPL skinned mesh with real-time deformation
Interactive pose parameter controls
Note
This example requires external assets. To download them, run:
git clone -b v1.0.26 https://github.com/viser-project/viser.git
cd viser/examples
./assets/download_assets.sh
python 04_demos/04_smpl_skinned.py # With viser installed.
Source: examples/04_demos/04_smpl_skinned.py
Code¶
1# mypy: disable-error-code="assignment"
2#
3# Asymmetric properties are supported in Pyright, but not yet in mypy.
4# - https://github.com/python/mypy/issues/3004
5# - https://github.com/python/mypy/pull/11643
6
7from __future__ import annotations
8
9import time
10from dataclasses import dataclass
11from pathlib import Path
12from typing import List, Tuple
13
14import numpy as np
15import tyro
16
17import viser
18import viser.transforms as tf
19
20
21@dataclass(frozen=True)
22class SmplFkOutputs:
23 T_world_joint: np.ndarray # (num_joints, 4, 4)
24 T_parent_joint: np.ndarray # (num_joints, 4, 4)
25
26
27class SmplHelper:
28
29 def __init__(self, model_path: Path) -> None:
30 assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
31 body_dict = dict(**np.load(model_path, allow_pickle=True))
32
33 self.J_regressor = body_dict["J_regressor"]
34 self.weights = body_dict["weights"]
35 self.v_template = body_dict["v_template"]
36 self.posedirs = body_dict["posedirs"]
37 self.shapedirs = body_dict["shapedirs"]
38 self.faces = body_dict["f"]
39
40 self.num_joints: int = self.weights.shape[-1]
41 self.num_betas: int = self.shapedirs.shape[-1]
42 self.parent_idx: np.ndarray = body_dict["kintree_table"][0]
43
44 def get_tpose(self, betas: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
45 # Get shaped vertices + joint positions, when all local poses are identity.
46 v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
47 j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
48 return v_tpose, j_tpose
49
50 def get_outputs(
51 self, betas: np.ndarray, joint_rotmats: np.ndarray
52 ) -> SmplFkOutputs:
53 # Get shaped vertices + joint positions, when all local poses are identity.
54 v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
55 j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
56
57 # Local SE(3) transforms.
58 T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
59 T_parent_joint[:, :3, :3] = joint_rotmats
60 T_parent_joint[0, :3, 3] = j_tpose[0]
61 T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]]
62
63 # Forward kinematics.
64 T_world_joint = T_parent_joint.copy()
65 for i in range(1, self.num_joints):
66 T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i]
67
68 return SmplFkOutputs(T_world_joint, T_parent_joint)
69
70
71def main(
72 model_path: Path = Path(__file__).parent / "../assets/SMPLH_NEUTRAL.npz",
73) -> None:
74 server = viser.ViserServer()
75 server.scene.set_up_direction("+y")
76 server.initial_camera.position = (2.5, 1.0, 2.5)
77
78 # Main loop. We'll read pose/shape from the GUI elements, compute the mesh,
79 # and then send the updated mesh in a loop.
80 model = SmplHelper(model_path)
81 gui_elements = make_gui_elements(
82 server,
83 num_betas=model.num_betas,
84 num_joints=model.num_joints,
85 parent_idx=model.parent_idx,
86 )
87 v_tpose, j_tpose = model.get_tpose(np.zeros((model.num_betas,)))
88 mesh_handle = server.scene.add_mesh_skinned(
89 "/human",
90 v_tpose,
91 model.faces,
92 bone_wxyzs=tf.SO3.identity(batch_axes=(model.num_joints,)).wxyz,
93 bone_positions=j_tpose,
94 skin_weights=model.weights,
95 wireframe=gui_elements.gui_wireframe.value,
96 color=gui_elements.gui_rgb.value,
97 )
98 server.scene.add_grid("/grid", position=(0.0, -1.3, 0.0), plane="xz")
99
100 while True:
101 # Do nothing if no change.
102 time.sleep(0.02)
103 if not gui_elements.changed:
104 continue
105
106 # Shapes changed: update vertices / joint positions.
107 if gui_elements.betas_changed:
108 v_tpose, j_tpose = model.get_tpose(
109 np.array([gui_beta.value for gui_beta in gui_elements.gui_betas])
110 )
111 mesh_handle.vertices = v_tpose
112 mesh_handle.bone_positions = j_tpose
113
114 mesh_handle.color = gui_elements.gui_rgb.value
115 gui_elements.changed = False
116 gui_elements.betas_changed = False
117
118 # Render as wireframe?
119 mesh_handle.wireframe = gui_elements.gui_wireframe.value
120
121 # Compute SMPL outputs.
122 smpl_outputs = model.get_outputs(
123 betas=np.array([x.value for x in gui_elements.gui_betas]),
124 joint_rotmats=np.stack(
125 [
126 tf.SO3.exp(np.array(x.value)).as_matrix()
127 for x in gui_elements.gui_joints
128 ],
129 axis=0,
130 ),
131 )
132
133 # Match transform control gizmos to joint positions.
134 for i, control in enumerate(gui_elements.transform_controls):
135 control.position = smpl_outputs.T_parent_joint[i, :3, 3]
136 mesh_handle.bones[i].wxyz = tf.SO3.from_matrix(
137 smpl_outputs.T_world_joint[i, :3, :3]
138 ).wxyz
139 mesh_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3]
140
141
142@dataclass
143class GuiElements:
144
145 gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]]
146 gui_wireframe: viser.GuiInputHandle[bool]
147 gui_betas: List[viser.GuiInputHandle[float]]
148 gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]]
149 transform_controls: List[viser.TransformControlsHandle]
150
151 changed: bool
152
153 betas_changed: bool
154
155
156def make_gui_elements(
157 server: viser.ViserServer,
158 num_betas: int,
159 num_joints: int,
160 parent_idx: np.ndarray,
161) -> GuiElements:
162
163 tab_group = server.gui.add_tab_group()
164
165 def set_changed(_) -> None:
166 out.changed = True # out is defined later!
167
168 def set_betas_changed(_) -> None:
169 out.betas_changed = True
170 out.changed = True
171
172 # GUI elements: mesh settings + visibility.
173 with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
174 gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255))
175 gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False)
176 gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True)
177 gui_control_size = server.gui.add_slider(
178 "Handle size", min=0.0, max=10.0, step=0.01, initial_value=1.0
179 )
180
181 gui_rgb.on_update(set_changed)
182 gui_wireframe.on_update(set_changed)
183
184 @gui_show_controls.on_update
185 def _(_):
186 for control in transform_controls:
187 control.visible = gui_show_controls.value
188
189 @gui_control_size.on_update
190 def _(_):
191 for control in transform_controls:
192 prefixed_joint_name = control.name
193 control.scale = (
194 0.2
195 * (0.75 ** prefixed_joint_name.count("/"))
196 * gui_control_size.value
197 )
198
199 # GUI elements: shape parameters.
200 with tab_group.add_tab("Shape", viser.Icon.BOX):
201 gui_reset_shape = server.gui.add_button("Reset Shape")
202 gui_random_shape = server.gui.add_button("Random Shape")
203
204 @gui_reset_shape.on_click
205 def _(_):
206 for beta in gui_betas:
207 beta.value = 0.0
208
209 @gui_random_shape.on_click
210 def _(_):
211 for beta in gui_betas:
212 beta.value = np.random.normal(loc=0.0, scale=1.0)
213
214 gui_betas = []
215 for i in range(num_betas):
216 beta = server.gui.add_slider(
217 f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0
218 )
219 gui_betas.append(beta)
220 beta.on_update(set_betas_changed)
221
222 # GUI elements: joint angles.
223 with tab_group.add_tab("Joints", viser.Icon.ANGLE):
224 gui_reset_joints = server.gui.add_button("Reset Joints")
225 gui_random_joints = server.gui.add_button("Random Joints")
226
227 @gui_reset_joints.on_click
228 def _(_):
229 for joint in gui_joints:
230 joint.value = (0.0, 0.0, 0.0)
231
232 @gui_random_joints.on_click
233 def _(_):
234 rng = np.random.default_rng()
235 for joint in gui_joints:
236 joint.value = tf.SO3.sample_uniform(rng).log()
237
238 gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = []
239 for i in range(num_joints):
240 gui_joint = server.gui.add_vector3(
241 label=f"Joint {i}",
242 initial_value=(0.0, 0.0, 0.0),
243 step=0.05,
244 )
245 gui_joints.append(gui_joint)
246
247 def set_callback_in_closure(i: int) -> None:
248 @gui_joint.on_update
249 def _(_):
250 transform_controls[i].wxyz = tf.SO3.exp(
251 np.array(gui_joints[i].value)
252 ).wxyz
253 out.changed = True
254
255 set_callback_in_closure(i)
256
257 # Transform control gizmos on joints.
258 transform_controls: List[viser.TransformControlsHandle] = []
259 prefixed_joint_names = [] # Joint names, but prefixed with parents.
260 for i in range(num_joints):
261 prefixed_joint_name = f"joint_{i}"
262 if i > 0:
263 prefixed_joint_name = (
264 prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name
265 )
266 prefixed_joint_names.append(prefixed_joint_name)
267 controls = server.scene.add_transform_controls(
268 f"/smpl/{prefixed_joint_name}",
269 depth_test=False,
270 scale=0.2 * (0.75 ** prefixed_joint_name.count("/")),
271 disable_axes=True,
272 disable_sliders=True,
273 visible=gui_show_controls.value,
274 )
275 transform_controls.append(controls)
276
277 def set_callback_in_closure(i: int) -> None:
278 @controls.on_update
279 def _(_) -> None:
280 axisangle = tf.SO3(transform_controls[i].wxyz).log()
281 gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2])
282
283 set_callback_in_closure(i)
284
285 out = GuiElements(
286 gui_rgb,
287 gui_wireframe,
288 gui_betas,
289 gui_joints,
290 transform_controls=transform_controls,
291 changed=True,
292 betas_changed=False,
293 )
294 return out
295
296
297if __name__ == "__main__":
298 tyro.cli(main, description=__doc__)