Source code for viser.infra._typescript_interface_gen

import dataclasses
import types
from collections import defaultdict
from typing import Any, Type, Union, cast

import numpy as np
from typing_extensions import (
    Annotated,
    Literal,
    NotRequired,
    get_args,
    get_origin,
    get_type_hints,
    is_typeddict,
)

try:
    from typing import Literal as LiteralAlt
except ImportError:
    LiteralAlt = Literal  # type: ignore

from ._messages import Message

_raw_type_mapping = {
    bool: "boolean",
    float: "number",
    int: "number",
    str: "string",
    # For numpy arrays, we directly serialize the underlying data buffer.
    np.ndarray: "Uint8Array",
    bytes: "Uint8Array",
    Any: "any",
    None: "null",
    type(None): "null",
}


def _get_ts_type(typ: Type[Any]) -> str:
    origin_typ = get_origin(typ)

    # Look for TypeScriptAnnotationOverride in the annotations.
    if origin_typ is Annotated:
        args = get_args(typ)
        for arg in args[1:]:
            if isinstance(arg, TypeScriptAnnotationOverride):
                return arg.annotation

        # If no override is found, just use the unwrapped type.
        origin_typ = args[0]

    # Automatic Python => TypeScript conversion.
    UnionType = getattr(types, "UnionType", Union)
    if origin_typ is tuple:
        args = get_args(typ)
        if len(args) == 2 and args[1] == ...:
            return _get_ts_type(args[0]) + "[]"
        else:
            return "[" + ", ".join(map(_get_ts_type, args)) + "]"
    elif origin_typ is list:
        args = get_args(typ)
        assert len(args) == 1
        return _get_ts_type(args[0]) + "[]"
    elif origin_typ is dict:
        args = get_args(typ)
        assert len(args) == 2
        return "{[key: " + _get_ts_type(args[0]) + "]: " + _get_ts_type(args[1]) + "}"
    elif origin_typ in (Literal, LiteralAlt):
        return " | ".join(
            map(
                lambda lit: repr(lit).lower() if type(lit) is bool else repr(lit),
                get_args(typ),
            )
        )
    elif origin_typ in (Union, UnionType):
        return (
            "("
            + " | ".join(
                # We're using dictionary as an ordered set.
                {_get_ts_type(t): None for t in get_args(typ)}.keys()
            )
            + ")"
        )
    elif origin_typ is list:
        args = get_args(typ)
        return _get_ts_type(args[0]) + "[]"
    elif origin_typ is dict:
        args = get_args(typ)
        assert len(args) == 2
        return "{ [key: " + _get_ts_type(args[0]) + "]: " + _get_ts_type(args[1]) + " }"
    elif is_typeddict(typ) or dataclasses.is_dataclass(typ):
        hints = get_type_hints(typ)
        if dataclasses.is_dataclass(typ):
            hints = {field.name: hints[field.name] for field in dataclasses.fields(typ)}
        optional_keys = getattr(typ, "__optional_keys__", [])

        def fmt(key):
            val = hints[key]
            optional = key in optional_keys
            if is_typeddict(typ) and get_origin(val) is NotRequired:
                val = get_args(val)[0]
            ret = f"'{key}'{'?' if optional else ''}" + ": " + _get_ts_type(val)
            return ret

        ret = "{" + ", ".join(map(fmt, hints)) + "}"
        return ret
    else:
        # Like get_origin(), but also supports numpy.typing.NDArray[dtype].
        typ = cast(Any, getattr(typ, "__origin__", typ))

        assert typ in _raw_type_mapping, f"Unsupported type {typ}"
        return _raw_type_mapping[typ]


[docs] @dataclasses.dataclass(frozen=True) class TypeScriptAnnotationOverride: """Use with `typing.Annotated[]` to override the automatically-generated TypeScript annotation corresponding to a dataclass field.""" annotation: str
[docs] def generate_typescript_interfaces(message_cls: Type[Message]) -> str: """Generate TypeScript definitions for all subclasses of a base message class.""" out_lines = [] message_types = message_cls.get_subclasses() tag_map = defaultdict(list) # Generate interfaces for each specific message. for cls in message_types: if cls.__doc__ is not None: docstring = "\n * ".join( map(lambda line: line.strip(), cls.__doc__.split("\n")) ) out_lines.append(f"/** {docstring}") out_lines.append(" *") out_lines.append(" * (automatically generated)") out_lines.append(" */") for tag in getattr(cls, "_tags", []): tag_map[tag].append(cls.__name__) out_lines.append(f"export interface {cls.__name__} " + "{") out_lines.append(f' type: "{cls.__name__}";') field_names = set([f.name for f in dataclasses.fields(cls)]) # type: ignore for name, typ in get_type_hints(cls, include_extras=True).items(): if name in field_names: typ = _get_ts_type(typ) else: continue out_lines.append(f" {name}: {typ};") out_lines.append("}") out_lines.append("") # Generate union type over all messages. out_lines.append("export type Message = ") for cls in message_types: out_lines.append(f" | {cls.__name__}") out_lines[-1] = out_lines[-1] + ";" # Generate union type over all tags. for tag, cls_names in tag_map.items(): out_lines.append(f"export type {tag} = ") for cls_name in cls_names: out_lines.append(f" | {cls_name}") out_lines[-1] = out_lines[-1] + ";" for tag, cls_names in tag_map.items(): out_lines.extend( [ f"const typeSet{tag} = new Set(['" + "', '".join(cls_names) + "']);" f"export function is{tag}(message: Message): message is {tag}" + " {", f" return typeSet{tag}.has(message.type);", "}", ] ) generated_typescript = "\n".join(out_lines) + "\n" # Add header and return. return ( "\n".join( [ ( "// AUTOMATICALLY GENERATED message interfaces, from Python" " dataclass definitions." ), "// This file should not be manually modified.", "", ] ) + generated_typescript )