SMPL visualizer (Skinned Mesh)ΒΆ
Requires a .npz model file.
- See here for download instructions:
https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model
1from __future__ import annotations
2
3import time
4from dataclasses import dataclass
5from pathlib import Path
6from typing import List, Tuple
7
8import numpy as np
9import tyro
10
11import viser
12import viser.transforms as tf
13
14
15@dataclass(frozen=True)
16class SmplFkOutputs:
17 T_world_joint: np.ndarray # (num_joints, 4, 4)
18 T_parent_joint: np.ndarray # (num_joints, 4, 4)
19
20
21class SmplHelper:
22 """Helper for models in the SMPL family, implemented in numpy. Does not include blend skinning."""
23
24 def __init__(self, model_path: Path) -> None:
25 assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
26 body_dict = dict(**np.load(model_path, allow_pickle=True))
27
28 self.J_regressor = body_dict["J_regressor"]
29 self.weights = body_dict["weights"]
30 self.v_template = body_dict["v_template"]
31 self.posedirs = body_dict["posedirs"]
32 self.shapedirs = body_dict["shapedirs"]
33 self.faces = body_dict["f"]
34
35 self.num_joints: int = self.weights.shape[-1]
36 self.num_betas: int = self.shapedirs.shape[-1]
37 self.parent_idx: np.ndarray = body_dict["kintree_table"][0]
38
39 def get_tpose(self, betas: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
40 # Get shaped vertices + joint positions, when all local poses are identity.
41 v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
42 j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
43 return v_tpose, j_tpose
44
45 def get_outputs(
46 self, betas: np.ndarray, joint_rotmats: np.ndarray
47 ) -> SmplFkOutputs:
48 # Get shaped vertices + joint positions, when all local poses are identity.
49 v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
50 j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
51
52 # Local SE(3) transforms.
53 T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
54 T_parent_joint[:, :3, :3] = joint_rotmats
55 T_parent_joint[0, :3, 3] = j_tpose[0]
56 T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]]
57
58 # Forward kinematics.
59 T_world_joint = T_parent_joint.copy()
60 for i in range(1, self.num_joints):
61 T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i]
62
63 return SmplFkOutputs(T_world_joint, T_parent_joint)
64
65
66def main(model_path: Path) -> None:
67 server = viser.ViserServer()
68 server.scene.set_up_direction("+y")
69
70 # Main loop. We'll read pose/shape from the GUI elements, compute the mesh,
71 # and then send the updated mesh in a loop.
72 model = SmplHelper(model_path)
73 gui_elements = make_gui_elements(
74 server,
75 num_betas=model.num_betas,
76 num_joints=model.num_joints,
77 parent_idx=model.parent_idx,
78 )
79 v_tpose, j_tpose = model.get_tpose(np.zeros((model.num_betas,)))
80 mesh_handle = server.scene.add_mesh_skinned(
81 "/human",
82 v_tpose,
83 model.faces,
84 bone_wxyzs=tf.SO3.identity(batch_axes=(model.num_joints,)).wxyz,
85 bone_positions=j_tpose,
86 skin_weights=model.weights,
87 wireframe=gui_elements.gui_wireframe.value,
88 color=gui_elements.gui_rgb.value,
89 )
90 server.scene.add_grid("/grid", position=(0.0, -1.3, 0.0), plane="xz")
91
92 while True:
93 # Do nothing if no change.
94 time.sleep(0.02)
95 if not gui_elements.changed:
96 continue
97
98 # Shapes changed: update vertices / joint positions.
99 if gui_elements.betas_changed:
100 v_tpose, j_tpose = model.get_tpose(
101 np.array([gui_beta.value for gui_beta in gui_elements.gui_betas])
102 )
103 mesh_handle.vertices = v_tpose
104 mesh_handle.bone_positions = j_tpose
105
106 mesh_handle.color = gui_elements.gui_rgb.value
107 gui_elements.changed = False
108 gui_elements.betas_changed = False
109
110 # Render as wireframe?
111 mesh_handle.wireframe = gui_elements.gui_wireframe.value
112
113 # Compute SMPL outputs.
114 smpl_outputs = model.get_outputs(
115 betas=np.array([x.value for x in gui_elements.gui_betas]),
116 joint_rotmats=np.stack(
117 [
118 tf.SO3.exp(np.array(x.value)).as_matrix()
119 for x in gui_elements.gui_joints
120 ],
121 axis=0,
122 ),
123 )
124
125 # Match transform control gizmos to joint positions.
126 for i, control in enumerate(gui_elements.transform_controls):
127 control.position = smpl_outputs.T_parent_joint[i, :3, 3]
128 mesh_handle.bones[i].wxyz = tf.SO3.from_matrix(
129 smpl_outputs.T_world_joint[i, :3, :3]
130 ).wxyz
131 mesh_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3]
132
133
134@dataclass
135class GuiElements:
136 """Structure containing handles for reading from GUI elements."""
137
138 gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]]
139 gui_wireframe: viser.GuiInputHandle[bool]
140 gui_betas: List[viser.GuiInputHandle[float]]
141 gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]]
142 transform_controls: List[viser.TransformControlsHandle]
143
144 changed: bool
145 """This flag will be flipped to True whenever any input is changed."""
146
147 betas_changed: bool
148 """This flag will be flipped to True whenever the shape changes."""
149
150
151def make_gui_elements(
152 server: viser.ViserServer,
153 num_betas: int,
154 num_joints: int,
155 parent_idx: np.ndarray,
156) -> GuiElements:
157 """Make GUI elements for interacting with the model."""
158
159 tab_group = server.gui.add_tab_group()
160
161 def set_changed(_) -> None:
162 out.changed = True # out is defined later!
163
164 def set_betas_changed(_) -> None:
165 out.betas_changed = True
166 out.changed = True
167
168 # GUI elements: mesh settings + visibility.
169 with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
170 gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255))
171 gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False)
172 gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True)
173 gui_control_size = server.gui.add_slider(
174 "Handle size", min=0.0, max=10.0, step=0.01, initial_value=1.0
175 )
176
177 gui_rgb.on_update(set_changed)
178 gui_wireframe.on_update(set_changed)
179
180 @gui_show_controls.on_update
181 def _(_):
182 for control in transform_controls:
183 control.visible = gui_show_controls.value
184
185 @gui_control_size.on_update
186 def _(_):
187 for control in transform_controls:
188 prefixed_joint_name = control.name
189 control.scale = (
190 0.2
191 * (0.75 ** prefixed_joint_name.count("/"))
192 * gui_control_size.value
193 )
194
195 # GUI elements: shape parameters.
196 with tab_group.add_tab("Shape", viser.Icon.BOX):
197 gui_reset_shape = server.gui.add_button("Reset Shape")
198 gui_random_shape = server.gui.add_button("Random Shape")
199
200 @gui_reset_shape.on_click
201 def _(_):
202 for beta in gui_betas:
203 beta.value = 0.0
204
205 @gui_random_shape.on_click
206 def _(_):
207 for beta in gui_betas:
208 beta.value = np.random.normal(loc=0.0, scale=1.0)
209
210 gui_betas = []
211 for i in range(num_betas):
212 beta = server.gui.add_slider(
213 f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0
214 )
215 gui_betas.append(beta)
216 beta.on_update(set_betas_changed)
217
218 # GUI elements: joint angles.
219 with tab_group.add_tab("Joints", viser.Icon.ANGLE):
220 gui_reset_joints = server.gui.add_button("Reset Joints")
221 gui_random_joints = server.gui.add_button("Random Joints")
222
223 @gui_reset_joints.on_click
224 def _(_):
225 for joint in gui_joints:
226 joint.value = (0.0, 0.0, 0.0)
227
228 @gui_random_joints.on_click
229 def _(_):
230 rng = np.random.default_rng()
231 for joint in gui_joints:
232 joint.value = tf.SO3.sample_uniform(rng).log()
233
234 gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = []
235 for i in range(num_joints):
236 gui_joint = server.gui.add_vector3(
237 label=f"Joint {i}",
238 initial_value=(0.0, 0.0, 0.0),
239 step=0.05,
240 )
241 gui_joints.append(gui_joint)
242
243 def set_callback_in_closure(i: int) -> None:
244 @gui_joint.on_update
245 def _(_):
246 transform_controls[i].wxyz = tf.SO3.exp(
247 np.array(gui_joints[i].value)
248 ).wxyz
249 out.changed = True
250
251 set_callback_in_closure(i)
252
253 # Transform control gizmos on joints.
254 transform_controls: List[viser.TransformControlsHandle] = []
255 prefixed_joint_names = [] # Joint names, but prefixed with parents.
256 for i in range(num_joints):
257 prefixed_joint_name = f"joint_{i}"
258 if i > 0:
259 prefixed_joint_name = (
260 prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name
261 )
262 prefixed_joint_names.append(prefixed_joint_name)
263 controls = server.scene.add_transform_controls(
264 f"/smpl/{prefixed_joint_name}",
265 depth_test=False,
266 scale=0.2 * (0.75 ** prefixed_joint_name.count("/")),
267 disable_axes=True,
268 disable_sliders=True,
269 visible=gui_show_controls.value,
270 )
271 transform_controls.append(controls)
272
273 def set_callback_in_closure(i: int) -> None:
274 @controls.on_update
275 def _(_) -> None:
276 axisangle = tf.SO3(transform_controls[i].wxyz).log()
277 gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2])
278
279 set_callback_in_closure(i)
280
281 out = GuiElements(
282 gui_rgb,
283 gui_wireframe,
284 gui_betas,
285 gui_joints,
286 transform_controls=transform_controls,
287 changed=True,
288 betas_changed=False,
289 )
290 return out
291
292
293if __name__ == "__main__":
294 tyro.cli(main, description=__doc__)