Visualizer for SMPL human body models. Requires a .npz model file.#
- See here for download instructions:
https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model
1from __future__ import annotations
2
3import time
4from dataclasses import dataclass
5from pathlib import Path
6from typing import List, Tuple
7
8import numpy as np
9import numpy as onp
10import tyro
11import viser
12import viser.transforms as tf
13
14
15@dataclass(frozen=True)
16class SmplOutputs:
17 vertices: np.ndarray
18 faces: np.ndarray
19 T_world_joint: np.ndarray # (num_joints, 4, 4)
20 T_parent_joint: np.ndarray # (num_joints, 4, 4)
21
22
23class SmplHelper:
24 """Helper for models in the SMPL family, implemented in numpy."""
25
26 def __init__(self, model_path: Path) -> None:
27 assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
28 body_dict = dict(**onp.load(model_path, allow_pickle=True))
29
30 self._J_regressor = body_dict["J_regressor"]
31 self._weights = body_dict["weights"]
32 self._v_template = body_dict["v_template"]
33 self._posedirs = body_dict["posedirs"]
34 self._shapedirs = body_dict["shapedirs"]
35 self._faces = body_dict["f"]
36
37 self.num_joints: int = self._weights.shape[-1]
38 self.num_betas: int = self._shapedirs.shape[-1]
39 self.parent_idx: np.ndarray = body_dict["kintree_table"][0]
40
41 def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs:
42 # Get shaped vertices + joint positions, when all local poses are identity.
43 v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas)
44 j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose)
45
46 # Local SE(3) transforms.
47 T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
48 T_parent_joint[:, :3, :3] = joint_rotmats
49 T_parent_joint[0, :3, 3] = j_tpose[0]
50 T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]]
51
52 # Forward kinematics.
53 T_world_joint = T_parent_joint.copy()
54 for i in range(1, self.num_joints):
55 T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i]
56
57 # Linear blend skinning.
58 pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten()
59 v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta)
60 v_delta = np.ones((v_blend.shape[0], self.num_joints, 4))
61 v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :]
62 v_posed = np.einsum(
63 "jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta
64 )
65 return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint)
66
67
68def main(model_path: Path) -> None:
69 server = viser.ViserServer()
70 server.set_up_direction("+y")
71 server.configure_theme(control_layout="collapsible")
72
73 # Main loop. We'll read pose/shape from the GUI elements, compute the mesh,
74 # and then send the updated mesh in a loop.
75 model = SmplHelper(model_path)
76 gui_elements = make_gui_elements(
77 server,
78 num_betas=model.num_betas,
79 num_joints=model.num_joints,
80 parent_idx=model.parent_idx,
81 )
82 while True:
83 # Do nothing if no change.
84 time.sleep(0.02)
85 if not gui_elements.changed:
86 continue
87
88 gui_elements.changed = False
89
90 # Compute SMPL outputs.
91 smpl_outputs = model.get_outputs(
92 betas=np.array([x.value for x in gui_elements.gui_betas]),
93 joint_rotmats=tf.SO3.exp(
94 # (num_joints, 3)
95 np.array([x.value for x in gui_elements.gui_joints])
96 ).as_matrix(),
97 )
98 server.add_mesh_simple(
99 "/human",
100 smpl_outputs.vertices,
101 smpl_outputs.faces,
102 wireframe=gui_elements.gui_wireframe.value,
103 color=gui_elements.gui_rgb.value,
104 )
105
106 # Match transform control gizmos to joint positions.
107 for i, control in enumerate(gui_elements.transform_controls):
108 control.position = smpl_outputs.T_parent_joint[i, :3, 3]
109
110
111@dataclass
112class GuiElements:
113 """Structure containing handles for reading from GUI elements."""
114
115 gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]]
116 gui_wireframe: viser.GuiInputHandle[bool]
117 gui_betas: List[viser.GuiInputHandle[float]]
118 gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]]
119 transform_controls: List[viser.TransformControlsHandle]
120
121 changed: bool
122 """This flag will be flipped to True whenever the mesh needs to be re-generated."""
123
124
125def make_gui_elements(
126 server: viser.ViserServer,
127 num_betas: int,
128 num_joints: int,
129 parent_idx: np.ndarray,
130) -> GuiElements:
131 """Make GUI elements for interacting with the model."""
132
133 tab_group = server.add_gui_tab_group()
134
135 def set_changed(_) -> None:
136 out.changed = True # out is define later!
137
138 # GUI elements: mesh settings + visibility.
139 with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
140 gui_rgb = server.add_gui_rgb("Color", initial_value=(90, 200, 255))
141 gui_wireframe = server.add_gui_checkbox("Wireframe", initial_value=False)
142 gui_show_controls = server.add_gui_checkbox("Handles", initial_value=False)
143
144 gui_rgb.on_update(set_changed)
145 gui_wireframe.on_update(set_changed)
146
147 @gui_show_controls.on_update
148 def _(_):
149 for control in transform_controls:
150 control.visible = gui_show_controls.value
151
152 # GUI elements: shape parameters.
153 with tab_group.add_tab("Shape", viser.Icon.BOX):
154 gui_reset_shape = server.add_gui_button("Reset Shape")
155 gui_random_shape = server.add_gui_button("Random Shape")
156
157 @gui_reset_shape.on_click
158 def _(_):
159 for beta in gui_betas:
160 beta.value = 0.0
161
162 @gui_random_shape.on_click
163 def _(_):
164 for beta in gui_betas:
165 beta.value = onp.random.normal(loc=0.0, scale=1.0)
166
167 gui_betas = []
168 for i in range(num_betas):
169 beta = server.add_gui_slider(
170 f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0
171 )
172 gui_betas.append(beta)
173 beta.on_update(set_changed)
174
175 # GUI elements: joint angles.
176 with tab_group.add_tab("Joints", viser.Icon.ANGLE):
177 gui_reset_joints = server.add_gui_button("Reset Joints")
178 gui_random_joints = server.add_gui_button("Random Joints")
179
180 @gui_reset_joints.on_click
181 def _(_):
182 for joint in gui_joints:
183 joint.value = (0.0, 0.0, 0.0)
184
185 @gui_random_joints.on_click
186 def _(_):
187 for joint in gui_joints:
188 # It's hard to uniformly sample orientations directly in so(3), so we
189 # first sample on S^3 and then convert.
190 quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,))
191 quat /= onp.linalg.norm(quat)
192 joint.value = tf.SO3(wxyz=quat).log()
193
194 gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = []
195 for i in range(num_joints):
196 gui_joint = server.add_gui_vector3(
197 label=f"Joint {i}",
198 initial_value=(0.0, 0.0, 0.0),
199 step=0.05,
200 )
201 gui_joints.append(gui_joint)
202
203 def set_callback_in_closure(i: int) -> None:
204 @gui_joint.on_update
205 def _(_):
206 transform_controls[i].wxyz = tf.SO3.exp(
207 np.array(gui_joints[i].value)
208 ).wxyz
209 out.changed = True
210
211 set_callback_in_closure(i)
212
213 # Transform control gizmos on joints.
214 transform_controls: List[viser.TransformControlsHandle] = []
215 prefixed_joint_names = [] # Joint names, but prefixed with parents.
216 for i in range(num_joints):
217 prefixed_joint_name = f"joint_{i}"
218 if i > 0:
219 prefixed_joint_name = (
220 prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name
221 )
222 prefixed_joint_names.append(prefixed_joint_name)
223 controls = server.add_transform_controls(
224 f"/smpl/{prefixed_joint_name}",
225 depth_test=False,
226 scale=0.2 * (0.75 ** prefixed_joint_name.count("/")),
227 disable_axes=True,
228 disable_sliders=True,
229 visible=gui_show_controls.value,
230 )
231 transform_controls.append(controls)
232
233 def set_callback_in_closure(i: int) -> None:
234 @controls.on_update
235 def _(_) -> None:
236 axisangle = tf.SO3(transform_controls[i].wxyz).log()
237 gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2])
238
239 set_callback_in_closure(i)
240
241 out = GuiElements(
242 gui_rgb,
243 gui_wireframe,
244 gui_betas,
245 gui_joints,
246 transform_controls=transform_controls,
247 changed=True,
248 )
249 return out
250
251
252if __name__ == "__main__":
253 tyro.cli(main, description=__doc__)