Skip to content

Misc

utils

Miscellaneous utility functions (stdlib only).

Do not import any other modules from this package here.

logger

logger = getLogger(__name__)

P

P = ParamSpec('P')

S

S = TypeVar('S')

hook

hook(func: Callable[P, S]) -> Hook[P, S]

Wraps a function in a Hook object, making it interceptable.

Decorating a function with @hook allows its behavior to be observed, extended, or even completely replaced by downstream code.

Example usage:

@hook
def foo(state, t):
    ... # pure jax/torch/numpy code...
For example, to time the function:
import time

def timing_interceptor(original_fn, *args, **kwargs):
    start = time.perf_counter()
    result = original_fn(*args, **kwargs)
    end = time.perf_counter()
    print(f"{original_fn.__name__} took {end - start:.4f}s")
    return result

foo.intercept(timing_interceptor)

Source code in src/aerocore/utils.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def hook(func: Callable[P, S]) -> Hook[P, S]:
    """Wraps a function in a `Hook` object, making it interceptable.

    Decorating a function with `@hook` allows its behavior to be observed,
    extended, or even completely replaced by downstream code.

    Example usage:
    ```py
    @hook
    def foo(state, t):
        ... # pure jax/torch/numpy code...
    ```
    For example, to time the function:
    ```py
    import time

    def timing_interceptor(original_fn, *args, **kwargs):
        start = time.perf_counter()
        result = original_fn(*args, **kwargs)
        end = time.perf_counter()
        print(f"{original_fn.__name__} took {end - start:.4f}s")
        return result

    foo.intercept(timing_interceptor)
    ```
    """
    return Hook(handler=func)

Hook

Hook(handler: Callable[P, S])

Bases: Generic[P, S]

A callable that implements the middleware pattern.

handler

handler: Callable[P, S]

__call__

__call__(*args: args, **kwargs: kwargs) -> S
Source code in src/aerocore/utils.py
54
55
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> S:
    return self.handler(*args, **kwargs)

intercept

intercept(
    interceptor: Callable[
        Concatenate[Callable[P, S], P], S
    ],
) -> Hook[P, S]
Source code in src/aerocore/utils.py
57
58
59
60
61
62
63
64
def intercept(
    self, interceptor: Callable[Concatenate[Callable[P, S], P], S]
) -> Hook[P, S]:
    original_handler = self.handler
    self.handler = lambda *args, **kwargs: interceptor(
        original_handler, *args, **kwargs
    )
    return self

debug

debug() -> Hook[P, S]

Log the function arguments and result.

Source code in src/aerocore/utils.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def debug(self) -> Hook[P, S]:
    """Log the function arguments and result."""

    def debug_interceptor(
        original_fn: Callable[P, S], /, *args: P.args, **kwargs: P.kwargs
    ) -> S:
        result = original_fn(*args, **kwargs)
        all_args = ", ".join(
            list(map(repr, args))
            + [f"{k}={v!r}" for k, v in kwargs.items()]
        )
        logger.debug(f"{original_fn.__name__}({all_args}) -> {result!r}")
        return result

    return self.intercept(debug_interceptor)

column

Column

Column(
    unit: Any | None,
    display_name: str | None = None,
    symbol: str | None = None,
    identifier: str | None = None,
)

unit

unit: Any | None

display_name

display_name: str | None = None

symbol

symbol: str | None = None

identifier

identifier: str | None = None

A unique identifier for retrieving the series in a dataframe (optional)

label

label: str

__call__

__call__(df: DataFrame) -> Series
__call__(df: Any) -> Any
__call__(df: Any) -> Any

Returns series in the dataframe.

Source code in src/aerocore/column.py
42
43
44
def __call__(self, df: Any) -> Any:
    """Returns series in the dataframe."""
    return df[self.identifier]

plot

Lightweight matplotlib utils.

When developing on a remote host, you may want to view interactive plots in a web browser. Tunnel port 8988 and use WEB=1 python3 scripts/{}.py to enable the webagg backend.

Requires extras:

  • matplotlib
  • polars for additional plots

CMapCycler

CMapCycler(cmap_name: str = 'tab10')
Source code in src/aerocore/plot.py
33
34
35
36
def __init__(self, cmap_name: str = "tab10"):
    cmap = mpl.colormaps[cmap_name]
    self.colors: tuple[tuple[float, float, float]] = cmap.colors  # type: ignore
    self.N = cmap.N

colors

colors: tuple[tuple[float, float, float]] = colors

N

N = N

__getitem__

