Source code for xradio.schema.typing

"""Typing support for xarray data classes

This has been extracted from the xarray-dataclasses package by astropenguin
(see https://github.com/astropenguin/xarray-dataclasses/). The reason we
replicate this here is because we actually ignore / redo everything but the
type annotations, especially adding xradio-specific support for multiple
options in data variable / coordinate dimensionality and dtype.
"""

from typing import (
    Any,
    List,
    Tuple,
    Hashable,
    Iterable,
    Type,
    ClassVar,
    Dict,
    TypeVar,
    Union,
    Sequence,
    Generic,
    Collection,
    Literal,
    get_type_hints,
    get_args,
    get_origin,
    Annotated,
    Protocol,
)

from typing import Union

try:
    # Python 3.10 forward: TypeAlias, ParamSpec are standard, and there is the
    # "a | b" UnionType alternative to "Union[a,b]"
    from typing import TypeAlias, ParamSpec
    from types import UnionType

    HAVE_UNIONTYPE = True
except ImportError:
    # Python 3.9: Get TypeAlias, ParamSpec from typing_extensions, no support
    # for "a | b"
    from typing_extensions import (
        TypeAlias,
        ParamSpec,
    )

    HAVE_UNIONTYPE = False
import numpy as np
from itertools import chain
from enum import Enum

PInit = ParamSpec("PInit")
T = TypeVar("T")
TDataClass = TypeVar("TDataClass", bound="DataClass[Any]")
TDims = TypeVar("TDims", covariant=True)
TDType = TypeVar("TDType", covariant=True)
THashable = TypeVar("THashable", bound=Hashable)

AnyArray: TypeAlias = "np.ndarray[Any, Any]"
AnyDType: TypeAlias = "np.dtype[Any]"
AnyField: TypeAlias = "Field[Any]"
AnyXarray: TypeAlias = "xr.DataArray | xr.Dataset"
Dims = Tuple[str, ...]
Order = Literal["C", "F"]
Shape = Union[Sequence[int], int]
Sizes = Dict[str, int]


