Imports#

import numpy as np
import jax.numpy as jnp
from numba import jit
import numba
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_mode
import timeit
import jax
import math
from jax.scipy.special import gammaln
from functools import partial

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import plotly.io as pio
pio.renderers.default = "notebook"
print("pytensor version:", pytensor.__version__)
print("jax version:", jax.__version__)
print("numba version:", numba.__version__)
pytensor version: 0+untagged.31343.g647570d.dirty
jax version: 0.7.2
numba version: 0.62.1
class Benchmarker:
    """
    Benchmark a set of functions by timing execution and summarizing statistics.

    Parameters
    ----------
    functions : list of callables
        List of callables to benchmark.
    names : list of str, optional
        Names corresponding to each function. Default is ['func_0', 'func_1', ...].
    number : int or None, optional
        Number of loops per timing. If None, auto-calibrated via Timer.autorange().
        Default is None.
    repeat : int, optional
        Number of repeats for timing. Default is 7.
    target_time : float, optional
        Target duration in seconds for auto-calibration. Default is 0.2.

    Attributes
    ----------
    results : dict
        Mapping from function names to a dict with keys:
        - 'raw_us': numpy.ndarray of raw timings in microseconds
        - 'loops': number of loops used per timing

    Methods
    -------
    run()
        Auto-calibrate (if needed) and run timings for all functions.
    summary(unit='us') -> pandas.DataFrame
        Return a summary DataFrame with statistics converted to the given unit.
    raw(name=None) -> dict or numpy.ndarray
        Return raw timing data in microseconds for a specific function or all.
    _convert_times(times, unit) -> numpy.ndarray
        Convert an array of times from microseconds to the specified unit.
    """

    def __init__(
        self, functions, names=None, number=None, min_rounds=5, max_time=1.0, target_time=0.2
    ):
        self.functions = functions
        self.names = names or [f"func_{i}" for i in range(len(functions))]
        self.number = number
        self.min_rounds = min_rounds
        self.max_time = max_time
        self.target_time = target_time
        self.results = {}

    def run(self, inputs: dict[str, dict]):
        """
        Auto-calibrate loop count and sample rounds if needed, then time each function.
        """
        for name, func in zip(self.names, self.functions):
            for input_name, kwargs in inputs.items():
                timer = timeit.Timer(partial(func, **kwargs))

                # Calibrate loops
                if self.number is None:
                    loops, calib_time = timer.autorange()
                else:
                    loops = self.number
                    calib_time = timer.timeit(number=loops)

                # Determine rounds based on max_time and min_rounds
                if self.max_time is not None:
                    rounds = max(self.min_rounds, int(np.ceil(self.max_time / calib_time)))
                else:
                    rounds = self.min_rounds

                raw_round_times = np.array(timer.repeat(repeat=rounds, number=loops))

                # Convert to microseconds per single execution
                raw_us = raw_round_times / loops * 1e6

                self.results[(name, input_name)] = {
                    "raw_us": raw_us,
                    "loops": loops,
                    "rounds": rounds,
                }

    def summary(self, unit="us"):
        """
        Summarize benchmark statistics in a DataFrame.

        Parameters
        ----------
        unit : {'us', 'ms', 'ns', 's'}, optional
            Unit for output times. 'us' means microseconds, 'ms' milliseconds,
            'ns' nanoseconds, 's' seconds. Default is 'us'.

        Returns
        -------
        pandas.DataFrame
            Summary with columns:
            Name, Loops, Min, Max, Mean, StdDev, Median, IQR (all in given unit),
            OPS (Kops/unit), Samples.
        """
        records = []
        indexes = []
        for name, data in self.results.items():
            raw_us = data["raw_us"]
            # Convert to target unit
            times = self._convert_times(raw_us, unit)
            if isinstance(name, tuple) and len(name) > 1:
                indexes.append(name)
            elif isinstance(name, tuple) and len(name) == 1:
                indexes.append(name[0])
            else:
                indexes.append(name)

            stats = {
                "Loops": data["loops"],
                f"Min ({unit})": np.min(times),
                f"Max ({unit})": np.max(times),
                f"Mean ({unit})": np.mean(times),
                f"StdDev ({unit})": np.std(times),
                f"Median ({unit})": np.median(times),
                f"IQR ({unit})": np.percentile(times, 75) - np.percentile(times, 25),
                "OPS (Kops/s)": 1e3 / (np.mean(raw_us)),
                "Samples": len(raw_us),
            }
            records.append(stats)

        if all(isinstance(idx, tuple) for idx in indexes):
            index = pd.MultiIndex.from_tuples(indexes)
        else:
            index = pd.Index(indexes)
        return pd.DataFrame(records, index=index)

    def raw(self, name=None):
        """
        Get raw timing data in microseconds.

        Parameters
        ----------
        name : str, optional
            If given, returns the raw_us array for that function. Otherwise returns
            a dict of all raw results.

        Returns
        -------
        numpy.ndarray or dict
        """
        if name:
            return self.results.get(name, {}).get("raw_us")
        return {n: d["raw_us"] for n, d in self.results.items()}

    def _convert_times(self, times, unit):
        """
        Convert an array of times from microseconds to the specified unit.

        Parameters
        ----------
        times : array-like
            Times in microseconds.
        unit : {'us', 'ms', 'ns', 's'}
            Target unit: 'us' microseconds, 'ms' milliseconds,
            'ns' nanoseconds, 's' seconds.

        Returns
        -------
        numpy.ndarray
            Converted times.

        Raises
        ------
        ValueError
            If `unit` is not one of the supported options.
        """
        unit = unit.lower()
        if unit == "us":
            factor = 1.0
        elif unit == "ms":
            factor = 1e-3
        elif unit == "ns":
            factor = 1e3
        elif unit == "s":
            factor = 1e-6
        else:
            raise ValueError(f"Unsupported unit: {unit}")
        return times * factor
# def block(func):
#     def inner(*args, **kwargs):
#         return jax.block_until_ready(func(*args, **kwargs))
#     return inner


# I believe the above will only work if the return is a single jnp.array (at least that is what AI thinks) I would appreciate your insights here, Adrian.

def block(func):
    def inner(*args, **kwargs):
        result = func(*args, **kwargs)
        # Recursively block on all JAX arrays in result
        jax.tree_util.tree_map(lambda x: x.block_until_ready() if hasattr(x, "block_until_ready") else None, result)
        return result
    return inner
# Set Pytensor to use float32
pytensor.config.floatX = "float32"

Introduction#

Baby Steps#

Fibonacci Algorithm#

N_STEPS = 100_000

b_symbolic = pt.vector("b", dtype="int32", shape=(1,))

def step(a, b):
    return a + b, a

(outputs_a, outputs_b), _ = pytensor.scan(
    fn=step,
    outputs_info=[pt.ones(1, dtype="int32"), b_symbolic],
    n_steps=N_STEPS
)

# compile function returning final a
fibonacci_pytensor = pytensor.function([b_symbolic], outputs_a[-1], trust_input=True)
fibonacci_pytensor_numba = pytensor.function([b_symbolic], outputs_a[-1], mode='NUMBA', trust_input=True)
@jit(nopython=True)
def fibonacci_numba(b):
    a = np.ones(1, dtype=np.int32)
    for _ in range(N_STEPS):
        a[0], b[0] = a[0] + b[0], a[0]
    return a[0]
@jax.jit
def fibonacci_jax(b):
    a = jnp.array(1, dtype=np.int32)
    for _ in range(N_STEPS):
        a, b = a + b, a
    return a