__getitem__(i: int) -> tuple[float, float, float]
Source code in src/aerocore/plot.py
38
39
def __getitem__(self, i: int) -> tuple[float, float, float]:
    return self.colors[i % self.N]

C

C = CMapCycler()

Default color cycle

init_style

init_style(
    dark: bool = False,
    fast: bool = True,
    use_tex: bool = True,
) -> None
Source code in src/aerocore/plot.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def init_style(
    dark: bool = False, fast: bool = True, use_tex: bool = True
) -> None:
    if fast:
        plt.style.use("fast")
    if dark:
        plt.style.use("dark_background")

    rc_params: dict[str, Any] = {
        "axes.axisbelow": True,  # make gridlines appear below plot
    }
    if use_tex:
        rc_params.update(
            {
                "text.usetex": use_tex,
                "text.latex.preamble": (
                    r"\usepackage{amsmath}"  # for \text
                    r"\usepackage{amssymb}"  # for real
                    r"\usepackage{siunitx}"
                    r"\usepackage{gensymb}"  # for \degree
                ),
                "font.family": "serif",
                "font.serif": "cm",
            }
        )
    plt.rcParams.update(rc_params)
    if os.getenv("WEB"):
        mpl.use("webagg")

new_figure

new_figure(*args: Any, **kwargs: Any) -> Figure
Source code in src/aerocore/plot.py
76
77
78
79
80
81
82
def new_figure(*args: Any, **kwargs: Any) -> Figure:
    init_style()
    if "figsize" not in kwargs:
        kwargs["figsize"] = (16 * 0.5, 9 * 0.5)
    fig = plt.figure(*args, **kwargs)
    fig.set_layout_engine("tight")
    return fig

setup_xy

setup_xy(
    ax: Axes,
    x: Column,
    y: Column,
    title: Callable[[Column, Column], str]
    | Literal["default"]
    | str
    | None = None,
) -> None

Parameters:

Name Type Description Default
title Callable[[Column, Column], str] | Literal['default'] | str | None

a function that computes the title from x and y, a fixed string, or explicitly None.

None
Source code in src/aerocore/plot.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def setup_xy(
    ax: Axes,
    x: Column,
    y: Column,
    title: (
        Callable[[Column, Column], str] | Literal["default"] | str | None
    ) = None,
) -> None:
    """
    :param title: a function that computes the title from x and y,
        a fixed string, or explicitly None.
    """
    ax.set_xlabel(x.label)
    ax.set_ylabel(y.label)
    if title is not None:
        if callable(title):
            title_ = title(x, y)
        elif title == "default":
            title_ = f"Plot of {y.display_name} against {x.display_name}"
        else:
            title_ = title
        ax.set_title(title_)

add_linear_trendline

add_linear_trendline(
    ax: Axes,
    x: Series,
    y: Series,
    *,
    x_symbol: str = "x",
    y_symbol: str = "y",
    with_legend: bool = True,
    **kwargs: Any,
) -> None

Note: legend is NOT added, call ax.legend() manually.

Source code in src/aerocore/plot.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def add_linear_trendline(
    ax: Axes,
    x: pl.Series,
    y: pl.Series,
    *,
    x_symbol: str = "x",
    y_symbol: str = "y",
    with_legend: bool = True,
    **kwargs: Any,
) -> None:
    """
    Note: legend is NOT added, call `ax.legend()` manually.
    """
    from scipy.stats import linregress

    import numpy as np

    slope, intercept, r, _p, _se = linregress(x, y)

    intercept_sign = "+" if intercept >= 0 else ""
    label = (
        "Linear Trendline\n"
        f"${y_symbol} = {slope:.4f}{x_symbol}{intercept_sign}{intercept:.4f}$\n"
        f"$R^2 = {r**2:.4f}$"
    )

    x_range = np.linspace(x.min(), x.max(), 100)  # type: ignore
    y_trend = slope * x_range + intercept
    if with_legend:
        kwargs["label"] = label
    ax.plot(x_range, y_trend, **kwargs)

basic_scatter

basic_scatter(
    df: DataFrame,
    x: Column,
    y: Column,
    *,
    with_line: bool = False,
) -> Figure
Source code in src/aerocore/plot.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def basic_scatter(
    df: pl.DataFrame,
    x: Column,
    y: Column,
    *,
    with_line: bool = False,
) -> Figure:
    fig = new_figure()
    ax = fig.subplots()
    setup_xy(ax, x, y)
    if with_line:
        ax.plot(x(df), y(df))
    ax.scatter(x(df), y(df), marker="+")
    return fig