from __future__ import annotations
import dataclasses
import re
import time
import urllib.parse
import uuid
import warnings
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
)
import imageio.v3 as iio
import numpy as onp
from typing_extensions import Protocol
from ._icons import svg_from_icon
from ._icons_enum import IconName
from ._message_api import _encode_image_base64
from ._messages import GuiCloseModalMessage, GuiRemoveMessage, GuiUpdateMessage, Message
from .infra import ClientId
if TYPE_CHECKING:
import plotly.graph_objects as go
from ._gui_api import GuiApi
from ._viser import ClientHandle
T = TypeVar("T")
TGuiHandle = TypeVar("TGuiHandle", bound="_GuiInputHandle")
def _make_unique_id() -> str:
"""Return a unique ID for referencing GUI elements."""
return str(uuid.uuid4())
class GuiContainerProtocol(Protocol):
_children: Dict[str, SupportsRemoveProtocol] = dataclasses.field(
default_factory=dict
)
class SupportsRemoveProtocol(Protocol):
def remove(self) -> None:
...
@dataclasses.dataclass
class _GuiHandleState(Generic[T]):
"""Internal API for GUI elements."""
label: str
typ: Type[T]
gui_api: GuiApi
value: T
update_timestamp: float
parent_container_id: str
"""Container that this GUI input was placed into."""
update_cb: List[Callable[[GuiEvent], None]]
"""Registered functions to call when this input is updated."""
is_button: bool
"""Indicates a button element, which requires special handling."""
sync_cb: Optional[Callable[[ClientId, Dict[str, Any]], None]]
"""Callback for synchronizing inputs across clients."""
disabled: bool
visible: bool
order: float
id: str
hint: Optional[str]
message_type: Type[Message]
@dataclasses.dataclass
class _GuiInputHandle(Generic[T]):
# Let's shove private implementation details in here...
_impl: _GuiHandleState[T]
# Should we use @property for get_value / set_value, set_hidden, etc?
#
# Benefits:
# @property is syntactically very nice.
# `gui.value = ...` is really tempting!
# Feels a bit more magical.
#
# Downsides:
# Consistency: not everything that can be written can be read, and not everything
# that can be read can be written. `get_`/`set_` makes this really clear.
# Clarity: some things that we read (like client mappings) are copied before
# they're returned. An attribute access obfuscates the overhead here.
# Flexibility: getter/setter types should match. https://github.com/python/mypy/issues/3004
# Feels a bit more magical.
#
# Is this worth the tradeoff?
@property
def order(self) -> float:
"""Read-only order value, which dictates the position of the GUI element."""
return self._impl.order
@property
def value(self) -> T:
"""Value of the GUI input. Synchronized automatically when assigned."""
return self._impl.value
@value.setter
def value(self, value: T | onp.ndarray) -> None:
if isinstance(value, onp.ndarray):
assert len(value.shape) <= 1, f"{value.shape} should be at most 1D!"
value = tuple(map(float, value)) # type: ignore
# Send to client, except for buttons.
if not self._impl.is_button:
self._impl.gui_api._get_api()._queue(
GuiUpdateMessage(self._impl.id, {"value": value})
)
# Set internal state. We automatically convert numpy arrays to the expected
# internal type. (eg 1D arrays to tuples)
self._impl.value = type(self._impl.value)(value) # type: ignore
self._impl.update_timestamp = time.time()
# Call update callbacks.
for cb in self._impl.update_cb:
# Pushing callbacks into separate threads helps prevent deadlocks when we
# have a lock in a callback. TODO: revisit other callbacks.
self._impl.gui_api._get_api()._thread_executor.submit(
lambda: cb(
GuiEvent(
client_id=None,
client=None,
target=self,
)
)
)
@property
def update_timestamp(self) -> float:
"""Read-only timestamp when this input was last updated."""
return self._impl.update_timestamp
@property
def disabled(self) -> bool:
"""Allow/disallow user interaction with the input. Synchronized automatically
when assigned."""
return self._impl.disabled
@disabled.setter
def disabled(self, disabled: bool) -> None:
if disabled == self.disabled:
return
self._impl.gui_api._get_api()._queue(
GuiUpdateMessage(self._impl.id, {"disabled": disabled})
)
self._impl.disabled = disabled
@property
def visible(self) -> bool:
"""Temporarily show or hide this GUI element from the visualizer. Synchronized
automatically when assigned."""
return self._impl.visible
@visible.setter
def visible(self, visible: bool) -> None:
if visible == self.visible:
return
self._impl.gui_api._get_api()._queue(
GuiUpdateMessage(self._impl.id, {"visible": visible})
)
self._impl.visible = visible
def __post_init__(self) -> None:
"""We need to register ourself after construction for callbacks to work."""
gui_api = self._impl.gui_api
# TODO: the current way we track GUI handles and children is very manual +
# error-prone. We should revist this design.
gui_api._gui_input_handle_from_id[self._impl.id] = self
parent = gui_api._container_handle_from_id[self._impl.parent_container_id]
parent._children[self._impl.id] = self
def remove(self) -> None:
"""Permanently remove this GUI element from the visualizer."""
gui_api = self._impl.gui_api
gui_api._get_api()._queue(GuiRemoveMessage(self._impl.id))
gui_api._gui_input_handle_from_id.pop(self._impl.id)
parent = gui_api._container_handle_from_id[self._impl.parent_container_id]
parent._children.pop(self._impl.id)
StringType = TypeVar("StringType", bound=str)
# GuiInputHandle[T] is used for all inputs except for buttons.
#
# We inherit from _GuiInputHandle to special-case buttons because the usage semantics
# are slightly different: we have `on_click()` instead of `on_update()`.
[docs]
@dataclasses.dataclass(frozen=True)
class GuiEvent(Generic[TGuiHandle]):
"""Information associated with a GUI event, such as an update or click.
Passed as input to callback functions."""
client: Optional[ClientHandle]
"""Client that triggered this event."""
client_id: Optional[int]
"""ID of client that triggered this event."""
target: TGuiHandle
"""GUI element that was affected."""
@dataclasses.dataclass
class UploadedFile:
"""Result of a file upload."""
name: str
"""Name of the file."""
content: bytes
"""Contents of the file."""
@dataclasses.dataclass
class GuiUploadButtonHandle(_GuiInputHandle[UploadedFile]):
"""Handle for an upload file button in our visualizer.
The `.value` attribute will be updated with the contents of uploaded files.
"""
def on_upload(
self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None]
) -> Callable[[GuiEvent[TGuiHandle]], None]:
"""Attach a function to call when a button is pressed. Happens in a thread."""
self._impl.update_cb.append(func)
return func
[docs]
@dataclasses.dataclass
class GuiDropdownHandle(GuiInputHandle[StringType], Generic[StringType]):
"""Handle for a dropdown-style GUI input in our visualizer.
Lets us get values, set values, and detect updates."""
_impl_options: Tuple[StringType, ...]
@property
def options(self) -> Tuple[StringType, ...]:
"""Options for our dropdown. Synchronized automatically when assigned.
For projects that care about typing: the static type of `options` should be
consistent with the `StringType` associated with a handle. Literal types will be
inferred where possible when handles are instantiated; for the most flexibility,
we can declare handles as `GuiDropdownHandle[str]`.
"""
return self._impl_options
@options.setter
def options(self, options: Iterable[StringType]) -> None:
self._impl_options = tuple(options)
need_to_overwrite_value = self.value not in self._impl_options
if need_to_overwrite_value:
self._impl.gui_api._get_api()._queue(
GuiUpdateMessage(
self._impl.id,
{"options": self._impl_options, "value": self._impl_options[0]},
)
)
self._impl.value = self._impl_options[0]
else:
self._impl.gui_api._get_api()._queue(
GuiUpdateMessage(
self._impl.id,
{"options": self._impl_options},
)
)
[docs]
@dataclasses.dataclass(frozen=True)
class GuiTabGroupHandle:
_tab_group_id: str
_labels: List[str]
_icons_html: List[Optional[str]]
_tabs: List[GuiTabHandle]
_gui_api: GuiApi
_parent_container_id: str
_order: float
@property
def order(self) -> float:
"""Read-only order value, which dictates the position of the GUI element."""
return self._order
[docs]
def add_tab(self, label: str, icon: Optional[IconName] = None) -> GuiTabHandle:
"""Add a tab. Returns a handle we can use to add GUI elements to it."""
id = _make_unique_id()
# We may want to make this thread-safe in the future.
out = GuiTabHandle(_parent=self, _id=id)
self._labels.append(label)
self._icons_html.append(None if icon is None else svg_from_icon(icon))
self._tabs.append(out)
self._sync_with_client()
return out
def __post_init__(self) -> None:
parent = self._gui_api._container_handle_from_id[self._parent_container_id]
parent._children[self._tab_group_id] = self
[docs]
def remove(self) -> None:
"""Remove this tab group and all contained GUI elements."""
for tab in tuple(self._tabs):
tab.remove()
gui_api = self._gui_api
gui_api._get_api()._queue(GuiRemoveMessage(self._tab_group_id))
parent = gui_api._container_handle_from_id[self._parent_container_id]
parent._children.pop(self._tab_group_id)
def _sync_with_client(self) -> None:
"""Send messages for syncing tab state with the client."""
self._gui_api._get_api()._queue(
GuiUpdateMessage(
self._tab_group_id,
{
"tab_labels": tuple(self._labels),
"tab_icons_html": tuple(self._icons_html),
"tab_container_ids": tuple(tab._id for tab in self._tabs),
},
)
)
[docs]
@dataclasses.dataclass
class GuiFolderHandle:
"""Use as a context to place GUI elements into a folder."""
_gui_api: GuiApi
_id: str # Used as container ID for children.
_order: float
_parent_container_id: str # Container ID of parent.
_container_id_restore: Optional[str] = None
_children: Dict[str, SupportsRemoveProtocol] = dataclasses.field(
default_factory=dict
)
@property
def order(self) -> float:
"""Read-only order value, which dictates the position of the GUI element."""
return self._order
def __enter__(self) -> GuiFolderHandle:
self._container_id_restore = self._gui_api._get_container_id()
self._gui_api._set_container_id(self._id)
return self
def __exit__(self, *args) -> None:
del args
assert self._container_id_restore is not None
self._gui_api._set_container_id(self._container_id_restore)
self._container_id_restore = None
def __post_init__(self) -> None:
self._gui_api._container_handle_from_id[self._id] = self
parent = self._gui_api._container_handle_from_id[self._parent_container_id]
parent._children[self._id] = self
[docs]
def remove(self) -> None:
"""Permanently remove this folder and all contained GUI elements from the
visualizer."""
self._gui_api._get_api()._queue(GuiRemoveMessage(self._id))
for child in tuple(self._children.values()):
child.remove()
parent = self._gui_api._container_handle_from_id[self._parent_container_id]
parent._children.pop(self._id)
self._gui_api._container_handle_from_id.pop(self._id)
@dataclasses.dataclass
class GuiModalHandle:
"""Use as a context to place GUI elements into a modal."""
_gui_api: GuiApi
_id: str # Used as container ID of children.
_container_id_restore: Optional[str] = None
_children: Dict[str, SupportsRemoveProtocol] = dataclasses.field(
default_factory=dict
)
def __enter__(self) -> GuiModalHandle:
self._container_id_restore = self._gui_api._get_container_id()
self._gui_api._set_container_id(self._id)
return self
def __exit__(self, *args) -> None:
del args
assert self._container_id_restore is not None
self._gui_api._set_container_id(self._container_id_restore)
self._container_id_restore = None
def __post_init__(self) -> None:
self._gui_api._container_handle_from_id[self._id] = self
def close(self) -> None:
"""Close this modal and permananently remove all contained GUI elements."""
self._gui_api._get_api()._queue(
GuiCloseModalMessage(self._id),
)
for child in tuple(self._children.values()):
child.remove()
self._gui_api._container_handle_from_id.pop(self._id)
[docs]
@dataclasses.dataclass
class GuiTabHandle:
"""Use as a context to place GUI elements into a tab."""
_parent: GuiTabGroupHandle
_id: str # Used as container ID of children.
_container_id_restore: Optional[str] = None
_children: Dict[str, SupportsRemoveProtocol] = dataclasses.field(
default_factory=dict
)
def __enter__(self) -> GuiTabHandle:
self._container_id_restore = self._parent._gui_api._get_container_id()
self._parent._gui_api._set_container_id(self._id)
return self
def __exit__(self, *args) -> None:
del args
assert self._container_id_restore is not None
self._parent._gui_api._set_container_id(self._container_id_restore)
self._container_id_restore = None
def __post_init__(self) -> None:
self._parent._gui_api._container_handle_from_id[self._id] = self
[docs]
def remove(self) -> None:
"""Permanently remove this tab and all contained GUI elements from the
visualizer."""
# We may want to make this thread-safe in the future.
container_index = -1
for i, tab in enumerate(self._parent._tabs):
if tab is self:
container_index = i
break
assert container_index != -1, "Tab already removed!"
self._parent._labels.pop(container_index)
self._parent._icons_html.pop(container_index)
self._parent._tabs.pop(container_index)
self._parent._sync_with_client()
for child in tuple(self._children.values()):
child.remove()
self._parent._gui_api._container_handle_from_id.pop(self._id)
def _get_data_url(url: str, image_root: Optional[Path]) -> str:
if not url.startswith("http") and not image_root:
warnings.warn(
(
"No `image_root` provided. All relative paths will be scoped to viser's"
" installation path."
),
stacklevel=2,
)
if url.startswith("http") or url.startswith("data:"):
return url
if image_root is None:
image_root = Path(__file__).parent
try:
image = iio.imread(image_root / url)
data_uri = _encode_image_base64(image, "png")
url = urllib.parse.quote(f"{data_uri[1]}")
return f"data:{data_uri[0]};base64,{url}"
except (IOError, FileNotFoundError):
warnings.warn(
f"Failed to read image {url}, with image_root set to {image_root}.",
stacklevel=2,
)
return url
def _parse_markdown(markdown: str, image_root: Optional[Path]) -> str:
markdown = re.sub(
r"\!\[([^]]*)\]\(([^]]*)\)",
lambda match: (
f"![{match.group(1)}]({_get_data_url(match.group(2), image_root)})"
),
markdown,
)
return markdown
[docs]
@dataclasses.dataclass
class GuiMarkdownHandle:
"""Use to remove markdown."""
_gui_api: GuiApi
_id: str
_visible: bool
_parent_container_id: str # Parent.
_order: float
_image_root: Optional[Path]
_content: Optional[str]
@property
def content(self) -> str:
"""Current content of this markdown element. Synchronized automatically when assigned."""
assert self._content is not None
return self._content
@content.setter
def content(self, content: str) -> None:
self._content = content
self._gui_api._get_api()._queue(
GuiUpdateMessage(
self._id,
{"markdown": _parse_markdown(content, self._image_root)},
)
)
@property
def order(self) -> float:
"""Read-only order value, which dictates the position of the GUI element."""
return self._order
@property
def visible(self) -> bool:
"""Temporarily show or hide this GUI element from the visualizer. Synchronized
automatically when assigned."""
return self._visible
@visible.setter
def visible(self, visible: bool) -> None:
if visible == self.visible:
return
self._gui_api._get_api()._queue(
GuiUpdateMessage(self._id, {"visible": visible})
)
self._visible = visible
[docs]
def __post_init__(self) -> None:
"""We need to register ourself after construction for callbacks to work."""
parent = self._gui_api._container_handle_from_id[self._parent_container_id]
parent._children[self._id] = self
[docs]
def remove(self) -> None:
"""Permanently remove this markdown from the visualizer."""
api = self._gui_api._get_api()
api._queue(GuiRemoveMessage(self._id))
parent = self._gui_api._container_handle_from_id[self._parent_container_id]
parent._children.pop(self._id)
@dataclasses.dataclass
class GuiPlotlyHandle:
"""Use to remove markdown."""
_gui_api: GuiApi
_id: str
_visible: bool
_parent_container_id: str # Parent.
_order: float
_figure: Optional[go.Figure]
_aspect: Optional[float]
@property
def figure(self) -> go.Figure:
"""Current content of this markdown element. Synchronized automatically when assigned."""
assert self._figure is not None
return self._figure
@figure.setter
def figure(self, figure: go.Figure) -> None:
self._figure = figure
json_str = figure.to_json()
assert isinstance(json_str, str)
self._gui_api._get_api()._queue(
GuiUpdateMessage(
self._id,
{"plotly_json_str": json_str},
)
)
@property
def aspect(self) -> float:
"""Aspect ratio of the plotly figure, in the control panel."""
assert self._aspect is not None
return self._aspect
@aspect.setter
def aspect(self, aspect: float) -> None:
self._aspect = aspect
self._gui_api._get_api()._queue(
GuiUpdateMessage(
self._id,
{"aspect": aspect},
)
)
@property
def order(self) -> float:
"""Read-only order value, which dictates the position of the GUI element."""
return self._order
@property
def visible(self) -> bool:
"""Temporarily show or hide this GUI element from the visualizer. Synchronized
automatically when assigned."""
return self._visible
@visible.setter
def visible(self, visible: bool) -> None:
if visible == self.visible:
return
self._gui_api._get_api()._queue(
GuiUpdateMessage(self._id, {"visible": visible})
)
self._visible = visible
def __post_init__(self) -> None:
"""We need to register ourself after construction for callbacks to work."""
parent = self._gui_api._container_handle_from_id[self._parent_container_id]
parent._children[self._id] = self
def remove(self) -> None:
"""Permanently remove this plotly element from the visualizer."""
api = self._gui_api._get_api()
api._queue(GuiRemoveMessage(self._id))
parent = self._gui_api._container_handle_from_id[self._parent_container_id]
parent._children.pop(self._id)