fibonacci_bench = Benchmarker(
    functions=[fibonacci_pytensor, fibonacci_numba, fibonacci_jax, fibonacci_pytensor_numba], 
    names=['fibonacci_pytensor', 'fibonacci_numba', 'fibonacci_jax', 'fibonacci_pytensor_numba'],
    number=10
)
fibonacci_bench.run(
    inputs={
        "fibonacci_inputs": {"b": np.ones(1, dtype=np.int32)},
    }
)
fibonacci_bench.summary()
Loops Min (us) Max (us) Mean (us) StdDev (us) Median (us) IQR (us) OPS (Kops/s) Samples
fibonacci_pytensor fibonacci_inputs 10 54149.950002 55278.412500 54654.332500 434.902898 54807.945801 694.087401 0.018297 5
fibonacci_numba fibonacci_inputs 10 84.208400 105.558301 95.575969 4.793104 95.562500 5.033301 10.462881 13
fibonacci_jax fibonacci_inputs 10 7.162499 31.558401 14.425020 8.923158 12.233300 5.920898 69.323996 5
fibonacci_pytensor_numba fibonacci_inputs 10 2338.045801 2429.037497 2385.059159 34.438988 2392.570800 58.641701 0.419277 5

Element-wise multiplication Algorithm#

# a_symbolic = pt.vector("a", dtype="int32")
# b_symbolic = pt.vector("b", dtype="int32")

# def step(a_element, b_element):
#         return a_element * b_element
    
# c, _ = pytensor.scan(
#     fn=step,
#     sequences=[a_symbolic, b_symbolic]
# )

# # compile function returning final a
# c_mode = get_mode("FAST_RUN").excluding("scan_push_out_seq")
# elementwise_multiply_pytensor = pytensor.function([a_symbolic, b_symbolic], c, trust_input=True, mode=c_mode)

# numba_mode = get_mode("NUMBA").excluding("scan_push_out_seq")
# elementwise_multiply_pytensor_numba = pytensor.function([a_symbolic, b_symbolic], c, mode=numba_mode, trust_input=True)
# Not sure this makes a difference
a_symbolic = pt.vector("a", dtype="int32")
b_symbolic = pt.vector("b", dtype="int32")
N_STEPS = 1000

def step(a_element, b_element):
    return a_element * b_element

c, _ = pytensor.scan(
    fn=step,
    sequences=[a_symbolic, b_symbolic],
    n_steps=N_STEPS
)

# compile function returning final a
c_mode = get_mode("FAST_RUN").excluding("scan_push_out_seq")
elementwise_multiply_pytensor = pytensor.function([a_symbolic, b_symbolic], c, trust_input=True, mode=c_mode)

numba_mode = get_mode("NUMBA").excluding("scan_push_out_seq")
elementwise_multiply_pytensor_numba = pytensor.function([a_symbolic, b_symbolic], c, mode=numba_mode, trust_input=True)
@jit(nopython=True)
def elementwise_multiply_numba(a, b):
    n = a.shape[0]
    c = np.empty(n, dtype=a.dtype)
    for i in range(n):
        c[i] = a[i] * b[i]
    return c
@block
@jax.jit
def elementwise_multiply_jax(a, b):
    n = a.shape[0]
    c_init = jnp.empty(n, dtype=a.dtype)
    def step(i, c):
        return jax.lax.dynamic_update_index_in_dim(c, a[i] * b[i], i, axis=0)
    
    c = jax.lax.fori_loop(0, n, step, c_init)
    return c
a = np.random.normal(0, 1, (N_STEPS)).astype(np.int32)
b = np.random.normal(0, 1, (N_STEPS)).astype(np.int32)
elem_mult_bench = Benchmarker(
    functions=[elementwise_multiply_pytensor, elementwise_multiply_numba, elementwise_multiply_jax, elementwise_multiply_pytensor_numba], 
    names=['elementwise_multiply_pytensor', 'elementwise_multiply_numba', 'elementwise_multiply_jax', 'elementwise_multiply_pytensor_numba'],
    number=10
)
elem_mult_bench.run(
    inputs={
        "elem_mult_inputs": {"a": a, "b": b},
    }
)
elem_mult_bench.summary()
Loops Min (us) Max (us) Mean (us) StdDev (us) Median (us) IQR (us) OPS (Kops/s) Samples
elementwise_multiply_pytensor elem_mult_inputs 10 450.595800 967.095900 492.991945 36.396199 491.412499 18.221901 2.028431 220
elementwise_multiply_numba elem_mult_inputs 10 0.366700 0.720800 0.394247 0.073563 0.379200 0.012497 2536.478255 21
elementwise_multiply_jax elem_mult_inputs 10 7.512499 10.391598 8.152427 0.621782 7.895901 0.593750 122.662853 55
elementwise_multiply_pytensor_numba elem_mult_inputs 10 34.662499 50.737502 39.280821 5.934219 37.408300 3.912600 25.457717 5

Changepoint Detection Algorithms#

Cumulative Sum (CUSUM) Algorithm#

@jit(nopython=True)
def cusum_adaptive_numba(x, alpha=0.01, k=0.5, h=5.0):
    """
    Two-sided CUSUM with adaptive exponential moving average baseline.
    
    Parameters
    ----------
    x: np.ndarray
        input signal
    alpha: float
        EMA smoothing factor (0 < alpha <= 1)
    k: float
        slack to avoid small changes triggering alarms
    h: float
        threshold for raising an alarm
        
    Returns
    -------
    s_pos: np.ndarray
        upper CUSUM stats
    s_neg: np.ndarray
        lower CUSUM stats
    mu_t: np.ndarray
        evolving baseline estimate
    alarms_pos: np.ndarray
        alarms for upward changes
    alarms_neg: np.ndarray
        alarms for downward changes
    """
    n = x.shape[0]

    s_pos = np.zeros(n, dtype=np.float32)
    s_neg = np.zeros(n, dtype=np.float32)
    mu_t  = np.zeros(n, dtype=np.float32)
    alarms_pos = np.zeros(n, dtype=np.bool_)
    alarms_neg = np.zeros(n, dtype=np.bool_)

    # Initialization
    mu_t[0] = x[0]

    for i in range(1, n):
        # Update baseline (EMA)
        mu_t[i] = alpha * x[i] + (1 - alpha) * mu_t[i-1]

        # Update CUSUM stats
        s_pos[i] = max(0.0, s_pos[i-1] + x[i] - mu_t[i] - k)
        s_neg[i] = max(0.0, s_neg[i-1] - (x[i] - mu_t[i]) - k)

        # Alarms
        alarms_pos[i] = s_pos[i] > h
        alarms_neg[i] = s_neg[i] > h

    return s_pos, s_neg, mu_t, alarms_pos, alarms_neg
