"""
Utility functions for type checking array dimensions and types.
These functions are intended to be used as type guards for static type checking
with runtime assertions.
"""
from __future__ import annotations
from typing import Literal, Any, overload
import sys
if sys.version_info < (3, 13):
from typing_extensions import TypeIs
else:
from typing import TypeIs
import numpy as np
from .arraytypes import MeterArray, MeterDtype, TruePeakArray, TruePeakDtype
from .types import (
AnyArray, AnyNdArray, Any1dArray, Any2dArray, Any3dArray, AnyFloatArray,
IndexArray, BoolArray, ComplexArray, NumChannelsT,
ShapeT, DType_co, DType_t,
)
[docs]
def is_1d_array(arr: AnyNdArray[Any, DType_co]) -> TypeIs[Any1dArray[DType_co]]:
"""Check if the given array is a 1-dimensional array
"""
return is_nd_array(arr, 1)
[docs]
def is_2d_array(arr: AnyNdArray[Any, DType_co]) -> TypeIs[Any2dArray[DType_co]]:
"""Check if the given array is a 2-dimensional array
"""
return is_nd_array(arr, 2)
[docs]
def is_3d_array(arr: AnyNdArray[Any, DType_co]) -> TypeIs[Any3dArray[DType_co]]:
"""Check if the given array is a 3-dimensional array
"""
return is_nd_array(arr, 3)
@overload
def is_nd_array(arr: AnyArray, ndim: Literal[1]) -> TypeIs[Any1dArray]: ...
@overload
def is_nd_array(arr: AnyArray, ndim: Literal[2]) -> TypeIs[Any2dArray]: ...
@overload
def is_nd_array(arr: AnyArray, ndim: Literal[3]) -> TypeIs[Any3dArray]: ...
@overload
def is_nd_array(arr: AnyArray, ndim: int) -> TypeIs[AnyNdArray[Any, DType_co]]: ...
[docs]
def is_nd_array(arr: AnyArray, ndim: int) -> TypeIs[AnyNdArray[Any, DType_co]]:
"""Check if the given array shape matches the specified number of dimensions
"""
return arr.ndim == ndim
[docs]
def ensure_1d_array(arr: AnyNdArray[Any, DType_co]) -> Any1dArray[DType_co]:
"""Ensure the given array is 1-dimensional and return it
"""
assert is_1d_array(arr)
return arr
[docs]
def ensure_2d_array(arr: AnyNdArray[Any, DType_co]) -> Any2dArray[DType_co]:
"""Ensure the given array is 2-dimensional and return it
"""
assert is_2d_array(arr)
return arr
[docs]
def ensure_3d_array(arr: AnyNdArray[Any, DType_co]) -> Any3dArray[DType_co]:
"""Ensure the given array is 3-dimensional and return it
"""
assert is_3d_array(arr)
return arr
@overload
def ensure_nd_array(arr: AnyNdArray[Any, DType_co], ndim: Literal[1]) -> Any1dArray[DType_co]: ...
@overload
def ensure_nd_array(arr: AnyNdArray[Any, DType_co], ndim: Literal[2]) -> Any2dArray[DType_co]: ...
@overload
def ensure_nd_array(arr: AnyNdArray[Any, DType_co], ndim: Literal[3]) -> Any3dArray[DType_co]: ...
@overload
def ensure_nd_array(arr: AnyNdArray[Any, DType_co], ndim: int) -> AnyNdArray[Any, DType_co]: ...
[docs]
def ensure_nd_array(arr: AnyNdArray[Any, DType_co], ndim: int) -> AnyNdArray[Any, DType_co]:
"""Ensure the given array has the specified number of dimensions and return it
"""
assert arr.ndim == ndim
return arr
[docs]
def is_array_of_shape(
arr: np.ndarray[tuple[int,...], DType_t],
shape: ShapeT
) -> TypeIs[np.ndarray[ShapeT, DType_t]]:
"""Check if the given array's shape matches the specified shape
"""
return arr.shape == shape
[docs]
def ensure_array_of_shape(
arr: np.ndarray[tuple[int,...], DType_t],
shape: ShapeT
) -> np.ndarray[ShapeT, DType_t]:
"""Ensure the given array's shape matches the specified shape and return it
"""
assert is_array_of_shape(arr, shape)
return arr
[docs]
def is_array_of_dtype(
arr: AnyNdArray[ShapeT, Any],
dtype: DType_t,
) -> TypeIs[AnyNdArray[ShapeT, DType_t]]:
"""Check if the given array's dtype matches the specified dtype
"""
return arr.dtype == dtype
[docs]
def is_float_array(arr: AnyNdArray[ShapeT, Any]) -> TypeIs[AnyFloatArray[ShapeT]]:
"""Check if the given array's dtype is a floating-point type
"""
return np.issubdtype(arr.dtype, np.floating)
[docs]
def is_float32_array(arr: AnyNdArray[ShapeT, Any]) -> TypeIs[AnyNdArray[ShapeT, np.dtype[np.float32]]]:
"""Check if the given array's dtype is :obj:`numpy.float32`
"""
return arr.dtype == np.float32
[docs]
def is_float64_array(arr: AnyNdArray[ShapeT, Any]) -> TypeIs[AnyNdArray[ShapeT, np.dtype[np.float64]]]:
"""Check if the given array's dtype is :obj:`numpy.float64`
"""
return arr.dtype == np.float64
[docs]
def is_index_array(arr: AnyNdArray[ShapeT, Any]) -> TypeIs[IndexArray[ShapeT]]:
"""Check if the given array's dtype is an integer type suitable for indexing
"""
return arr.dtype == np.intp
[docs]
def is_bool_array(arr: AnyNdArray[ShapeT, Any]) -> TypeIs[BoolArray[ShapeT]]:
"""Check if the given array's dtype is boolean
"""
return arr.dtype == np.bool_
[docs]
def is_complex_array(arr: AnyNdArray[ShapeT, Any]) -> TypeIs[ComplexArray[ShapeT]]:
"""Check if the given array's dtype is a complex floating-point type
"""
return np.issubdtype(arr.dtype, np.complexfloating)
[docs]
def is_meter_array(arr: AnyArray) -> TypeIs[MeterArray]:
"""Check if the given array is a :class:`~.arraytypes.MeterArray`
"""
return isinstance(arr, np.ndarray) and arr.dtype == MeterDtype
[docs]
def ensure_meter_array(arr: AnyArray) -> MeterArray:
"""Ensure the given array is a :class:`~.arraytypes.MeterArray` and return it
"""
assert is_meter_array(arr)
return arr
[docs]
def is_true_peak_array(arr: AnyArray, num_channels: NumChannelsT) -> TypeIs[TruePeakArray[NumChannelsT]]:
"""Check if the given array is a :class:`~.arraytypes.TruePeakArray` for the specified number of channels
Arguments:
arr: The array to check
num_channels: The number of audio channels
"""
dtype = build_true_peak_dtype(num_channels)
return isinstance(arr, np.ndarray) and arr.dtype == dtype
[docs]
def ensure_true_peak_array(
arr: AnyArray,
num_channels: NumChannelsT
) -> TruePeakArray[NumChannelsT]:
"""Ensure the given array is a :class:`~.arraytypes.TruePeakArray` for the specified number of channels and return it
Arguments:
arr: The array to check
num_channels: The number of audio channels
"""
assert is_true_peak_array(arr, num_channels)
return arr
[docs]
def build_meter_array(size: int) -> MeterArray:
"""Build a :obj:`~.arraytypes.MeterArray` of the given size
"""
r = np.zeros(size, dtype=MeterDtype)
assert is_meter_array(r)
return r
[docs]
def build_true_peak_dtype(num_channels: NumChannelsT) -> TruePeakDtype[NumChannelsT]:
"""Build a :obj:`~.arraytypes.TruePeakDtype` for the given number of channels
Arguments:
num_channels: The number of audio channels
"""
return np.dtype([
('t', np.float64),
('tp', (np.float64, num_channels)),
]) # type: ignore[return-value]
[docs]
def build_true_peak_array(num_channels: NumChannelsT, size: int) -> TruePeakArray[NumChannelsT]:
"""Build a :obj:`~.arraytypes.TruePeakArray` for the given number of channels and size
Arguments:
num_channels: The number of audio channels
size: The number of elements in the array
"""
dtype = build_true_peak_dtype(num_channels)
r = np.zeros(size, dtype=dtype)
assert is_true_peak_array(r, num_channels)
return r