SMPL visualizer (Skinned Mesh)#
Requires a .npz model file.
- See here for download instructions:
https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model
1from __future__ import annotations
2
3import time
4from dataclasses import dataclass
5from pathlib import Path
6from typing import List, Tuple
7
8import numpy as np
9import numpy as onp
10import tyro
11
12import viser
13import viser.transforms as tf
14
15
16@dataclass(frozen=True)
17class SmplOutputs:
18 vertices: np.ndarray
19 faces: np.ndarray
20 T_world_joint: np.ndarray # (num_joints, 4, 4)
21 T_parent_joint: np.ndarray # (num_joints, 4, 4)
22
23
24class SmplHelper:
25 """Helper for models in the SMPL family, implemented in numpy."""
26
27 def __init__(self, model_path: Path) -> None:
28 assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
29 body_dict = dict(**onp.load(model_path, allow_pickle=True))
30
31 self._J_regressor = body_dict["J_regressor"]
32 self._weights = body_dict["weights"]
33 self._v_template = body_dict["v_template"]
34 self._posedirs = body_dict["posedirs"]
35 self._shapedirs = body_dict["shapedirs"]
36 self._faces = body_dict["f"]
37
38 self.num_joints: int = self._weights.shape[-1]
39 self.num_betas: int = self._shapedirs.shape[-1]
40 self.parent_idx: np.ndarray = body_dict["kintree_table"][0]
41
42 def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs:
43 # Get shaped vertices + joint positions, when all local poses are identity.
44 v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas)
45 j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose)
46
47 # Local SE(3) transforms.
48 T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
49 T_parent_joint[:, :3, :3] = joint_rotmats
50 T_parent_joint[0, :3, 3] = j_tpose[0]
51 T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]]
52
53 # Forward kinematics.
54 T_world_joint = T_parent_joint.copy()
55 for i in range(1, self.num_joints):
56 T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i]
57
58 # Linear blend skinning.
59 pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten()
60 v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta)
61 v_delta = np.ones((v_blend.shape[0], self.num_joints, 4))
62 v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :]
63 v_posed = np.einsum(
64 "jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta
65 )
66 return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint)
67
68
69def main(model_path: Path) -> None:
70 server = viser.ViserServer()
71 server.scene.set_up_direction("+y")
72 server.gui.configure_theme(control_layout="collapsible")
73
74 # Main loop. We'll read pose/shape from the GUI elements, compute the mesh,
75 # and then send the updated mesh in a loop.
76 model = SmplHelper(model_path)
77 gui_elements = make_gui_elements(
78 server,
79 num_betas=model.num_betas,
80 num_joints=model.num_joints,
81 parent_idx=model.parent_idx,
82 )
83 smpl_outputs = model.get_outputs(
84 betas=np.array([x.value for x in gui_elements.gui_betas]),
85 joint_rotmats=onp.zeros((model.num_joints, 3, 3)) + onp.eye(3),
86 )
87
88 bone_wxyzs = np.array(
89 [tf.SO3.from_matrix(R).wxyz for R in smpl_outputs.T_world_joint[:, :3, :3]]
90 )
91 bone_positions = smpl_outputs.T_world_joint[:, :3, 3]
92
93 skinned_handle = server.scene.add_mesh_skinned(
94 "/human",
95 smpl_outputs.vertices,
96 smpl_outputs.faces,
97 bone_wxyzs=bone_wxyzs,
98 bone_positions=bone_positions,
99 skin_weights=model._weights,
100 wireframe=gui_elements.gui_wireframe.value,
101 color=gui_elements.gui_rgb.value,
102 )
103
104 while True:
105 # Do nothing if no change.
106 time.sleep(0.02)
107 if not gui_elements.changed:
108 continue
109
110 gui_elements.changed = False
111
112 # Compute SMPL outputs.
113 smpl_outputs = model.get_outputs(
114 betas=np.array([x.value for x in gui_elements.gui_betas]),
115 joint_rotmats=np.stack(
116 [
117 tf.SO3.exp(np.array(x.value)).as_matrix()
118 for x in gui_elements.gui_joints
119 ],
120 axis=0,
121 ),
122 )
123
124 # Match transform control gizmos to joint positions.
125 for i, control in enumerate(gui_elements.transform_controls):
126 control.position = smpl_outputs.T_parent_joint[i, :3, 3]
127 skinned_handle.bones[i].wxyz = tf.SO3.from_matrix(
128 smpl_outputs.T_world_joint[i, :3, :3]
129 ).wxyz
130 skinned_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3]
131
132
133@dataclass
134class GuiElements:
135 """Structure containing handles for reading from GUI elements."""
136
137 gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]]
138 gui_wireframe: viser.GuiInputHandle[bool]
139 gui_betas: List[viser.GuiInputHandle[float]]
140 gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]]
141 transform_controls: List[viser.TransformControlsHandle]
142
143 changed: bool
144 """This flag will be flipped to True whenever the mesh needs to be re-generated."""
145
146
147def make_gui_elements(
148 server: viser.ViserServer,
149 num_betas: int,
150 num_joints: int,
151 parent_idx: np.ndarray,
152) -> GuiElements:
153 """Make GUI elements for interacting with the model."""
154
155 tab_group = server.gui.add_tab_group()
156
157 def set_changed(_) -> None:
158 out.changed = True # out is define later!
159
160 # GUI elements: mesh settings + visibility.
161 with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
162 gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255))
163 gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False)
164 gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True)
165
166 gui_rgb.on_update(set_changed)
167 gui_wireframe.on_update(set_changed)
168
169 @gui_show_controls.on_update
170 def _(_):
171 for control in transform_controls:
172 control.visible = gui_show_controls.value
173
174 # GUI elements: shape parameters.
175 with tab_group.add_tab("Shape", viser.Icon.BOX):
176 gui_reset_shape = server.gui.add_button("Reset Shape")
177 gui_random_shape = server.gui.add_button("Random Shape")
178
179 @gui_reset_shape.on_click
180 def _(_):
181 for beta in gui_betas:
182 beta.value = 0.0
183
184 @gui_random_shape.on_click
185 def _(_):
186 for beta in gui_betas:
187 beta.value = onp.random.normal(loc=0.0, scale=1.0)
188
189 gui_betas = []
190 for i in range(num_betas):
191 beta = server.gui.add_slider(
192 f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0
193 )
194 gui_betas.append(beta)
195 beta.on_update(set_changed)
196
197 # GUI elements: joint angles.
198 with tab_group.add_tab("Joints", viser.Icon.ANGLE):
199 gui_reset_joints = server.gui.add_button("Reset Joints")
200 gui_random_joints = server.gui.add_button("Random Joints")
201
202 @gui_reset_joints.on_click
203 def _(_):
204 for joint in gui_joints:
205 joint.value = (0.0, 0.0, 0.0)
206
207 @gui_random_joints.on_click
208 def _(_):
209 for joint in gui_joints:
210 # It's hard to uniformly sample orientations directly in so(3), so we
211 # first sample on S^3 and then convert.
212 quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,))
213 quat /= onp.linalg.norm(quat)
214 joint.value = tf.SO3(wxyz=quat).log()
215
216 gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = []
217 for i in range(num_joints):
218 gui_joint = server.gui.add_vector3(
219 label=f"Joint {i}",
220 initial_value=(0.0, 0.0, 0.0),
221 step=0.05,
222 )
223 gui_joints.append(gui_joint)
224
225 def set_callback_in_closure(i: int) -> None:
226 @gui_joint.on_update
227 def _(_):
228 transform_controls[i].wxyz = tf.SO3.exp(
229 np.array(gui_joints[i].value)
230 ).wxyz
231 out.changed = True
232
233 set_callback_in_closure(i)
234
235 # Transform control gizmos on joints.
236 transform_controls: List[viser.TransformControlsHandle] = []
237 prefixed_joint_names = [] # Joint names, but prefixed with parents.
238 for i in range(num_joints):
239 prefixed_joint_name = f"joint_{i}"
240 if i > 0:
241 prefixed_joint_name = (
242 prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name
243 )
244 prefixed_joint_names.append(prefixed_joint_name)
245 controls = server.scene.add_transform_controls(
246 f"/smpl/{prefixed_joint_name}",
247 depth_test=False,
248 scale=0.2 * (0.75 ** prefixed_joint_name.count("/")),
249 disable_axes=True,
250 disable_sliders=True,
251 visible=gui_show_controls.value,
252 )
253 transform_controls.append(controls)
254
255 def set_callback_in_closure(i: int) -> None:
256 @controls.on_update
257 def _(_) -> None:
258 axisangle = tf.SO3(transform_controls[i].wxyz).log()
259 gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2])
260
261 set_callback_in_closure(i)
262
263 out = GuiElements(
264 gui_rgb,
265 gui_wireframe,
266 gui_betas,
267 gui_joints,
268 transform_controls=transform_controls,
269 changed=True,
270 )
271 return out
272
273
274if __name__ == "__main__":
275 tyro.cli(main, description=__doc__)