@block
@jax.jit
def cusum_adaptive_jax(x, alpha=0.01, k=0.5, h=5.0):
    """
    Two-sided CUSUM with adaptive exponential moving average baseline.
    
    Parameters
    ----------
    x: jnp.ndarray
        input signal
    alpha: float
        EMA smoothing factor (0 < alpha <= 1)
    k: float
        slack to avoid small changes triggering alarms
    h: float
        threshold for raising an alarm
        
    Returns
    -------
    s_pos: jnp.ndarray
        upper CUSUM stats
    s_neg: jnp.ndarray
        lower CUSUM stats
    mu_t: jnp.ndarray
        evolving baseline estimate
    alarms_pos: jnp.ndarray
        alarms for upward changes
    alarms_neg: jnp.ndarray
        alarms for downward changes
    """
    def body(carry, x_t):
        s_pos_prev, s_neg_prev, mu_prev = carry
        
        # Update EMA baseline
        mu_t = alpha * x_t + (1 - alpha) * mu_prev
        
        # Update CUSUMs using updated baseline
        s_pos = jnp.maximum(0.0, s_pos_prev + x_t - mu_t - k)
        s_neg = jnp.maximum(0.0, s_neg_prev - (x_t - mu_t) - k)
        
        new_carry = (s_pos, s_neg, mu_t)
        output = (s_pos, s_neg, mu_t)
        return new_carry, output

    # Initialize: CUSUMs at 0, initial mean = first sample
    s0 = (0.0, 0.0, x[0])
    _, (s_pos_vals, s_neg_vals, mu_vals) = jax.lax.scan(body, s0, x)
    
    # Thresholding
    alarms_pos = s_pos_vals > h
    alarms_neg = s_neg_vals > h

    return s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg
x_symbolic = pt.vector("x")
alpha_symbolic = pt.scalar("alpha")
k_symbolic = pt.scalar("k")
h_symbolic = pt.scalar("h")

N_STEPS = 100 # Fixing this just incase but I don't think this is changing anything when we have sequences as input to scan

def step(x_t, s_pos_prev, s_neg_prev, mu_prev, alpha, k):
    # Update EMA baseline
    mu_t = alpha * x_t + (1 - alpha) * mu_prev
    
    # Update CUSUMs using updated baseline
    s_pos = pt.maximum(0.0, s_pos_prev + x_t - mu_t - k)
    s_neg = pt.maximum(0.0, s_neg_prev - (x_t - mu_t) - k)
    
    return s_pos, s_neg, mu_t


(s_pos_vals, s_neg_vals, mu_vals), updates = pytensor.scan(
    fn=step,
    outputs_info=[pt.constant(0., dtype="float32"), pt.constant(0., dtype="float32"), x_symbolic[0]],
    non_sequences=[alpha_symbolic, k_symbolic],
    sequences=[x_symbolic],
    n_steps=N_STEPS
)

# Thresholding
alarms_pos = s_pos_vals > h_symbolic
alarms_neg = s_neg_vals > h_symbolic

cusum_adaptive_pytensor = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], trust_input=True)

cusum_adaptive_pytensor_numba = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], mode="NUMBA", trust_input=True)
cusum_adaptive_pytensor_jax = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], mode="JAX", trust_input=True)
xs1 = np.random.normal(80, 20, size=(int(N_STEPS/2)))
xs2 = np.random.normal(50, 20, size=(int(N_STEPS/2)))
xs = np.concat((xs1, xs2))
xs = xs.astype(np.float32)
xs = xs.astype(np.float32)
xs_std = (xs - xs.mean()) / xs.std()
cusum_bench = Benchmarker(
    functions=[cusum_adaptive_pytensor, cusum_adaptive_numba, cusum_adaptive_jax, cusum_adaptive_pytensor_numba, block(cusum_adaptive_pytensor_jax)], 
    names=['cusum_adaptive_pytensor', 'cusum_adaptive_numba', 'cusum_adaptive_jax', 'cusum_adaptive_pytensor_numba', 'cusum_adaptive_pytensor_jax'],
    number=10
)
cusum_bench.run(
    inputs={
        "cusum_inputs": {"x": xs, "alpha": 0.1, "k": 0.5, "h": 3.5},
    }
)
cusum_bench.summary()
Loops Min (us) Max (us) Mean (us) StdDev (us) Median (us) IQR (us) OPS (Kops/s) Samples
cusum_adaptive_pytensor cusum_inputs 10 124.412501 226.245899 141.837420 8.902762 139.566700 10.325000 7.050326 681
cusum_adaptive_numba cusum_inputs 10 1.729201 2.174999 1.806571 0.150796 1.745898 0.018753 553.534780 7
cusum_adaptive_jax cusum_inputs 10 14.366599 18.899998 15.076731 0.942976 14.633298 0.738525 66.327378 36
cusum_adaptive_pytensor_numba cusum_inputs 10 24.316600 36.341700 27.204980 4.596065 25.033302 1.241703 36.757976 5
cusum_adaptive_pytensor_jax cusum_inputs 10 21.479101 27.295901 22.782493 1.642595 21.893748 1.734347 43.893352 30
outputs = cusum_adaptive_numba(xs_std, alpha=0.1, k=0.5, h=3.5)
fig = go.Figure()
fig.add_traces(
    [
        go.Scatter(
            x = np.arange(len(xs)),
            y = xs_std,
            name="series"
        ),
        go.Scatter(
            x = np.arange(len(xs)),
            y = outputs[0],
            name="cum. positive devs."
        ),
        go.Scatter(
            x = np.arange(len(xs)),
            y = outputs[1],
            name="cum. negative devs."
        ),
        go.Scatter(
            x = np.arange(len(xs)),
            y = outputs[2],
            name="Exp. Mean"
        ),
        go.Scatter(
            x = np.arange(len(xs)),
            y = outputs[3].astype(np.float16),
            name="positive alarms"
        ),
        go.Scatter(
            x = np.arange(len(xs)),
            y = outputs[4].astype(np.float16),
            name="negative alarms"
        ),
        
    ]
)
fig.update_layout(
    title = dict(
        text = "CUSUM Change Point Detection Algorithm"
    ),
    xaxis=dict(
        title = "Time Index"
    ),
    yaxis=dict(
        title = "Standardized Series Scaled"
    ),
    legend=dict(
        yanchor="top",
        y=1.1,
        xanchor="left",
        x=0,
        orientation="h"
    ),
    template="plotly_dark"
)

Pruned Exact Linear Time (PELT) Algorithm#

@jit(nopython=True)
def segment_cost_numba(S1, S2, i, j):
    """Cost of segment x[i:j], SSE around mean"""
    n = j - i
    sum_x = S1[j] - S1[i]
    sum_x2 = S2[j] - S2[i]
    if n > 0:
        return sum_x2 - (sum_x ** 2) / n
    else:
        return np.inf

@jit(nopython=True)
def pelt_numba(x, beta=10.0):
    """
    Pruned Exact Linear Time algorithm for change point detection

    Parameters
    ----------
    x: np.ndarray
        The timeseries signal
    beta: float
        Penalty of segmenting the series

    Returns
    -------
    C: np.ndarray
        The best costs up to segment t
    last_change: np.ndarray
        The last change point up to segment t
    """
    n = len(x)

    # cumulative sums for cost
    S1 = np.empty(n+1, dtype=np.float32)
    S2 = np.empty(n+1, dtype=np.float32)
    S1[0], S2[0] = 0.0, 0.0
    for i in range(1, n+1):
        S1[i] = S1[i-1] + x[i-1]
        S2[i] = S2[i-1] + x[i-1]**2

    # DP arrays
    C = np.full((n+1,), np.inf)
    C[0] = -beta
    last_change = np.full((n+1,), -1)
    min_size = 3

    for t in range(1, n+1):
        costs = np.full(n, np.inf)
        for s in range(n):
            if s < t and (t - s) >= min_size:
                costs[s] = C[s] + segment_cost_numba(S1, S2, s, t) + beta
        best_s = np.argmin(costs)
        C[t] = costs[best_s]
        last_change[t] = best_s

    return C, last_change
def segment_cost_jax(S1, S2, i, j):
    """Cost of segment x[i:j], SSE around mean"""
    n = j - i
    sum_x = S1[j] - S1[i]
    sum_x2 = S2[j] - S2[i]
    return jnp.where(n > 0, sum_x2 - (sum_x ** 2) / n, jnp.inf)