[docs] class DataClass(Protocol[PInit]): """Type hint for dataclass objects.""" def __init__(self, *args: PInit.args, **kwargs: PInit.kwargs) -> None: ... __dataclass_fields__: ClassVar[Dict[str, AnyField]]
[docs] class Labeled(Generic[TDims]): """Type hint for labeled objects.""" pass
# type hints (public)
[docs] class Role(Enum): """Annotations for typing dataclass fields.""" ATTR = "attr" """Annotation for attribute fields.""" COORD = "coord" """Annotation for coordinate fields.""" DATA = "data" """Annotation for data (variable) fields.""" NAME = "name" """Annotation for name fields.""" OTHER = "other" """Annotation for other fields."""
[docs] @classmethod def annotates(cls, tp: Any) -> bool: """Check if any role annotates a type hint.""" if get_origin(tp) is not Annotated: return False return any(isinstance(arg, cls) for arg in get_args(tp))
Attr = Annotated[T, Role.ATTR] """Type hint for attribute fields (``Attr[T]``). Example: :: @dataclass class Image(): data: Data[tuple[X, Y], float] long_name: Attr[str] = "luminance" units: Attr[str] = "cd / m^2" Hint: The following field names are specially treated when plotting. - ``long_name`` or ``standard_name``: Coordinate name. - ``units``: Coordinate units. Reference: https://xarray.pydata.org/en/stable/user-guide/plotting.html """ Coord = Annotated[Union[Labeled[TDims], Collection[TDType], TDType], Role.COORD] """Type hint for coordinate fields (``Coord[TDims, TDType]``). Example: :: @dataclass class Image(): data: Data[tuple[X, Y], float] mask: Coord[tuple[X, Y], bool] x: Coord[X, int] = 0 y: Coord[Y, int] = 0 Hint: A coordinate field whose name is the same as ``TDims`` (e.g. ``x: Coord[X, int]``) can define a dimension. """ Coordof = Annotated[Union[TDataClass, Any], Role.COORD] """Type hint for coordinate fields (``Coordof[TDataClass]``). Unlike ``Coord``, it specifies a dataclass that defines a DataArray class. This is useful when users want to add metadata to dimensions for plotting. Example: :: @dataclass class XAxis: data: Data[X, int] long_name: Attr[str] = "x axis" @dataclass class YAxis: data: Data[Y, int] long_name: Attr[str] = "y axis" @dataclass class Image(): data: Data[tuple[X, Y], float] x: Coordof[XAxis] = 0 y: Coordof[YAxis] = 0 """ Data = Annotated[Union[Labeled[TDims], Collection[TDType], TDType], Role.DATA] """Type hint for data fields (``Coordof[TDims, TDType]``). Example: Exactly one data field is allowed in a DataArray class (the second and subsequent data fields are just ignored):: @dataclass class Image(): data: Data[tuple[X, Y], float] Multiple data fields are allowed in a Dataset class:: @dataclass class ColorImage(): red: Data[tuple[X, Y], float] green: Data[tuple[X, Y], float] blue: Data[tuple[X, Y], float] """ Dataof = Annotated[Union[TDataClass, Any], Role.DATA] """Type hint for data fields (``Coordof[TDataClass]``). Unlike ``Data``, it specifies a dataclass that defines a DataArray class. This is useful when users want to reuse a dataclass in a Dataset class. Example: :: @dataclass class Image: data: Data[tuple[X, Y], float] x: Coord[X, int] = 0 y: Coord[Y, int] = 0 @dataclass class ColorImage(): red: Dataof[Image] green: Dataof[Image] blue: Dataof[Image] """ Name = Annotated[THashable, Role.NAME] """Type hint for name fields (``Name[THashable]``). Example: :: @dataclass class Image(): data: Data[tuple[X, Y], float] name: Name[str] = "image" """
[docs] def deannotate(tp: Any) -> Any: """Recursively remove annotations in a type hint.""" class Temporary: __annotations__ = dict(type=tp) return get_type_hints(Temporary)["type"]
[docs] def find_annotated(tp: Any) -> Iterable[Any]: """Generate all annotated types in a type hint.""" args = get_args(tp) if get_origin(tp) is Annotated: yield tp yield from find_annotated(args[0]) else: yield from chain(*map(find_annotated, args))
[docs] def get_annotated(tp: Any) -> Any: """Extract the first role-annotated type.""" for annotated in filter(Role.annotates, find_annotated(tp)): return deannotate(annotated) raise TypeError("Could not find any role-annotated type.")
[docs] def get_annotations(tp: Any) -> Tuple[Any, ...]: """Extract annotations of the first role-annotated type.""" for annotated in filter(Role.annotates, find_annotated(tp)): return get_args(annotated)[1:] raise TypeError("Could not find any role-annotated type.")
[docs] def get_dataclass(tp: Any) -> Type[DataClass[Any]]: """Extract a dataclass.""" try: dataclass = get_args(get_annotated(tp))[0] except TypeError: raise TypeError(f"Could not find any dataclass in {tp!r}.") if not is_dataclass(dataclass): raise TypeError(f"Could not find any dataclass in {tp!r}.") return dataclass
[docs] def get_dims(tp: Any) -> List[Dims]: """Extract data dimensions (dims).""" try: dims = get_args(get_args(get_annotated(tp))[0])[0] except TypeError: raise TypeError(f"Could not find any dims in {tp!r}.") # List of allowed dtypes (might just be one) if get_origin(dims) is Union: dims_in = get_args(dims) elif HAVE_UNIONTYPE and get_origin(dims) is UnionType: dims_in = get_args(dims) else: dims_in = [dims] dims_out = [] for dim in dims_in: args = list(get_args(dim)) origin = get_origin(dim) # One-dimensional dimension if origin is Literal: dims_out.append([str(args[0])]) continue if not (origin is tuple or origin is Tuple): raise TypeError(f"Could not find any dims in {tp!r}.") # Zero-dimensions if args == [] or args == [()]: dims_out.append([]) continue if not all(get_origin(arg) is Literal for arg in args): raise TypeError(f"Could not find any dims in {tp!r}.") dims_out.append([str(get_args(arg)[0]) for arg in args]) return dims_out
[docs] def get_types(tp: Any) -> List[AnyDType]: """Extract data types from type annotation E.g. Coord[..., Type1 | Type2 | ...] or Data[..., Type1 | Type2 | ...] """ try: typ = get_args(get_args(get_annotated(tp))[1])[0] except TypeError: raise TypeError(f"Could not find any dtype in {tp!r}.") # List of allowed dtypes (might just be one) if get_origin(typ) is Union: types_in = get_args(typ) elif HAVE_UNIONTYPE and get_origin(typ) is UnionType: types_in = get_args(typ) else: types_in = [typ] types_out = [] for dt in types_in: # Handle case that we want to allow "Any" if dt is Any or dt is type(None): types_out.append(None) continue # Allow specifying type as literal (e.g. string) elif get_origin(dt) is Literal: dt = get_args(dt)[0] # Return type types_out.append(dt) return types_out
[docs] def get_name(tp: Any, default: Hashable = None) -> Hashable: """Extract a name if found or return given default.""" try: annotations = get_annotations(tp)[1:] except TypeError: return default for annotation in annotations: if isinstance(annotation, Hashable): return annotation return default
[docs] def get_role(tp: Any, default: Role = Role.OTHER) -> Role: """Extract a role if found or return given default.""" try: return get_annotations(tp)[0] except TypeError: return default
[docs] def is_optional(type_ann): """ Check whether a type annotation indicates that the value is optional Boils down to checking whether it's a union type that includes None """ if get_origin(type_ann) is Union: return None.__class__ in get_args(type_ann) return False