SMPL model visualizerΒΆ
Visualizer for SMPL human body models. Requires a .npz model file.
- See here for download instructions:
https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model
1from __future__ import annotations
2
3import time
4from dataclasses import dataclass
5from pathlib import Path
6
7import numpy as np
8import tyro
9import viser
10import viser.transforms as tf
11
12
13@dataclass(frozen=True)
14class SmplOutputs:
15 vertices: np.ndarray
16 faces: np.ndarray
17 T_world_joint: np.ndarray # (num_joints, 4, 4)
18 T_parent_joint: np.ndarray # (num_joints, 4, 4)
19
20
21class SmplHelper:
22 """Helper for models in the SMPL family, implemented in numpy."""
23
24 def __init__(self, model_path: Path) -> None:
25 assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
26 body_dict = dict(**np.load(model_path, allow_pickle=True))
27
28 self.J_regressor = body_dict["J_regressor"]
29 self.weights = body_dict["weights"]
30 self.v_template = body_dict["v_template"]
31 self.posedirs = body_dict["posedirs"]
32 self.shapedirs = body_dict["shapedirs"]
33 self.faces = body_dict["f"]
34
35 self.num_joints: int = self.weights.shape[-1]
36 self.num_betas: int = self.shapedirs.shape[-1]
37 self.parent_idx: np.ndarray = body_dict["kintree_table"][0]
38
39 def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs:
40 # Get shaped vertices + joint positions, when all local poses are identity.
41 v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
42 j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
43
44 # Local SE(3) transforms.
45 T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
46 T_parent_joint[:, :3, :3] = joint_rotmats
47 T_parent_joint[0, :3, 3] = j_tpose[0]
48 T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]]
49
50 # Forward kinematics.
51 T_world_joint = T_parent_joint.copy()
52 for i in range(1, self.num_joints):
53 T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i]
54
55 # Linear blend skinning.
56 pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten()
57 v_blend = v_tpose + np.einsum("byn,n->by", self.posedirs, pose_delta)
58 v_delta = np.ones((v_blend.shape[0], self.num_joints, 4))
59 v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :]
60 v_posed = np.einsum(
61 "jxy,vj,vjy->vx", T_world_joint[:, :3, :], self.weights, v_delta
62 )
63 return SmplOutputs(v_posed, self.faces, T_world_joint, T_parent_joint)
64
65
66def main(model_path: Path) -> None:
67 server = viser.ViserServer()
68 server.scene.set_up_direction("+y")
69 server.gui.configure_theme(control_layout="collapsible")
70
71 # Main loop. We'll read pose/shape from the GUI elements, compute the mesh,
72 # and then send the updated mesh in a loop.
73 model = SmplHelper(model_path)
74 gui_elements = make_gui_elements(
75 server,
76 num_betas=model.num_betas,
77 num_joints=model.num_joints,
78 parent_idx=model.parent_idx,
79 )
80 body_handle = server.scene.add_mesh_simple(
81 "/human",
82 model.v_template,
83 model.faces,
84 wireframe=gui_elements.gui_wireframe.value,
85 color=gui_elements.gui_rgb.value,
86 )
87 while True:
88 # Do nothing if no change.
89 time.sleep(0.02)
90 if not gui_elements.changed:
91 continue
92
93 gui_elements.changed = False
94
95 # If anything has changed, re-compute SMPL outputs.
96 smpl_outputs = model.get_outputs(
97 betas=np.array([x.value for x in gui_elements.gui_betas]),
98 joint_rotmats=tf.SO3.exp(
99 # (num_joints, 3)
100 np.array([x.value for x in gui_elements.gui_joints])
101 ).as_matrix(),
102 )
103
104 # Update the mesh properties based on the SMPL model output + GUI
105 # elements.
106 body_handle.vertices = smpl_outputs.vertices
107 body_handle.wireframe = gui_elements.gui_wireframe.value
108 body_handle.color = gui_elements.gui_rgb.value
109
110 # Match transform control gizmos to joint positions.
111 for i, control in enumerate(gui_elements.transform_controls):
112 control.position = smpl_outputs.T_parent_joint[i, :3, 3]
113
114
115@dataclass
116class GuiElements:
117 """Structure containing handles for reading from GUI elements."""
118
119 gui_rgb: viser.GuiInputHandle[tuple[int, int, int]]
120 gui_wireframe: viser.GuiInputHandle[bool]
121 gui_betas: list[viser.GuiInputHandle[float]]
122 gui_joints: list[viser.GuiInputHandle[tuple[float, float, float]]]
123 transform_controls: list[viser.TransformControlsHandle]
124
125 changed: bool
126 """This flag will be flipped to True whenever the mesh needs to be re-generated."""
127
128
129def make_gui_elements(
130 server: viser.ViserServer,
131 num_betas: int,
132 num_joints: int,
133 parent_idx: np.ndarray,
134) -> GuiElements:
135 """Make GUI elements for interacting with the model."""
136
137 tab_group = server.gui.add_tab_group()
138
139 def set_changed(_) -> None:
140 out.changed = True # out is define later!
141
142 # GUI elements: mesh settings + visibility.
143 with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
144 gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255))
145 gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False)
146 gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True)
147
148 gui_rgb.on_update(set_changed)
149 gui_wireframe.on_update(set_changed)
150
151 @gui_show_controls.on_update
152 def _(_):
153 for control in transform_controls:
154 control.visible = gui_show_controls.value
155
156 # GUI elements: shape parameters.
157 with tab_group.add_tab("Shape", viser.Icon.BOX):
158 gui_reset_shape = server.gui.add_button("Reset Shape")
159 gui_random_shape = server.gui.add_button("Random Shape")
160
161 @gui_reset_shape.on_click
162 def _(_):
163 for beta in gui_betas:
164 beta.value = 0.0
165
166 @gui_random_shape.on_click
167 def _(_):
168 for beta in gui_betas:
169 beta.value = np.random.normal(loc=0.0, scale=1.0)
170
171 gui_betas = []
172 for i in range(num_betas):
173 beta = server.gui.add_slider(
174 f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0
175 )
176 gui_betas.append(beta)
177 beta.on_update(set_changed)
178
179 # GUI elements: joint angles.
180 with tab_group.add_tab("Joints", viser.Icon.ANGLE):
181 gui_reset_joints = server.gui.add_button("Reset Joints")
182 gui_random_joints = server.gui.add_button("Random Joints")
183
184 @gui_reset_joints.on_click
185 def _(_):
186 for joint in gui_joints:
187 joint.value = (0.0, 0.0, 0.0)
188
189 @gui_random_joints.on_click
190 def _(_):
191 rng = np.random.default_rng()
192 for joint in gui_joints:
193 joint.value = tf.SO3.sample_uniform(rng).log()
194
195 gui_joints: list[viser.GuiInputHandle[tuple[float, float, float]]] = []
196 for i in range(num_joints):
197 gui_joint = server.gui.add_vector3(
198 label=f"Joint {i}",
199 initial_value=(0.0, 0.0, 0.0),
200 step=0.05,
201 )
202 gui_joints.append(gui_joint)
203
204 def set_callback_in_closure(i: int) -> None:
205 @gui_joint.on_update
206 def _(_):
207 transform_controls[i].wxyz = tf.SO3.exp(
208 np.array(gui_joints[i].value)
209 ).wxyz
210 out.changed = True
211
212 set_callback_in_closure(i)
213
214 # Transform control gizmos on joints.
215 transform_controls: list[viser.TransformControlsHandle] = []
216 prefixed_joint_names = [] # Joint names, but prefixed with parents.
217 for i in range(num_joints):
218 prefixed_joint_name = f"joint_{i}"
219 if i > 0:
220 prefixed_joint_name = (
221 prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name
222 )
223 prefixed_joint_names.append(prefixed_joint_name)
224 controls = server.scene.add_transform_controls(
225 f"/smpl/{prefixed_joint_name}",
226 depth_test=False,
227 scale=0.2 * (0.75 ** prefixed_joint_name.count("/")),
228 disable_axes=True,
229 disable_sliders=True,
230 visible=gui_show_controls.value,
231 )
232 transform_controls.append(controls)
233
234 def set_callback_in_closure(i: int) -> None:
235 @controls.on_update
236 def _(_) -> None:
237 axisangle = tf.SO3(transform_controls[i].wxyz).log()
238 gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2])
239
240 set_callback_in_closure(i)
241
242 out = GuiElements(
243 gui_rgb,
244 gui_wireframe,
245 gui_betas,
246 gui_joints,
247 transform_controls=transform_controls,
248 changed=True,
249 )
250 return out
251
252
253if __name__ == "__main__":
254 tyro.cli(main, description=__doc__)