@block
@jax.jit
def pelt_jax(x, beta=10.0):
    """
    Pruned Exact Linear Time algorithm for change point detection

    Parameters
    ----------
    x: np.ndarray
        The timeseries signal
    beta: float
        Penalty of segmenting the series

    Returns
    -------
    C: jnp.ndarray
        The best costs up to segment t
    last_change: jnp.ndarray
        The last change point up to segment t
    """
    n = len(x)

    # cumulative sums for cost
    S1 = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(x)])
    S2 = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(x**2)])

    # DP arrays
    C = jnp.full((n+1,), jnp.inf)
    C = C.at[0].set(-beta)
    last_change = jnp.full((n+1,), -1)
    min_size = 3

    s_all = jnp.arange(n)   # all possible candidates

    def body(t, carry):
        C, last_change = carry

        # Compute cost for all s < t, mask invalid
        # valid = s_all < t & ((t - s_all) >= min_size)
        
        valid = (s_all < t) & ((t - s_all) >= min_size)
        costs = jnp.where(
            valid,
            C[s_all] + segment_cost_jax(S1, S2, s_all, t) + beta,
            jnp.inf
        )

        best_s = jnp.argmin(costs)
        C = C.at[t].set(costs[best_s])
        last_change = last_change.at[t].set(best_s)
        return C, last_change

    C, last_change = jax.lax.fori_loop(1, n+1, body, (C, last_change))
    return C, last_change
def segment_cost_pytensor(S1, S2, i, j):
    """Cost of segment x[i:j], SSE around mean"""
    n = j - i
    sum_x = S1[j] - S1[i]
    sum_x2 = S2[j] - S2[i]
    return pt.switch(
        pt.gt(n, 0),
        sum_x2 - (sum_x ** 2) / n,
        np.inf
    )
x_symbolic = pt.vector("x")
beta_symbolic = pt.scalar("beta")
n = x_symbolic.shape[0]
N_STEPS=100

# cumulative sums for cost
S1 = pt.concatenate([pt.as_tensor([0.0]), pt.cumsum(x_symbolic)])
S2 = pt.concatenate([pt.as_tensor([0.0]), pt.cumsum(x_symbolic**2)])

# DP arrays
C_init = pt.alloc(np.inf, n+1)
C_init = pt.set_subtensor(C_init[0], -beta_symbolic)
last_change_init = pt.alloc(-1, n+1)

s_all = pt.arange(n)   # candidate change points
min_size = 3

def step(t, C_prev, last_change_prev, S1, S2, beta_symbolic, s_all):
    # valid = (s_all < t) & ((t - s_all) >= min_size)
    valid = pt.and_(pt.lt(s_all, t), pt.ge(t - s_all, min_size))

    # compute costs for all candidates
    costs, _ = pytensor.scan(
        fn=lambda s: pt.switch(
            valid[s],
            C_prev[s] + segment_cost_pytensor(S1, S2, s, t) + beta_symbolic,
            np.inf
        ),
        sequences=[pt.arange(n)]
    )
    costs = costs.flatten()

    best_s = pt.argmin(costs, axis=0)
    C_new = pt.set_subtensor(C_prev[t], costs[best_s])
    last_change_new = pt.set_subtensor(last_change_prev[t], best_s)

    return C_new, last_change_new

(C_vals, last_change_vals), _ = pytensor.scan(
    fn=step,
    sequences=[pt.arange(1, n+1)],
    outputs_info=[C_init, last_change_init],
    non_sequences=[S1, S2, beta_symbolic, s_all],
    n_steps=N_STEPS # Added fixed iterations here
)

pelt_pytensor = pytensor.function([x_symbolic, beta_symbolic], [C_vals[-1], last_change_vals[-1]], trust_input=True)
pelt_pytensor_numba = pytensor.function(inputs=[x_symbolic, beta_symbolic], outputs=[C_vals[-1], last_change_vals[-1]], mode="NUMBA", trust_input=True)
pelt_bench = Benchmarker(
    functions=[pelt_pytensor, pelt_numba, pelt_jax, pelt_pytensor_numba], 
    names=['pelt_pytensor', 'pelt_numba', 'pelt_jax', 'pelt_pytensor_numba'],
    number=10
)
pelt_bench.run(
    inputs={
        "pelt_inputs": {"x": xs_std, "beta": 2. * np.log(len(xs_std))},
    }
)
pelt_bench.summary()
Loops Min (us) Max (us) Mean (us) StdDev (us) Median (us) IQR (us) OPS (Kops/s) Samples
pelt_pytensor pelt_inputs 10 11903.712500 12687.708301 12234.242133 245.989123 12224.875001 277.799901 0.081738 9
pelt_numba pelt_inputs 10 17.408299 22.204101 19.400820 1.585318 19.241701 1.000002 51.544213 5
pelt_jax pelt_inputs 10 64.616700 82.345799 71.339679 4.879389 70.279100 4.893748 14.017445 19
pelt_pytensor_numba pelt_inputs 10 2330.925001 2496.158300 2424.299980 56.514842 2435.225001 61.833399 0.412490 5
outputs = pelt_numba(xs_std, 2. * np.log(len(xs_std)))
def plot_pelt_diagnostics(x, cps, C):
    """
    Diagnostic plots for PELT changepoint detection.
    
    Args:
        x: 1D array, original time series
        C: 1D array, cumulative DP cost from pelt()
        cps: list of changepoint indices (sorted ascending)
    """
    n = len(x)
    cps_full = [0] + cps + [n]

    # Segment means, std, SSE
    segment_means = []
    segment_stds = []
    segment_costs = []
    for start, end in zip(cps_full[:-1], cps_full[1:]):
        seg = x[start:end]
        mean = np.mean(seg)
        std = np.std(seg)
        cost = np.sum((seg - mean)**2)
        segment_means.append(mean)
        segment_stds.append(std)
        segment_costs.append(cost)

    # Step function for segment mean
    mean_step = np.zeros(n)
    for i, (start, end) in enumerate(zip(cps_full[:-1], cps_full[1:])):
        mean_step[start:end] = segment_means[i]

    # Step function for segment std
    std_step = np.zeros(n)
    for i, (start, end) in enumerate(zip(cps_full[:-1], cps_full[1:])):
        std_step[start:end] = segment_stds[i]

    if len(x) < 20:
        title1 = "<span style='color: red;'>Warning</span>: Sample size is small - Detected Changepoints"
    else:
        title1 = "Detected Changepoints"

    fig = make_subplots(
        rows=4, 
        cols=1,
        subplot_titles=(title1, "Average Shifts", "Variability Shifts", "Cumulative Cost")
    )

    fig.add_trace(
        go.Scatter(
            x = np.arange(len(x)),
            y = x,
            line_color="royalblue",
            name = "Actuals",
            mode="lines",
            showlegend=False,
            hovertemplate="<b>Time Point</b>: %{x}<br><b>Actual</b>: %{y}"
        ),
        row=1, col=1
    )

    for cp in cps:
        fig.add_vline(x=cp, line_dash='dash', line_color="red", row=1, col=1)

    fig.add_trace(
        go.Scatter(
            x = np.arange(len(x)),
            y = x,
            name = "Actuals",
            mode="lines",
            line_color="rgba(105, 105, 105, 0.25)",
            showlegend=False,
            hoverinfo="skip"
        ),
        row=2, col=1
    )

    fig.add_trace(
        go.Scatter(
            x = np.arange(len(x)),
            y = mean_step,
            name = "Average",
            line_color="royalblue",
            showlegend=False,
            hovertemplate="<b>Time Point</b>: %{x}<br><b>Average</b>: %{y}"
        ),
        row=2, col=1
    )

    fig.add_trace(
        go.Scatter(
            x = np.arange(len(x)),
            y = std_step,
            name = "Standard Deviation",
            line_color="royalblue",
            showlegend=False,
            hovertemplate="<b>Time Point</b>: %{x}<br><b>Standard Deviation</b>: %{y}"
        ),
        row=3, col=1
    )

    fig.add_trace(
        go.Scatter(
            x = np.arange(len(x)),
            y = C,
            name = "Cumulative Cost",
            line_color="royalblue",
            showlegend=False,
            hovertemplate="<b>Time Point</b>: %{x}<br><b>Cost</b>: %{y}"
        ),
        row=4, col=1
    )

    for cp in cps:
        fig.add_vline(x=cp, line_dash='dash', line_color="red", row=4, col=1)

    return fig.update_layout(height=1000, width=1200, template="plotly_dark")
def get_changepoints(last_change, n):
    """
    Backtrack changepoints from last_change array.
    
    Args:
        last_change: array from pelt()
        n: length of input series

    Returns:
        list of changepoint indices (sorted ascending)
    """
    cps = []
    t = n
    while t > 0:
        s = int(last_change[t])
        if s <= 0:
            break
        cps.append(s)
        t = s
    return list(reversed(cps))
cps = get_changepoints(outputs[1], n=len(xs_std))
plot_pelt_diagnostics(xs, cps, outputs[0])

Kalman Filter Algorithms#

Linear Gaussian Kalman Filter#

@jit(nopython=True)
def atrocious_kalman_filter_numba(z, F, H, Q, R, x0, P0):
    """
    This implementation of the Kalman filter is Atrocious and in standard Python would be a 
    BIG NO-NO. That being said this version SIGNIFICANTLY reduces Numba Compilation time. 
    
    Linear Gaussian Kalman filter algorithm

    Parameters
    ----------
    z: np.ndarray
        shape (T, m)  - observations
    F: np.ndarray
        state transition matrix - shape (n, n)
    H: np.ndarray
        observation/design matrix - shape (m, n)
    Q: np.ndarray
        process noise covariance - shape (n, n)
    R: np.ndarray
        observation noise covariance - shape (m, m)
    x0: np.ndarray
        initial state mean - shape (n,)
    P0: np.ndarray
        initial state covariance - shape (n, n)

    Returns
    -------
    xs: np.ndarray
        shape (T, n)   - filtered state means
    Ps: np.ndarray
        shape (T, n, n) - filtered state covariances
    """
    T = z.shape[0]
    m = z.shape[1]
    n = x0.shape[0]

    xs = np.empty((T, n), dtype=np.float32)
    Ps = np.empty((T, n, n), dtype=np.float32)

    # local working arrays
    x = np.empty(n, dtype=np.float32)
    for i in range(n):
        x[i] = x0[i]
    P = np.empty((n, n), dtype=np.float32)
    for i in range(n):
        for j in range(n):
            P[i, j] = P0[i, j]

    # temporary matrices/vectors
    x_pred = np.empty((T, n), dtype=np.float32)
    P_pred = np.empty((T, n, n), dtype=np.float32)
    y = np.empty(m, dtype=np.float32)
    S = np.empty((m, m), dtype=np.float32)
    K = np.empty((n, m), dtype=np.float32)
    I_n = np.eye(n, dtype=np.float32)

    for t in range(T):
        # === Predict ===
        # x_pred = F @ x
        for i in range(n):
            s = 0.0
            for j in range(n):
                s += F[i, j] * x[j]
            x_pred[t, i] = s

        # P_pred = F @ P @ F.T + Q
        # temp = F @ P
        temp = np.empty((n, n), dtype=np.float32)
        for i in range(n):
            for j in range(n):
                s = 0.0
                for k in range(n):
                    s += F[i, k] * P[k, j]
                temp[i, j] = s
        # P_pred = temp @ F.T
        for i in range(n):
            for j in range(n):
                s = 0.0
                for k in range(n):
                    s += temp[i, k] * F[j, k]   # F.T[k, j] = F[j, k]
                P_pred[t, i, j] = s + Q[i, j]

        # === Update ===
        # y = z[t] - H @ x_pred
        for i in range(m):
            s = 0.0
            for j in range(n):
                s += H[i, j] * x_pred[t, j]
            y[i] = z[t, i] - s

        # S = H @ P_pred @ H.T + R
        # temp2 = H @ P_pred
        temp2 = np.empty((m, n), dtype=np.float32)
        for i in range(m):
            for j in range(n):
                s = 0.0
                for k in range(n):
                    s += H[i, k] * P_pred[t, k, j]
                temp2[i, j] = s
        # S = temp2 @ H.T
        for i in range(m):
            for j in range(m):
                s = 0.0
                for k in range(n):
                    s += temp2[i, k] * H[j, k]  # H.T[k,j] = H[j,k]
                S[i, j] = s + R[i, j]

        # K = P_pred @ H.T @ inv(S)
        # first compute P_pred @ H.T  -> (n, m)
        P_Ht = np.empty((n, m), dtype=np.float32)
        for i in range(n):
            for j in range(m):
                s = 0.0
                for k in range(n):
                    s += P_pred[t, i, k] * H[j, k]  # H.T[k,j] = H[j,k]
                P_Ht[i, j] = s

        # invert S
        S_inv = np.linalg.inv(S)

        # K = P_Ht @ S_inv  (n,m) @ (m,m) -> (n,m)
        for i in range(n):
            for j in range(m):
                s = 0.0
                for k in range(m):
                    s += P_Ht[i, k] * S_inv[k, j]
                K[i, j] = s

        # x = x_pred + K @ y
        for i in range(n):
            s = 0.0
            for j in range(m):
                s += K[i, j] * y[j]
            x[i] = x_pred[t, i] + s

        # P = (I - K H) P_pred
        # compute (I - K H)
        KH = np.empty((n, n), dtype=np.float32)
        for i in range(n):
            for j in range(n):
                s = 0.0
                for k in range(m):
                    s += K[i, k] * H[k, j]
                KH[i, j] = s

        I_minus_KH = np.empty((n, n), dtype=np.float32)
        for i in range(n):
            for j in range(n):
                I_minus_KH[i, j] = I_n[i, j] - KH[i, j]

        # P = I_minus_KH @ P_pred
        for i in range(n):
            for j in range(n):
                s = 0.0
                for k in range(n):
                    s += I_minus_KH[i, k] * P_pred[t, k, j]
                P[i, j] = s

        # store results
        for i in range(n):
            xs[t, i] = x[i]
        for i in range(n):
            for j in range(n):
                Ps[t, i, j] = P[i, j]

    return xs, Ps, x_pred, P_pred
@jit(nopython=True)
def kalman_filter_numba(z, F, H, Q, R, x0, P0):
    """
    Linear Gaussian Kalman filter algorithm

    Parameters
    ----------
    z: np.ndarray
        shape (T, m)  - observations
    F: np.ndarray
        state transition matrix - shape (n, n)
    H: np.ndarray
        observation/design matrix - shape (m, n)
    Q: np.ndarray
        process noise covariance - shape (n, n)
    R: np.ndarray
        observation noise covariance - shape (m, m)
    x0: np.ndarray
        initial state mean - shape (n,)
    P0: np.ndarray
        initial state covariance - shape (n, n)

    Returns
    -------
    xs: np.ndarray
        shape (T, n)   - filtered state means
    Ps: np.ndarray
        shape (T, n, n) - filtered state covariances
    """
    T, m = z.shape
    n = x0.shape[0]

    xs = np.zeros((T, n), dtype=np.float32)
    Ps = np.zeros((T, n, n), dtype=np.float32)

    x_pred = np.zeros((T, n), dtype=np.float32)
    P_pred = np.zeros((T, n, n), dtype=np.float32)

    x = x0.copy()
    P = P0.copy()

    I = np.eye(n, dtype=np.float32)

    for t in range(T):
        # --- Predict ---
        x_pred[t] = F @ x
        P_pred[t] = F @ P @ F.T + Q

        # --- Update ---
        y = z[t] - H @ x_pred[t]
        S = H @ P_pred[t] @ H.T + R
        K = P_pred[t] @ H.T @ np.linalg.inv(S)

        x = x_pred[t] + K @ y
        P = (I - K @ H) @ P_pred[t]

        xs[t] = x
        Ps[t] = P

    return xs, Ps, x_pred, P_pred
@block
@jax.jit
def kalman_filter_jax(z, F, H, Q, R, x0, P0):
    """
    Linear Gaussian Kalman filter algorithm

    Parameters
    ----------
    z: np.ndarray
        shape (T, m)  - observations
    F: np.ndarray
        state transition matrix - shape (n, n)
    H: np.ndarray
        observation/design matrix - shape (m, n)
    Q: np.ndarray
        process noise covariance - shape (n, n)
    R: np.ndarray
        observation noise covariance - shape (m, m)
    x0: np.ndarray
        initial state mean - shape (n,)
    P0: np.ndarray
        initial state covariance - shape (n, n)

    Returns
    -------
    xs: jnp.ndarray
        shape (T, n)   - filtered state means
    Ps: jnp.ndarray
        shape (T, n, n) - filtered state covariances
    """

    n = x0.shape[0]
    I = jnp.eye(n)
    X_pred_init = jnp.zeros((1,))
    P_pred_init = jnp.zeros((1, 1,))

    def step(carry, z_t):
        x, P, _, _ = carry

        # --- Predict ---
        x_pred = F @ x
        P_pred = F @ P @ F.T + Q

        # --- Update ---
        y = z_t - H @ x_pred
        S = H @ P_pred @ H.T + R
        K = P_pred @ H.T @ jnp.linalg.inv(S)

        x_new = x_pred + K @ y
        P_new = (I - K @ H) @ P_pred

        return (x_new, P_new, x_pred, P_pred), (x_new, P_new, x_pred, P_pred)

    # run scan
    (_, _, _, _), (xs, Ps, x_pred, P_pred) = jax.lax.scan(step, (x0, P0, X_pred_init, P_pred_init), z)

    return xs, Ps, x_pred, P_pred
z_symbolic = pt.matrix("z")
F_symbolic = pt.matrix("F")
H_symbolic = pt.matrix("H")
Q_symbolic = pt.matrix("Q")
R_symbolic = pt.matrix("R")
x0_symbolic = pt.vector("x0")
P0_symbolic = pt.matrix("P0")

n = x0_symbolic.shape[0]
I = pt.eye(n)
X_pred_init = pt.zeros_like(x0_symbolic)
P_pred_init = pt.zeros_like(P0_symbolic)

N_STEPS = 500

def step(z_t, x, P, x_pred, P_pred, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, I):

    # --- Predict ---
    x_pred = F_symbolic @ x
    P_pred = F_symbolic @ P @ F_symbolic.T + Q_symbolic

    # --- Update ---
    y = z_t - H_symbolic @ x_pred
    S = H_symbolic @ P_pred @ H_symbolic.T + R_symbolic
    K = P_pred @ H_symbolic.T @ pt.linalg.inv(S)

    x_new = x_pred + K @ y
    P_new = (I - K @ H_symbolic) @ P_pred

    return x_new, P_new, x_pred, P_pred

# run scan
(xs, Ps, x_pred, P_pred), _ = pytensor.scan(
    fn=step,
    outputs_info=[x0_symbolic, P0_symbolic, X_pred_init, P_pred_init],
    sequences=[z_symbolic],
    non_sequences=[F_symbolic, H_symbolic, Q_symbolic, R_symbolic, I],
    n_steps=N_STEPS
)

kalman_filter_pytensor = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], trust_input=True)

kalman_filter_pytensor_numba = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], mode="NUMBA", trust_input=True)
kalman_filter_pytensor_jax = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], mode="JAX", trust_input=True)
T = 500
F = np.array([[1.0]]).astype(np.float32)
H = np.array([[1.0]]).astype(np.float32)
Q = np.array([[0.01]]).astype(np.float32)
R = np.array([[0.1]]).astype(np.float32)
x0 = np.array([0.0]).astype(np.float32)
P0 = np.array([[1.0]]).astype(np.float32)

true = 1.0
z = (true + 0.4*np.random.randn(T)).reshape(T, 1).astype(np.float32)
kalman_filter_bench = Benchmarker(
    functions=[kalman_filter_pytensor, atrocious_kalman_filter_numba, kalman_filter_numba, kalman_filter_jax, kalman_filter_pytensor_numba, block(kalman_filter_pytensor_jax)], 
    names=['kalman_filter_pytensor', 'atrocious_kalman_filter_numba', 'kalman_filter_numba', 'kalman_filter_jax', 'kalman_filter_pytensor_numba', 'kalman_filter_pytensor_jax'],
    number=10
)
kalman_filter_bench.run(
    inputs={
        "kalman_filter_inputs": {"z": z, "F": F, "H": H, "Q": Q, "R": R, "x0": x0, "P0": P0},
    }
)
kalman_filter_bench.summary()
/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/tmpmbe0qi_u:33: NumbaWarning:

Cannot cache compiled function "scan" as it uses dynamic globals (such as ctypes pointers and large global arrays)
Loops Min (us) Max (us) Mean (us) StdDev (us) Median (us) IQR (us) OPS (Kops/s) Samples
kalman_filter_pytensor kalman_filter_inputs 10 5675.591601 6334.566601 5949.072047 159.617745 5937.045900 191.600001 0.168093 17
atrocious_kalman_filter_numba kalman_filter_inputs 10 297.462501 330.600000 314.415020 11.425574 315.533401 14.287402 3.180510 5
kalman_filter_numba kalman_filter_inputs 10 641.612499 718.120800 684.415800 32.973390 706.012500 62.024998 1.461100 5
kalman_filter_jax kalman_filter_inputs 10 305.941701 368.020899 333.580105 18.368842 329.077049 22.434351 2.997781 18
kalman_filter_pytensor_numba kalman_filter_inputs 10 853.600001 933.024997 897.570000 33.462131 918.749999 61.308398 1.114119 5
kalman_filter_pytensor_jax kalman_filter_inputs 10 304.041599 372.449998 330.685825 17.271642 327.668698 24.356249 3.024018 20
xs, Ps, x_pred, P_pred = kalman_filter_jax(z, F, H, Q, R, x0, P0)
def compute_pred_intervals(z, x_pred, P_pred, H, R, zscore=1.96):
    T = z.shape[0]
    m = H.shape[0]
    mean = np.zeros((T, m))
    lower = np.zeros((T, m))
    upper = np.zeros((T, m))
    outside = np.zeros(T, dtype=np.bool_)

    for t in range(T):
        mean[t] = H @ x_pred[t]
        S = H @ P_pred[t] @ H.T + R
        std = np.sqrt(np.diag(S))
        lower[t] = mean[t] - zscore * std
        upper[t] = mean[t] + zscore * std

        # check coverage of actual obs
        outside[t] = np.any((z[t] < lower[t]) | (z[t] > upper[t]))

    coverage = 1 - outside.mean()
    return mean, lower, upper, coverage
mean, lower, upper, coverage = compute_pred_intervals(z, x_pred, P_pred, H, R)
coverage
np.float64(0.91)
fig= go.Figure()
fig.add_traces(
    [
        go.Scatter(
            x = np.arange(T),
            y = z.ravel(),
            mode="markers",
            marker_color = "royalblue",
            name = "actuals"
        ),
        go.Scatter(
            x = np.arange(T),
            y = xs.ravel(),
            mode = "lines",
            marker_color = "orange",
            name = "filtered mean"
        ),
        go.Scatter(
                name="", 
                x=np.arange(T), 
                y=upper.ravel(), 
                mode="lines", 
                marker=dict(color="#eb8c34"), 
                line=dict(width=0), 
                legendgroup="95% CI",
                showlegend=False
            ),
            go.Scatter(
                name="95% CI", 
                x=np.arange(T), 
                y=lower.ravel(), 
                mode="lines", marker=dict(color="#eb8c34"), 
                line=dict(width=0), 
                legendgroup="95% CI", 
                fill='tonexty', 
                fillcolor='rgba(235, 140, 52, 0.2)'
            ),

    ]
)
fig.update_layout(
    xaxis=dict(
        title = "Time Index",
    ),
    yaxis=dict(
        title = "y"
    ),
    template = "plotly_dark"
)

Non-linear Kalman Filter#

@jit(nopython=True)
def loglik_poisson_numba(s, y):
    """Poisson Log Likelihood"""
    mu = np.exp(s)
    return y * np.log(mu + 1e-30) - mu - math.lgamma(y + 1.0) # numba does not support scipy.special gammaln

@jit(nopython=True)
def particle_filter_1d_predict_numba(A, Q, x0_mean, x0_std, ys, N=1000, seed=2):
    """
    1D particle filter.
    
    Parameters
    ----------
    A: float
        State transition
    Q: float
        Process covariance
    x0_mean: float
        Prior mean for the latent state
    x0_std: float
        Prior standard deviation 
    ys: np.ndarray
        observations
    N: int
        number of particles
    seed: int
        rng seed for reproducibility

    Returns
    -------
    filtered_means: np.ndarray
        The filtered mean for the latent state 
    filtered_vars: np.ndarray
        The filtered variance for the latent state
    pred_means: np.ndarray
        observation predicted mean 
    """
    np.random.seed(seed)
    T = ys.shape[0]
    particles = np.random.normal(x0_mean, x0_std, size=N)
    weights = np.ones(N) / N

    filtered_means = np.zeros(T)
    filtered_vars = np.zeros(T)
    pred_means = np.zeros(T)

    for t in range(T):
        y = ys[t]

        # propagate (vectorized)
        particles = A * particles + np.random.normal(0, np.sqrt(Q), size=N)

        # update weights
        logw = np.zeros(N)
        for i in range(N):
            logw[i] = loglik_poisson_numba(particles[i], y)
        logw = logw - np.max(logw)
        weights *= np.exp(logw)
        weights /= np.sum(weights) + 1e-12

        # filtered moments
        mean_t = np.sum(weights * particles)
        var_t = np.sum(weights * (particles - mean_t) ** 2)

        # predictive mean
        pred_mean = np.sum(weights * np.exp(particles))

        filtered_means[t] = mean_t
        filtered_vars[t] = var_t
        pred_means[t] = pred_mean

        # resample (multinomial resampling) because numba doesn't support np.random.choice
        cumulative_sum = np.cumsum(weights)
        cumulative_sum[-1] = 1.0  # guard against rounding error
        indices = np.searchsorted(cumulative_sum, np.random.rand(N))

        particles = particles[indices]
        weights = np.ones(N) / N

    return filtered_means, filtered_vars, pred_means
# Had to fix the loglikelihood and key to use benchmarker as is
def loglik_poisson_jax(s, y):
    """Poisson Log Likelihood"""
    mu = jnp.exp(s)
    return y * jnp.log(mu + 1e-30) - mu - gammaln(y + 1.0)

@block
@partial(jax.jit, static_argnums=5)
def particle_filter_1d_predict_jax(
    A, Q, x0_mean, x0_std, ys, N=1000,
):
    """
    1D particle filter.
    
    Parameters
    ----------
    A: float
        State transition
    Q: float
        Process covariance
    x0_mean: float
        Prior mean for the latent state
    x0_std: float
        Prior standard deviation 
    ys: np.ndarray
        observations
    loglik_fn: function
        The log likelihood function
    key: 
        JAX prng key
    N: int
        number of particles

    Returns
    -------
    filtered_means: jnp.ndarray
        The filtered mean for the latent state 
    filtered_vars: jnp.ndarray
        The filtered variance for the latent state
    pred_means: jnp.ndarray
        observation predicted mean 
    """
    key = jax.random.PRNGKey(0)
    T = ys.shape[0]
    particles = jax.random.normal(key, (N,)) * x0_std + x0_mean # init particles from gaussian priors
    weights = jnp.ones(N) / N # particle weights, all particles equally likely prior

    def body_fun(carry, t):
        particles, weights, key = carry
        y = ys[t]

        # propagate
        key, subkey = jax.random.split(key)
        particles = A * particles + jax.random.normal(subkey, (N,)) * jnp.sqrt(Q) # state transition model

        # update weights
        logw = jax.vmap(lambda x: loglik_poisson_jax(x, y))(particles) # update particles in parallel
        logw = logw - jnp.max(logw) # avoid overflow
        weights = weights * jnp.exp(logw) # old weights times the likelihood
        weights /= jnp.sum(weights) + 1e-12 # normalize so that weights sum to 1

        # filtered moments
        mean_t = jnp.sum(weights * particles) # posterior mean of latent state
        var_t = jnp.sum(weights * (particles - mean_t)**2) # posterior variance of latent state

        # predictive mean
        pred_mean = jnp.sum(weights * jnp.exp(particles))

        # resample to prevent dominant particles
        key, subkey = jax.random.split(key)
        indices = jax.random.choice(subkey, N, p=weights, shape=(N,))
        particles = particles[indices]
        weights = jnp.ones(N) / N

        carry = (particles, weights, key)
        out = (mean_t, var_t, pred_mean)
        return carry, out

    _, outputs = jax.lax.scan(body_fun, (particles, weights, key), jnp.arange(T))
    return outputs
from pytensor.tensor.random.utils import RandomStream

# Random stream for PyTensor
srng = RandomStream(seed=42)

# Poisson log-likelihood
def loglik_poisson_pytensor(s, y):
    mu = pt.exp(s)
    return y.flatten() * pt.log(mu + 1e-30) - mu - pt.gammaln(y.flatten() + 1.0)
ys_symbolic = pt.vector("ys")
x0_mean_symbolic = pt.scalar("x0_mean")
x0_std_symbolic = pt.scalar("x0_std")
A_symbolic = pt.scalar("A")
Q_symbolic = pt.scalar("Q")
N_symbolic = pt.scalar("N", dtype='int64')

N_STEPS = 300

# Initialize particles and weights
particles_init = srng.normal(size=(N_symbolic,)) * x0_std_symbolic + x0_mean_symbolic
weights_init = pt.ones((N_symbolic,)) / N_symbolic 

# Step function for scan
def step(y_t, particles_prev, weights_prev, A_symbolic, Q_symbolic):
    # Propagate particles
    particles_prop = A_symbolic * particles_prev + srng.normal(size=(N_symbolic,)) * pt.sqrt(Q_symbolic)

    # Update weights
    # logw = pt.stack([loglik_poisson_pytensor(p, y_t) for p in particles_prop])
    logw = loglik_poisson_pytensor(particles_prop, y_t)
    logw_stable = logw - pt.max(logw)
    w_unnorm = weights_prev * pt.exp(logw_stable)
    w = w_unnorm / (pt.sum(w_unnorm) + 1e-12) 

    # Filtered moments
    mean_t = pt.sum(w * particles_prop)
    var_t = pt.sum(w * (particles_prop - mean_t) ** 2)
    pred_mean = pt.sum(w * pt.exp(particles_prop))

    # Resample particles
    idx = srng.choice(size=(N_symbolic,), a=N_symbolic, p=w) 
    particles_resampled = particles_prop[idx]
    weights_resampled = pt.ones((N_symbolic,)) / N_symbolic

    # Return flat tuple
    return particles_resampled, weights_resampled, mean_t, var_t, pred_mean

# first two are recurrent, rest are collected
outputs_info = [
    particles_init,
    weights_init,
    None,
    None,
    None
]

(particles_seq, weights_seq, means_seq, vars_seq, preds_seq), updates = pytensor.scan(
    fn=step,
    sequences=[ys_symbolic],
    outputs_info=outputs_info,
    non_sequences=[A_symbolic, Q_symbolic],
    n_steps=N_STEPS
)

particle_filter_1d_predict_pytensor = pytensor.function(
    [A_symbolic, Q_symbolic, x0_mean_symbolic, x0_std_symbolic, ys_symbolic, N_symbolic],
    [means_seq, vars_seq, preds_seq],
    updates=updates,
    no_default_updates=True,
    trust_input=True
)

particle_filter_1d_predict_pytensor_numba = pytensor.function(
    [A_symbolic, Q_symbolic, x0_mean_symbolic, x0_std_symbolic, ys_symbolic, N_symbolic],
    [means_seq, vars_seq, preds_seq],
    updates=updates,
    no_default_updates=True,
    mode="NUMBA", 
    trust_input=True
)
key = jax.random.PRNGKey(0)
T = 300
A = 0.95
Q = 0.05
rng = np.random.RandomState(1)

target_mean = 10.0
latent_var = Q / (1 - A**2)
x0_mean = np.log(target_mean) - 0.5 * latent_var
x0_std = 1.0

# Simulate latent
x = np.zeros(T)
x[0] = rng.normal() * np.sqrt(latent_var) + x0_mean
for t in range(1, T):
    x[t] = A * x[t-1] + rng.normal() * np.sqrt(Q)

ys = np.array(rng.poisson(np.exp(x)), dtype=np.float32)
nonlinear_kalman_filter_bench = Benchmarker(
    functions=[particle_filter_1d_predict_pytensor, particle_filter_1d_predict_numba, particle_filter_1d_predict_jax, particle_filter_1d_predict_pytensor_numba,], 
    names=['particle_filter_1d_predict_pytensor', 'particle_filter_1d_predict_numba', 'particle_filter_1d_predict_jax', 'particle_filter_1d_predict_pytensor_numba',],
    number=5 # This takes a while to run reducing number of loops
)
nonlinear_kalman_filter_bench.run(
    inputs={
        "kalman_filter_inputs": {"A": A, "Q": Q, "x0_mean": x0_mean, "x0_std": x0_std, "ys": ys, "N": 2000},
    }
)
nonlinear_kalman_filter_bench.summary()
Loops Min (us) Max (us) Mean (us) StdDev (us) Median (us) IQR (us) OPS (Kops/s) Samples
particle_filter_1d_predict_pytensor kalman_filter_inputs 5 728815.783397 742948.366801 734642.386761 5002.169131 733596.250002 6332.200003 0.001361 5
particle_filter_1d_predict_numba kalman_filter_inputs 5 48678.408400 48904.825002 48782.921640 89.835338 48744.633398 160.525000 0.020499 5
particle_filter_1d_predict_jax kalman_filter_inputs 5 33170.266601 33644.350001 33410.141721 155.134231 33411.583403 125.924998 0.029931 5
particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 5 676612.958201 678200.574999 677432.478319 647.987321 677334.924997 1278.466597 0.001476 5

Slightly different estimates because I couldn’t reproduce 1:1

filtered_means, filtered_vars, pred_means = particle_filter_1d_predict_numba(
    A, Q, x0_mean, x0_std, ys, N=2000, seed=2
)
fig = make_subplots(
    rows=2, cols=1,
    subplot_titles=("Observation Predictions", "Latent State Estimation"),
    vertical_spacing=0.07,
    shared_xaxes=True
)

fig.add_traces(
    [
        go.Scatter(
            x = np.arange(T),
            y = ys,
            mode = "markers",
            marker_color = "cornflowerblue",
            name = "actuals"
        ),
        go.Scatter(
            x = np.arange(T),
            y = pred_means,
            mode = "lines",
            marker_color = "#eb8c34",
            name = "predicted mean"
        ),
        go.Scatter(
                name="", 
                x=np.arange(T), 
                y=pred_means + 2*jnp.sqrt(pred_means), 
                mode="lines", 
                marker=dict(color="#eb8c34"), 
                line=dict(width=0), 
                legendgroup="predicted mean 95% CI",
                showlegend=False
            ),
            go.Scatter(
                name="predicted mean 95% CI", 
                x=np.arange(T), 
                y=pred_means - 2*jnp.sqrt(pred_means), 
                mode="lines", marker=dict(color="#eb8c34"), 
                line=dict(width=0), 
                legendgroup="predicted mean 95% CI", 
                fill='tonexty', 
                fillcolor='rgba(235, 140, 52, 0.2)'
            ),
    ],
    rows=1, cols=1
)

fig.add_traces(
    [
        go.Scatter(
            x = np.arange(T),
            y = x,
            mode = "lines",
            marker_color = "cornflowerblue",
            name = "true latent state"
        ),
        go.Scatter(
            x = np.arange(T),
            y = filtered_means,
            mode = "lines",
            marker_color = "#eb8c34",
            name = "filtered state mean"
        ),
        go.Scatter(
                name="", 
                x=np.arange(T), 
                y=filtered_means + 2*jnp.sqrt(filtered_vars), 
                mode="lines", 
                marker=dict(color="#eb8c34"), 
                line=dict(width=0), 
                legendgroup="filtered state mean 95% CI",
                showlegend=False
            ),
            go.Scatter(
                name="filtered state mean 95% CI", 
                x=np.arange(T), 
                y=filtered_means - 2*jnp.sqrt(filtered_vars), 
                mode="lines", marker=dict(color="#eb8c34"), 
                line=dict(width=0), 
                legendgroup="filtered state mean 95% CI", 
                fill='tonexty', 
                fillcolor='rgba(235, 140, 52, 0.2)'
            ),
    ],
    rows=2, cols=1
)

for i, yaxis in enumerate(fig.select_yaxes(), 1):
    legend_name = f"legend{i}"
    fig.update_layout({legend_name: dict(y=yaxis.domain[1], yanchor="top")}, showlegend=True)
    fig.update_traces(row=i, legend=legend_name)

fig.update_layout(height=1000, width=1200, template="plotly_dark")

fig.update_layout(
    legend1=dict(
        yanchor="top",
        y=1.0,
        xanchor="left",
        x=0,
        orientation="h"
    ),
    legend2=dict(
        yanchor="top",
        y=.465,
        xanchor="left",
        x=0,
        orientation="h"
    ),
    )