Imports#
import numpy as np
import jax.numpy as jnp
from numba import jit
import numba
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_mode
import timeit
import jax
import math
from jax.scipy.special import gammaln
from functools import partial
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import plotly.io as pio
pio.renderers.default = "notebook"
print("pytensor version:", pytensor.__version__)
print("jax version:", jax.__version__)
print("numba version:", numba.__version__)
pytensor version: 0+untagged.31343.g647570d.dirty
jax version: 0.7.2
numba version: 0.62.1
class Benchmarker:
"""
Benchmark a set of functions by timing execution and summarizing statistics.
Parameters
----------
functions : list of callables
List of callables to benchmark.
names : list of str, optional
Names corresponding to each function. Default is ['func_0', 'func_1', ...].
number : int or None, optional
Number of loops per timing. If None, auto-calibrated via Timer.autorange().
Default is None.
repeat : int, optional
Number of repeats for timing. Default is 7.
target_time : float, optional
Target duration in seconds for auto-calibration. Default is 0.2.
Attributes
----------
results : dict
Mapping from function names to a dict with keys:
- 'raw_us': numpy.ndarray of raw timings in microseconds
- 'loops': number of loops used per timing
Methods
-------
run()
Auto-calibrate (if needed) and run timings for all functions.
summary(unit='us') -> pandas.DataFrame
Return a summary DataFrame with statistics converted to the given unit.
raw(name=None) -> dict or numpy.ndarray
Return raw timing data in microseconds for a specific function or all.
_convert_times(times, unit) -> numpy.ndarray
Convert an array of times from microseconds to the specified unit.
"""
def __init__(
self, functions, names=None, number=None, min_rounds=5, max_time=1.0, target_time=0.2
):
self.functions = functions
self.names = names or [f"func_{i}" for i in range(len(functions))]
self.number = number
self.min_rounds = min_rounds
self.max_time = max_time
self.target_time = target_time
self.results = {}
def run(self, inputs: dict[str, dict]):
"""
Auto-calibrate loop count and sample rounds if needed, then time each function.
"""
for name, func in zip(self.names, self.functions):
for input_name, kwargs in inputs.items():
timer = timeit.Timer(partial(func, **kwargs))
# Calibrate loops
if self.number is None:
loops, calib_time = timer.autorange()
else:
loops = self.number
calib_time = timer.timeit(number=loops)
# Determine rounds based on max_time and min_rounds
if self.max_time is not None:
rounds = max(self.min_rounds, int(np.ceil(self.max_time / calib_time)))
else:
rounds = self.min_rounds
raw_round_times = np.array(timer.repeat(repeat=rounds, number=loops))
# Convert to microseconds per single execution
raw_us = raw_round_times / loops * 1e6
self.results[(name, input_name)] = {
"raw_us": raw_us,
"loops": loops,
"rounds": rounds,
}
def summary(self, unit="us"):
"""
Summarize benchmark statistics in a DataFrame.
Parameters
----------
unit : {'us', 'ms', 'ns', 's'}, optional
Unit for output times. 'us' means microseconds, 'ms' milliseconds,
'ns' nanoseconds, 's' seconds. Default is 'us'.
Returns
-------
pandas.DataFrame
Summary with columns:
Name, Loops, Min, Max, Mean, StdDev, Median, IQR (all in given unit),
OPS (Kops/unit), Samples.
"""
records = []
indexes = []
for name, data in self.results.items():
raw_us = data["raw_us"]
# Convert to target unit
times = self._convert_times(raw_us, unit)
if isinstance(name, tuple) and len(name) > 1:
indexes.append(name)
elif isinstance(name, tuple) and len(name) == 1:
indexes.append(name[0])
else:
indexes.append(name)
stats = {
"Loops": data["loops"],
f"Min ({unit})": np.min(times),
f"Max ({unit})": np.max(times),
f"Mean ({unit})": np.mean(times),
f"StdDev ({unit})": np.std(times),
f"Median ({unit})": np.median(times),
f"IQR ({unit})": np.percentile(times, 75) - np.percentile(times, 25),
"OPS (Kops/s)": 1e3 / (np.mean(raw_us)),
"Samples": len(raw_us),
}
records.append(stats)
if all(isinstance(idx, tuple) for idx in indexes):
index = pd.MultiIndex.from_tuples(indexes)
else:
index = pd.Index(indexes)
return pd.DataFrame(records, index=index)
def raw(self, name=None):
"""
Get raw timing data in microseconds.
Parameters
----------
name : str, optional
If given, returns the raw_us array for that function. Otherwise returns
a dict of all raw results.
Returns
-------
numpy.ndarray or dict
"""
if name:
return self.results.get(name, {}).get("raw_us")
return {n: d["raw_us"] for n, d in self.results.items()}
def _convert_times(self, times, unit):
"""
Convert an array of times from microseconds to the specified unit.
Parameters
----------
times : array-like
Times in microseconds.
unit : {'us', 'ms', 'ns', 's'}
Target unit: 'us' microseconds, 'ms' milliseconds,
'ns' nanoseconds, 's' seconds.
Returns
-------
numpy.ndarray
Converted times.
Raises
------
ValueError
If `unit` is not one of the supported options.
"""
unit = unit.lower()
if unit == "us":
factor = 1.0
elif unit == "ms":
factor = 1e-3
elif unit == "ns":
factor = 1e3
elif unit == "s":
factor = 1e-6
else:
raise ValueError(f"Unsupported unit: {unit}")
return times * factor
# def block(func):
# def inner(*args, **kwargs):
# return jax.block_until_ready(func(*args, **kwargs))
# return inner
# I believe the above will only work if the return is a single jnp.array (at least that is what AI thinks) I would appreciate your insights here, Adrian.
def block(func):
def inner(*args, **kwargs):
result = func(*args, **kwargs)
# Recursively block on all JAX arrays in result
jax.tree_util.tree_map(lambda x: x.block_until_ready() if hasattr(x, "block_until_ready") else None, result)
return result
return inner
# Set Pytensor to use float32
pytensor.config.floatX = "float32"
Introduction#
Baby Steps#
Fibonacci Algorithm#
N_STEPS = 100_000
b_symbolic = pt.vector("b", dtype="int32", shape=(1,))
def step(a, b):
return a + b, a
(outputs_a, outputs_b), _ = pytensor.scan(
fn=step,
outputs_info=[pt.ones(1, dtype="int32"), b_symbolic],
n_steps=N_STEPS
)
# compile function returning final a
fibonacci_pytensor = pytensor.function([b_symbolic], outputs_a[-1], trust_input=True)
fibonacci_pytensor_numba = pytensor.function([b_symbolic], outputs_a[-1], mode='NUMBA', trust_input=True)
@jit(nopython=True)
def fibonacci_numba(b):
a = np.ones(1, dtype=np.int32)
for _ in range(N_STEPS):
a[0], b[0] = a[0] + b[0], a[0]
return a[0]
@jax.jit
def fibonacci_jax(b):
a = jnp.array(1, dtype=np.int32)
for _ in range(N_STEPS):
a, b = a + b, a
return a
fibonacci_bench = Benchmarker(
functions=[fibonacci_pytensor, fibonacci_numba, fibonacci_jax, fibonacci_pytensor_numba],
names=['fibonacci_pytensor', 'fibonacci_numba', 'fibonacci_jax', 'fibonacci_pytensor_numba'],
number=10
)
fibonacci_bench.run(
inputs={
"fibonacci_inputs": {"b": np.ones(1, dtype=np.int32)},
}
)
fibonacci_bench.summary()
Loops | Min (us) | Max (us) | Mean (us) | StdDev (us) | Median (us) | IQR (us) | OPS (Kops/s) | Samples | ||
---|---|---|---|---|---|---|---|---|---|---|
fibonacci_pytensor | fibonacci_inputs | 10 | 54149.950002 | 55278.412500 | 54654.332500 | 434.902898 | 54807.945801 | 694.087401 | 0.018297 | 5 |
fibonacci_numba | fibonacci_inputs | 10 | 84.208400 | 105.558301 | 95.575969 | 4.793104 | 95.562500 | 5.033301 | 10.462881 | 13 |
fibonacci_jax | fibonacci_inputs | 10 | 7.162499 | 31.558401 | 14.425020 | 8.923158 | 12.233300 | 5.920898 | 69.323996 | 5 |
fibonacci_pytensor_numba | fibonacci_inputs | 10 | 2338.045801 | 2429.037497 | 2385.059159 | 34.438988 | 2392.570800 | 58.641701 | 0.419277 | 5 |
Element-wise multiplication Algorithm#
# a_symbolic = pt.vector("a", dtype="int32")
# b_symbolic = pt.vector("b", dtype="int32")
# def step(a_element, b_element):
# return a_element * b_element
# c, _ = pytensor.scan(
# fn=step,
# sequences=[a_symbolic, b_symbolic]
# )
# # compile function returning final a
# c_mode = get_mode("FAST_RUN").excluding("scan_push_out_seq")
# elementwise_multiply_pytensor = pytensor.function([a_symbolic, b_symbolic], c, trust_input=True, mode=c_mode)
# numba_mode = get_mode("NUMBA").excluding("scan_push_out_seq")
# elementwise_multiply_pytensor_numba = pytensor.function([a_symbolic, b_symbolic], c, mode=numba_mode, trust_input=True)
# Not sure this makes a difference
a_symbolic = pt.vector("a", dtype="int32")
b_symbolic = pt.vector("b", dtype="int32")
N_STEPS = 1000
def step(a_element, b_element):
return a_element * b_element
c, _ = pytensor.scan(
fn=step,
sequences=[a_symbolic, b_symbolic],
n_steps=N_STEPS
)
# compile function returning final a
c_mode = get_mode("FAST_RUN").excluding("scan_push_out_seq")
elementwise_multiply_pytensor = pytensor.function([a_symbolic, b_symbolic], c, trust_input=True, mode=c_mode)
numba_mode = get_mode("NUMBA").excluding("scan_push_out_seq")
elementwise_multiply_pytensor_numba = pytensor.function([a_symbolic, b_symbolic], c, mode=numba_mode, trust_input=True)
@jit(nopython=True)
def elementwise_multiply_numba(a, b):
n = a.shape[0]
c = np.empty(n, dtype=a.dtype)
for i in range(n):
c[i] = a[i] * b[i]
return c
@block
@jax.jit
def elementwise_multiply_jax(a, b):
n = a.shape[0]
c_init = jnp.empty(n, dtype=a.dtype)
def step(i, c):
return jax.lax.dynamic_update_index_in_dim(c, a[i] * b[i], i, axis=0)
c = jax.lax.fori_loop(0, n, step, c_init)
return c
a = np.random.normal(0, 1, (N_STEPS)).astype(np.int32)
b = np.random.normal(0, 1, (N_STEPS)).astype(np.int32)
elem_mult_bench = Benchmarker(
functions=[elementwise_multiply_pytensor, elementwise_multiply_numba, elementwise_multiply_jax, elementwise_multiply_pytensor_numba],
names=['elementwise_multiply_pytensor', 'elementwise_multiply_numba', 'elementwise_multiply_jax', 'elementwise_multiply_pytensor_numba'],
number=10
)
elem_mult_bench.run(
inputs={
"elem_mult_inputs": {"a": a, "b": b},
}
)
elem_mult_bench.summary()
Loops | Min (us) | Max (us) | Mean (us) | StdDev (us) | Median (us) | IQR (us) | OPS (Kops/s) | Samples | ||
---|---|---|---|---|---|---|---|---|---|---|
elementwise_multiply_pytensor | elem_mult_inputs | 10 | 450.595800 | 967.095900 | 492.991945 | 36.396199 | 491.412499 | 18.221901 | 2.028431 | 220 |
elementwise_multiply_numba | elem_mult_inputs | 10 | 0.366700 | 0.720800 | 0.394247 | 0.073563 | 0.379200 | 0.012497 | 2536.478255 | 21 |
elementwise_multiply_jax | elem_mult_inputs | 10 | 7.512499 | 10.391598 | 8.152427 | 0.621782 | 7.895901 | 0.593750 | 122.662853 | 55 |
elementwise_multiply_pytensor_numba | elem_mult_inputs | 10 | 34.662499 | 50.737502 | 39.280821 | 5.934219 | 37.408300 | 3.912600 | 25.457717 | 5 |
Changepoint Detection Algorithms#
Cumulative Sum (CUSUM) Algorithm#
@jit(nopython=True)
def cusum_adaptive_numba(x, alpha=0.01, k=0.5, h=5.0):
"""
Two-sided CUSUM with adaptive exponential moving average baseline.
Parameters
----------
x: np.ndarray
input signal
alpha: float
EMA smoothing factor (0 < alpha <= 1)
k: float
slack to avoid small changes triggering alarms
h: float
threshold for raising an alarm
Returns
-------
s_pos: np.ndarray
upper CUSUM stats
s_neg: np.ndarray
lower CUSUM stats
mu_t: np.ndarray
evolving baseline estimate
alarms_pos: np.ndarray
alarms for upward changes
alarms_neg: np.ndarray
alarms for downward changes
"""
n = x.shape[0]
s_pos = np.zeros(n, dtype=np.float32)
s_neg = np.zeros(n, dtype=np.float32)
mu_t = np.zeros(n, dtype=np.float32)
alarms_pos = np.zeros(n, dtype=np.bool_)
alarms_neg = np.zeros(n, dtype=np.bool_)
# Initialization
mu_t[0] = x[0]
for i in range(1, n):
# Update baseline (EMA)
mu_t[i] = alpha * x[i] + (1 - alpha) * mu_t[i-1]
# Update CUSUM stats
s_pos[i] = max(0.0, s_pos[i-1] + x[i] - mu_t[i] - k)
s_neg[i] = max(0.0, s_neg[i-1] - (x[i] - mu_t[i]) - k)
# Alarms
alarms_pos[i] = s_pos[i] > h
alarms_neg[i] = s_neg[i] > h
return s_pos, s_neg, mu_t, alarms_pos, alarms_neg
@block
@jax.jit
def cusum_adaptive_jax(x, alpha=0.01, k=0.5, h=5.0):
"""
Two-sided CUSUM with adaptive exponential moving average baseline.
Parameters
----------
x: jnp.ndarray
input signal
alpha: float
EMA smoothing factor (0 < alpha <= 1)
k: float
slack to avoid small changes triggering alarms
h: float
threshold for raising an alarm
Returns
-------
s_pos: jnp.ndarray
upper CUSUM stats
s_neg: jnp.ndarray
lower CUSUM stats
mu_t: jnp.ndarray
evolving baseline estimate
alarms_pos: jnp.ndarray
alarms for upward changes
alarms_neg: jnp.ndarray
alarms for downward changes
"""
def body(carry, x_t):
s_pos_prev, s_neg_prev, mu_prev = carry
# Update EMA baseline
mu_t = alpha * x_t + (1 - alpha) * mu_prev
# Update CUSUMs using updated baseline
s_pos = jnp.maximum(0.0, s_pos_prev + x_t - mu_t - k)
s_neg = jnp.maximum(0.0, s_neg_prev - (x_t - mu_t) - k)
new_carry = (s_pos, s_neg, mu_t)
output = (s_pos, s_neg, mu_t)
return new_carry, output
# Initialize: CUSUMs at 0, initial mean = first sample
s0 = (0.0, 0.0, x[0])
_, (s_pos_vals, s_neg_vals, mu_vals) = jax.lax.scan(body, s0, x)
# Thresholding
alarms_pos = s_pos_vals > h
alarms_neg = s_neg_vals > h
return s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg
x_symbolic = pt.vector("x")
alpha_symbolic = pt.scalar("alpha")
k_symbolic = pt.scalar("k")
h_symbolic = pt.scalar("h")
N_STEPS = 100 # Fixing this just incase but I don't think this is changing anything when we have sequences as input to scan
def step(x_t, s_pos_prev, s_neg_prev, mu_prev, alpha, k):
# Update EMA baseline
mu_t = alpha * x_t + (1 - alpha) * mu_prev
# Update CUSUMs using updated baseline
s_pos = pt.maximum(0.0, s_pos_prev + x_t - mu_t - k)
s_neg = pt.maximum(0.0, s_neg_prev - (x_t - mu_t) - k)
return s_pos, s_neg, mu_t
(s_pos_vals, s_neg_vals, mu_vals), updates = pytensor.scan(
fn=step,
outputs_info=[pt.constant(0., dtype="float32"), pt.constant(0., dtype="float32"), x_symbolic[0]],
non_sequences=[alpha_symbolic, k_symbolic],
sequences=[x_symbolic],
n_steps=N_STEPS
)
# Thresholding
alarms_pos = s_pos_vals > h_symbolic
alarms_neg = s_neg_vals > h_symbolic
cusum_adaptive_pytensor = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], trust_input=True)
cusum_adaptive_pytensor_numba = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], mode="NUMBA", trust_input=True)
cusum_adaptive_pytensor_jax = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], mode="JAX", trust_input=True)
xs1 = np.random.normal(80, 20, size=(int(N_STEPS/2)))
xs2 = np.random.normal(50, 20, size=(int(N_STEPS/2)))
xs = np.concat((xs1, xs2))
xs = xs.astype(np.float32)
xs = xs.astype(np.float32)
xs_std = (xs - xs.mean()) / xs.std()
cusum_bench = Benchmarker(
functions=[cusum_adaptive_pytensor, cusum_adaptive_numba, cusum_adaptive_jax, cusum_adaptive_pytensor_numba, block(cusum_adaptive_pytensor_jax)],
names=['cusum_adaptive_pytensor', 'cusum_adaptive_numba', 'cusum_adaptive_jax', 'cusum_adaptive_pytensor_numba', 'cusum_adaptive_pytensor_jax'],
number=10
)
cusum_bench.run(
inputs={
"cusum_inputs": {"x": xs, "alpha": 0.1, "k": 0.5, "h": 3.5},
}
)
cusum_bench.summary()
Loops | Min (us) | Max (us) | Mean (us) | StdDev (us) | Median (us) | IQR (us) | OPS (Kops/s) | Samples | ||
---|---|---|---|---|---|---|---|---|---|---|
cusum_adaptive_pytensor | cusum_inputs | 10 | 124.412501 | 226.245899 | 141.837420 | 8.902762 | 139.566700 | 10.325000 | 7.050326 | 681 |
cusum_adaptive_numba | cusum_inputs | 10 | 1.729201 | 2.174999 | 1.806571 | 0.150796 | 1.745898 | 0.018753 | 553.534780 | 7 |
cusum_adaptive_jax | cusum_inputs | 10 | 14.366599 | 18.899998 | 15.076731 | 0.942976 | 14.633298 | 0.738525 | 66.327378 | 36 |
cusum_adaptive_pytensor_numba | cusum_inputs | 10 | 24.316600 | 36.341700 | 27.204980 | 4.596065 | 25.033302 | 1.241703 | 36.757976 | 5 |
cusum_adaptive_pytensor_jax | cusum_inputs | 10 | 21.479101 | 27.295901 | 22.782493 | 1.642595 | 21.893748 | 1.734347 | 43.893352 | 30 |
outputs = cusum_adaptive_numba(xs_std, alpha=0.1, k=0.5, h=3.5)
fig = go.Figure()
fig.add_traces(
[
go.Scatter(
x = np.arange(len(xs)),
y = xs_std,
name="series"
),
go.Scatter(
x = np.arange(len(xs)),
y = outputs[0],
name="cum. positive devs."
),
go.Scatter(
x = np.arange(len(xs)),
y = outputs[1],
name="cum. negative devs."
),
go.Scatter(
x = np.arange(len(xs)),
y = outputs[2],
name="Exp. Mean"
),
go.Scatter(
x = np.arange(len(xs)),
y = outputs[3].astype(np.float16),
name="positive alarms"
),
go.Scatter(
x = np.arange(len(xs)),
y = outputs[4].astype(np.float16),
name="negative alarms"
),
]
)
fig.update_layout(
title = dict(
text = "CUSUM Change Point Detection Algorithm"
),
xaxis=dict(
title = "Time Index"
),
yaxis=dict(
title = "Standardized Series Scaled"
),
legend=dict(
yanchor="top",
y=1.1,
xanchor="left",
x=0,
orientation="h"
),
template="plotly_dark"
)
Pruned Exact Linear Time (PELT) Algorithm#
@jit(nopython=True)
def segment_cost_numba(S1, S2, i, j):
"""Cost of segment x[i:j], SSE around mean"""
n = j - i
sum_x = S1[j] - S1[i]
sum_x2 = S2[j] - S2[i]
if n > 0:
return sum_x2 - (sum_x ** 2) / n
else:
return np.inf
@jit(nopython=True)
def pelt_numba(x, beta=10.0):
"""
Pruned Exact Linear Time algorithm for change point detection
Parameters
----------
x: np.ndarray
The timeseries signal
beta: float
Penalty of segmenting the series
Returns
-------
C: np.ndarray
The best costs up to segment t
last_change: np.ndarray
The last change point up to segment t
"""
n = len(x)
# cumulative sums for cost
S1 = np.empty(n+1, dtype=np.float32)
S2 = np.empty(n+1, dtype=np.float32)
S1[0], S2[0] = 0.0, 0.0
for i in range(1, n+1):
S1[i] = S1[i-1] + x[i-1]
S2[i] = S2[i-1] + x[i-1]**2
# DP arrays
C = np.full((n+1,), np.inf)
C[0] = -beta
last_change = np.full((n+1,), -1)
min_size = 3
for t in range(1, n+1):
costs = np.full(n, np.inf)
for s in range(n):
if s < t and (t - s) >= min_size:
costs[s] = C[s] + segment_cost_numba(S1, S2, s, t) + beta
best_s = np.argmin(costs)
C[t] = costs[best_s]
last_change[t] = best_s
return C, last_change
def segment_cost_jax(S1, S2, i, j):
"""Cost of segment x[i:j], SSE around mean"""
n = j - i
sum_x = S1[j] - S1[i]
sum_x2 = S2[j] - S2[i]
return jnp.where(n > 0, sum_x2 - (sum_x ** 2) / n, jnp.inf)
@block
@jax.jit
def pelt_jax(x, beta=10.0):
"""
Pruned Exact Linear Time algorithm for change point detection
Parameters
----------
x: np.ndarray
The timeseries signal
beta: float
Penalty of segmenting the series
Returns
-------
C: jnp.ndarray
The best costs up to segment t
last_change: jnp.ndarray
The last change point up to segment t
"""
n = len(x)
# cumulative sums for cost
S1 = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(x)])
S2 = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(x**2)])
# DP arrays
C = jnp.full((n+1,), jnp.inf)
C = C.at[0].set(-beta)
last_change = jnp.full((n+1,), -1)
min_size = 3
s_all = jnp.arange(n) # all possible candidates
def body(t, carry):
C, last_change = carry
# Compute cost for all s < t, mask invalid
# valid = s_all < t & ((t - s_all) >= min_size)
valid = (s_all < t) & ((t - s_all) >= min_size)
costs = jnp.where(
valid,
C[s_all] + segment_cost_jax(S1, S2, s_all, t) + beta,
jnp.inf
)
best_s = jnp.argmin(costs)
C = C.at[t].set(costs[best_s])
last_change = last_change.at[t].set(best_s)
return C, last_change
C, last_change = jax.lax.fori_loop(1, n+1, body, (C, last_change))
return C, last_change
def segment_cost_pytensor(S1, S2, i, j):
"""Cost of segment x[i:j], SSE around mean"""
n = j - i
sum_x = S1[j] - S1[i]
sum_x2 = S2[j] - S2[i]
return pt.switch(
pt.gt(n, 0),
sum_x2 - (sum_x ** 2) / n,
np.inf
)
x_symbolic = pt.vector("x")
beta_symbolic = pt.scalar("beta")
n = x_symbolic.shape[0]
N_STEPS=100
# cumulative sums for cost
S1 = pt.concatenate([pt.as_tensor([0.0]), pt.cumsum(x_symbolic)])
S2 = pt.concatenate([pt.as_tensor([0.0]), pt.cumsum(x_symbolic**2)])
# DP arrays
C_init = pt.alloc(np.inf, n+1)
C_init = pt.set_subtensor(C_init[0], -beta_symbolic)
last_change_init = pt.alloc(-1, n+1)
s_all = pt.arange(n) # candidate change points
min_size = 3
def step(t, C_prev, last_change_prev, S1, S2, beta_symbolic, s_all):
# valid = (s_all < t) & ((t - s_all) >= min_size)
valid = pt.and_(pt.lt(s_all, t), pt.ge(t - s_all, min_size))
# compute costs for all candidates
costs, _ = pytensor.scan(
fn=lambda s: pt.switch(
valid[s],
C_prev[s] + segment_cost_pytensor(S1, S2, s, t) + beta_symbolic,
np.inf
),
sequences=[pt.arange(n)]
)
costs = costs.flatten()
best_s = pt.argmin(costs, axis=0)
C_new = pt.set_subtensor(C_prev[t], costs[best_s])
last_change_new = pt.set_subtensor(last_change_prev[t], best_s)
return C_new, last_change_new
(C_vals, last_change_vals), _ = pytensor.scan(
fn=step,
sequences=[pt.arange(1, n+1)],
outputs_info=[C_init, last_change_init],
non_sequences=[S1, S2, beta_symbolic, s_all],
n_steps=N_STEPS # Added fixed iterations here
)
pelt_pytensor = pytensor.function([x_symbolic, beta_symbolic], [C_vals[-1], last_change_vals[-1]], trust_input=True)
pelt_pytensor_numba = pytensor.function(inputs=[x_symbolic, beta_symbolic], outputs=[C_vals[-1], last_change_vals[-1]], mode="NUMBA", trust_input=True)
pelt_bench = Benchmarker(
functions=[pelt_pytensor, pelt_numba, pelt_jax, pelt_pytensor_numba],
names=['pelt_pytensor', 'pelt_numba', 'pelt_jax', 'pelt_pytensor_numba'],
number=10
)
pelt_bench.run(
inputs={
"pelt_inputs": {"x": xs_std, "beta": 2. * np.log(len(xs_std))},
}
)
pelt_bench.summary()
Loops | Min (us) | Max (us) | Mean (us) | StdDev (us) | Median (us) | IQR (us) | OPS (Kops/s) | Samples | ||
---|---|---|---|---|---|---|---|---|---|---|
pelt_pytensor | pelt_inputs | 10 | 11903.712500 | 12687.708301 | 12234.242133 | 245.989123 | 12224.875001 | 277.799901 | 0.081738 | 9 |
pelt_numba | pelt_inputs | 10 | 17.408299 | 22.204101 | 19.400820 | 1.585318 | 19.241701 | 1.000002 | 51.544213 | 5 |
pelt_jax | pelt_inputs | 10 | 64.616700 | 82.345799 | 71.339679 | 4.879389 | 70.279100 | 4.893748 | 14.017445 | 19 |
pelt_pytensor_numba | pelt_inputs | 10 | 2330.925001 | 2496.158300 | 2424.299980 | 56.514842 | 2435.225001 | 61.833399 | 0.412490 | 5 |
outputs = pelt_numba(xs_std, 2. * np.log(len(xs_std)))
def plot_pelt_diagnostics(x, cps, C):
"""
Diagnostic plots for PELT changepoint detection.
Args:
x: 1D array, original time series
C: 1D array, cumulative DP cost from pelt()
cps: list of changepoint indices (sorted ascending)
"""
n = len(x)
cps_full = [0] + cps + [n]
# Segment means, std, SSE
segment_means = []
segment_stds = []
segment_costs = []
for start, end in zip(cps_full[:-1], cps_full[1:]):
seg = x[start:end]
mean = np.mean(seg)
std = np.std(seg)
cost = np.sum((seg - mean)**2)
segment_means.append(mean)
segment_stds.append(std)
segment_costs.append(cost)
# Step function for segment mean
mean_step = np.zeros(n)
for i, (start, end) in enumerate(zip(cps_full[:-1], cps_full[1:])):
mean_step[start:end] = segment_means[i]
# Step function for segment std
std_step = np.zeros(n)
for i, (start, end) in enumerate(zip(cps_full[:-1], cps_full[1:])):
std_step[start:end] = segment_stds[i]
if len(x) < 20:
title1 = "<span style='color: red;'>Warning</span>: Sample size is small - Detected Changepoints"
else:
title1 = "Detected Changepoints"
fig = make_subplots(
rows=4,
cols=1,
subplot_titles=(title1, "Average Shifts", "Variability Shifts", "Cumulative Cost")
)
fig.add_trace(
go.Scatter(
x = np.arange(len(x)),
y = x,
line_color="royalblue",
name = "Actuals",
mode="lines",
showlegend=False,
hovertemplate="<b>Time Point</b>: %{x}<br><b>Actual</b>: %{y}"
),
row=1, col=1
)
for cp in cps:
fig.add_vline(x=cp, line_dash='dash', line_color="red", row=1, col=1)
fig.add_trace(
go.Scatter(
x = np.arange(len(x)),
y = x,
name = "Actuals",
mode="lines",
line_color="rgba(105, 105, 105, 0.25)",
showlegend=False,
hoverinfo="skip"
),
row=2, col=1
)
fig.add_trace(
go.Scatter(
x = np.arange(len(x)),
y = mean_step,
name = "Average",
line_color="royalblue",
showlegend=False,
hovertemplate="<b>Time Point</b>: %{x}<br><b>Average</b>: %{y}"
),
row=2, col=1
)
fig.add_trace(
go.Scatter(
x = np.arange(len(x)),
y = std_step,
name = "Standard Deviation",
line_color="royalblue",
showlegend=False,
hovertemplate="<b>Time Point</b>: %{x}<br><b>Standard Deviation</b>: %{y}"
),
row=3, col=1
)
fig.add_trace(
go.Scatter(
x = np.arange(len(x)),
y = C,
name = "Cumulative Cost",
line_color="royalblue",
showlegend=False,
hovertemplate="<b>Time Point</b>: %{x}<br><b>Cost</b>: %{y}"
),
row=4, col=1
)
for cp in cps:
fig.add_vline(x=cp, line_dash='dash', line_color="red", row=4, col=1)
return fig.update_layout(height=1000, width=1200, template="plotly_dark")
def get_changepoints(last_change, n):
"""
Backtrack changepoints from last_change array.
Args:
last_change: array from pelt()
n: length of input series
Returns:
list of changepoint indices (sorted ascending)
"""
cps = []
t = n
while t > 0:
s = int(last_change[t])
if s <= 0:
break
cps.append(s)
t = s
return list(reversed(cps))
cps = get_changepoints(outputs[1], n=len(xs_std))
plot_pelt_diagnostics(xs, cps, outputs[0])
Kalman Filter Algorithms#
Linear Gaussian Kalman Filter#
@jit(nopython=True)
def atrocious_kalman_filter_numba(z, F, H, Q, R, x0, P0):
"""
This implementation of the Kalman filter is Atrocious and in standard Python would be a
BIG NO-NO. That being said this version SIGNIFICANTLY reduces Numba Compilation time.
Linear Gaussian Kalman filter algorithm
Parameters
----------
z: np.ndarray
shape (T, m) - observations
F: np.ndarray
state transition matrix - shape (n, n)
H: np.ndarray
observation/design matrix - shape (m, n)
Q: np.ndarray
process noise covariance - shape (n, n)
R: np.ndarray
observation noise covariance - shape (m, m)
x0: np.ndarray
initial state mean - shape (n,)
P0: np.ndarray
initial state covariance - shape (n, n)
Returns
-------
xs: np.ndarray
shape (T, n) - filtered state means
Ps: np.ndarray
shape (T, n, n) - filtered state covariances
"""
T = z.shape[0]
m = z.shape[1]
n = x0.shape[0]
xs = np.empty((T, n), dtype=np.float32)
Ps = np.empty((T, n, n), dtype=np.float32)
# local working arrays
x = np.empty(n, dtype=np.float32)
for i in range(n):
x[i] = x0[i]
P = np.empty((n, n), dtype=np.float32)
for i in range(n):
for j in range(n):
P[i, j] = P0[i, j]
# temporary matrices/vectors
x_pred = np.empty((T, n), dtype=np.float32)
P_pred = np.empty((T, n, n), dtype=np.float32)
y = np.empty(m, dtype=np.float32)
S = np.empty((m, m), dtype=np.float32)
K = np.empty((n, m), dtype=np.float32)
I_n = np.eye(n, dtype=np.float32)
for t in range(T):
# === Predict ===
# x_pred = F @ x
for i in range(n):
s = 0.0
for j in range(n):
s += F[i, j] * x[j]
x_pred[t, i] = s
# P_pred = F @ P @ F.T + Q
# temp = F @ P
temp = np.empty((n, n), dtype=np.float32)
for i in range(n):
for j in range(n):
s = 0.0
for k in range(n):
s += F[i, k] * P[k, j]
temp[i, j] = s
# P_pred = temp @ F.T
for i in range(n):
for j in range(n):
s = 0.0
for k in range(n):
s += temp[i, k] * F[j, k] # F.T[k, j] = F[j, k]
P_pred[t, i, j] = s + Q[i, j]
# === Update ===
# y = z[t] - H @ x_pred
for i in range(m):
s = 0.0
for j in range(n):
s += H[i, j] * x_pred[t, j]
y[i] = z[t, i] - s
# S = H @ P_pred @ H.T + R
# temp2 = H @ P_pred
temp2 = np.empty((m, n), dtype=np.float32)
for i in range(m):
for j in range(n):
s = 0.0
for k in range(n):
s += H[i, k] * P_pred[t, k, j]
temp2[i, j] = s
# S = temp2 @ H.T
for i in range(m):
for j in range(m):
s = 0.0
for k in range(n):
s += temp2[i, k] * H[j, k] # H.T[k,j] = H[j,k]
S[i, j] = s + R[i, j]
# K = P_pred @ H.T @ inv(S)
# first compute P_pred @ H.T -> (n, m)
P_Ht = np.empty((n, m), dtype=np.float32)
for i in range(n):
for j in range(m):
s = 0.0
for k in range(n):
s += P_pred[t, i, k] * H[j, k] # H.T[k,j] = H[j,k]
P_Ht[i, j] = s
# invert S
S_inv = np.linalg.inv(S)
# K = P_Ht @ S_inv (n,m) @ (m,m) -> (n,m)
for i in range(n):
for j in range(m):
s = 0.0
for k in range(m):
s += P_Ht[i, k] * S_inv[k, j]
K[i, j] = s
# x = x_pred + K @ y
for i in range(n):
s = 0.0
for j in range(m):
s += K[i, j] * y[j]
x[i] = x_pred[t, i] + s
# P = (I - K H) P_pred
# compute (I - K H)
KH = np.empty((n, n), dtype=np.float32)
for i in range(n):
for j in range(n):
s = 0.0
for k in range(m):
s += K[i, k] * H[k, j]
KH[i, j] = s
I_minus_KH = np.empty((n, n), dtype=np.float32)
for i in range(n):
for j in range(n):
I_minus_KH[i, j] = I_n[i, j] - KH[i, j]
# P = I_minus_KH @ P_pred
for i in range(n):
for j in range(n):
s = 0.0
for k in range(n):
s += I_minus_KH[i, k] * P_pred[t, k, j]
P[i, j] = s
# store results
for i in range(n):
xs[t, i] = x[i]
for i in range(n):
for j in range(n):
Ps[t, i, j] = P[i, j]
return xs, Ps, x_pred, P_pred
@jit(nopython=True)
def kalman_filter_numba(z, F, H, Q, R, x0, P0):
"""
Linear Gaussian Kalman filter algorithm
Parameters
----------
z: np.ndarray
shape (T, m) - observations
F: np.ndarray
state transition matrix - shape (n, n)
H: np.ndarray
observation/design matrix - shape (m, n)
Q: np.ndarray
process noise covariance - shape (n, n)
R: np.ndarray
observation noise covariance - shape (m, m)
x0: np.ndarray
initial state mean - shape (n,)
P0: np.ndarray
initial state covariance - shape (n, n)
Returns
-------
xs: np.ndarray
shape (T, n) - filtered state means
Ps: np.ndarray
shape (T, n, n) - filtered state covariances
"""
T, m = z.shape
n = x0.shape[0]
xs = np.zeros((T, n), dtype=np.float32)
Ps = np.zeros((T, n, n), dtype=np.float32)
x_pred = np.zeros((T, n), dtype=np.float32)
P_pred = np.zeros((T, n, n), dtype=np.float32)
x = x0.copy()
P = P0.copy()
I = np.eye(n, dtype=np.float32)
for t in range(T):
# --- Predict ---
x_pred[t] = F @ x
P_pred[t] = F @ P @ F.T + Q
# --- Update ---
y = z[t] - H @ x_pred[t]
S = H @ P_pred[t] @ H.T + R
K = P_pred[t] @ H.T @ np.linalg.inv(S)
x = x_pred[t] + K @ y
P = (I - K @ H) @ P_pred[t]
xs[t] = x
Ps[t] = P
return xs, Ps, x_pred, P_pred
@block
@jax.jit
def kalman_filter_jax(z, F, H, Q, R, x0, P0):
"""
Linear Gaussian Kalman filter algorithm
Parameters
----------
z: np.ndarray
shape (T, m) - observations
F: np.ndarray
state transition matrix - shape (n, n)
H: np.ndarray
observation/design matrix - shape (m, n)
Q: np.ndarray
process noise covariance - shape (n, n)
R: np.ndarray
observation noise covariance - shape (m, m)
x0: np.ndarray
initial state mean - shape (n,)
P0: np.ndarray
initial state covariance - shape (n, n)
Returns
-------
xs: jnp.ndarray
shape (T, n) - filtered state means
Ps: jnp.ndarray
shape (T, n, n) - filtered state covariances
"""
n = x0.shape[0]
I = jnp.eye(n)
X_pred_init = jnp.zeros((1,))
P_pred_init = jnp.zeros((1, 1,))
def step(carry, z_t):
x, P, _, _ = carry
# --- Predict ---
x_pred = F @ x
P_pred = F @ P @ F.T + Q
# --- Update ---
y = z_t - H @ x_pred
S = H @ P_pred @ H.T + R
K = P_pred @ H.T @ jnp.linalg.inv(S)
x_new = x_pred + K @ y
P_new = (I - K @ H) @ P_pred
return (x_new, P_new, x_pred, P_pred), (x_new, P_new, x_pred, P_pred)
# run scan
(_, _, _, _), (xs, Ps, x_pred, P_pred) = jax.lax.scan(step, (x0, P0, X_pred_init, P_pred_init), z)
return xs, Ps, x_pred, P_pred
z_symbolic = pt.matrix("z")
F_symbolic = pt.matrix("F")
H_symbolic = pt.matrix("H")
Q_symbolic = pt.matrix("Q")
R_symbolic = pt.matrix("R")
x0_symbolic = pt.vector("x0")
P0_symbolic = pt.matrix("P0")
n = x0_symbolic.shape[0]
I = pt.eye(n)
X_pred_init = pt.zeros_like(x0_symbolic)
P_pred_init = pt.zeros_like(P0_symbolic)
N_STEPS = 500
def step(z_t, x, P, x_pred, P_pred, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, I):
# --- Predict ---
x_pred = F_symbolic @ x
P_pred = F_symbolic @ P @ F_symbolic.T + Q_symbolic
# --- Update ---
y = z_t - H_symbolic @ x_pred
S = H_symbolic @ P_pred @ H_symbolic.T + R_symbolic
K = P_pred @ H_symbolic.T @ pt.linalg.inv(S)
x_new = x_pred + K @ y
P_new = (I - K @ H_symbolic) @ P_pred
return x_new, P_new, x_pred, P_pred
# run scan
(xs, Ps, x_pred, P_pred), _ = pytensor.scan(
fn=step,
outputs_info=[x0_symbolic, P0_symbolic, X_pred_init, P_pred_init],
sequences=[z_symbolic],
non_sequences=[F_symbolic, H_symbolic, Q_symbolic, R_symbolic, I],
n_steps=N_STEPS
)
kalman_filter_pytensor = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], trust_input=True)
kalman_filter_pytensor_numba = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], mode="NUMBA", trust_input=True)
kalman_filter_pytensor_jax = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], mode="JAX", trust_input=True)
T = 500
F = np.array([[1.0]]).astype(np.float32)
H = np.array([[1.0]]).astype(np.float32)
Q = np.array([[0.01]]).astype(np.float32)
R = np.array([[0.1]]).astype(np.float32)
x0 = np.array([0.0]).astype(np.float32)
P0 = np.array([[1.0]]).astype(np.float32)
true = 1.0
z = (true + 0.4*np.random.randn(T)).reshape(T, 1).astype(np.float32)
kalman_filter_bench = Benchmarker(
functions=[kalman_filter_pytensor, atrocious_kalman_filter_numba, kalman_filter_numba, kalman_filter_jax, kalman_filter_pytensor_numba, block(kalman_filter_pytensor_jax)],
names=['kalman_filter_pytensor', 'atrocious_kalman_filter_numba', 'kalman_filter_numba', 'kalman_filter_jax', 'kalman_filter_pytensor_numba', 'kalman_filter_pytensor_jax'],
number=10
)
kalman_filter_bench.run(
inputs={
"kalman_filter_inputs": {"z": z, "F": F, "H": H, "Q": Q, "R": R, "x0": x0, "P0": P0},
}
)
kalman_filter_bench.summary()
/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/tmpmbe0qi_u:33: NumbaWarning:
Cannot cache compiled function "scan" as it uses dynamic globals (such as ctypes pointers and large global arrays)
Loops | Min (us) | Max (us) | Mean (us) | StdDev (us) | Median (us) | IQR (us) | OPS (Kops/s) | Samples | ||
---|---|---|---|---|---|---|---|---|---|---|
kalman_filter_pytensor | kalman_filter_inputs | 10 | 5675.591601 | 6334.566601 | 5949.072047 | 159.617745 | 5937.045900 | 191.600001 | 0.168093 | 17 |
atrocious_kalman_filter_numba | kalman_filter_inputs | 10 | 297.462501 | 330.600000 | 314.415020 | 11.425574 | 315.533401 | 14.287402 | 3.180510 | 5 |
kalman_filter_numba | kalman_filter_inputs | 10 | 641.612499 | 718.120800 | 684.415800 | 32.973390 | 706.012500 | 62.024998 | 1.461100 | 5 |
kalman_filter_jax | kalman_filter_inputs | 10 | 305.941701 | 368.020899 | 333.580105 | 18.368842 | 329.077049 | 22.434351 | 2.997781 | 18 |
kalman_filter_pytensor_numba | kalman_filter_inputs | 10 | 853.600001 | 933.024997 | 897.570000 | 33.462131 | 918.749999 | 61.308398 | 1.114119 | 5 |
kalman_filter_pytensor_jax | kalman_filter_inputs | 10 | 304.041599 | 372.449998 | 330.685825 | 17.271642 | 327.668698 | 24.356249 | 3.024018 | 20 |
xs, Ps, x_pred, P_pred = kalman_filter_jax(z, F, H, Q, R, x0, P0)
def compute_pred_intervals(z, x_pred, P_pred, H, R, zscore=1.96):
T = z.shape[0]
m = H.shape[0]
mean = np.zeros((T, m))
lower = np.zeros((T, m))
upper = np.zeros((T, m))
outside = np.zeros(T, dtype=np.bool_)
for t in range(T):
mean[t] = H @ x_pred[t]
S = H @ P_pred[t] @ H.T + R
std = np.sqrt(np.diag(S))
lower[t] = mean[t] - zscore * std
upper[t] = mean[t] + zscore * std
# check coverage of actual obs
outside[t] = np.any((z[t] < lower[t]) | (z[t] > upper[t]))
coverage = 1 - outside.mean()
return mean, lower, upper, coverage
mean, lower, upper, coverage = compute_pred_intervals(z, x_pred, P_pred, H, R)
coverage
np.float64(0.91)
fig= go.Figure()
fig.add_traces(
[
go.Scatter(
x = np.arange(T),
y = z.ravel(),
mode="markers",
marker_color = "royalblue",
name = "actuals"
),
go.Scatter(
x = np.arange(T),
y = xs.ravel(),
mode = "lines",
marker_color = "orange",
name = "filtered mean"
),
go.Scatter(
name="",
x=np.arange(T),
y=upper.ravel(),
mode="lines",
marker=dict(color="#eb8c34"),
line=dict(width=0),
legendgroup="95% CI",
showlegend=False
),
go.Scatter(
name="95% CI",
x=np.arange(T),
y=lower.ravel(),
mode="lines", marker=dict(color="#eb8c34"),
line=dict(width=0),
legendgroup="95% CI",
fill='tonexty',
fillcolor='rgba(235, 140, 52, 0.2)'
),
]
)
fig.update_layout(
xaxis=dict(
title = "Time Index",
),
yaxis=dict(
title = "y"
),
template = "plotly_dark"
)
Non-linear Kalman Filter#
@jit(nopython=True)
def loglik_poisson_numba(s, y):
"""Poisson Log Likelihood"""
mu = np.exp(s)
return y * np.log(mu + 1e-30) - mu - math.lgamma(y + 1.0) # numba does not support scipy.special gammaln
@jit(nopython=True)
def particle_filter_1d_predict_numba(A, Q, x0_mean, x0_std, ys, N=1000, seed=2):
"""
1D particle filter.
Parameters
----------
A: float
State transition
Q: float
Process covariance
x0_mean: float
Prior mean for the latent state
x0_std: float
Prior standard deviation
ys: np.ndarray
observations
N: int
number of particles
seed: int
rng seed for reproducibility
Returns
-------
filtered_means: np.ndarray
The filtered mean for the latent state
filtered_vars: np.ndarray
The filtered variance for the latent state
pred_means: np.ndarray
observation predicted mean
"""
np.random.seed(seed)
T = ys.shape[0]
particles = np.random.normal(x0_mean, x0_std, size=N)
weights = np.ones(N) / N
filtered_means = np.zeros(T)
filtered_vars = np.zeros(T)
pred_means = np.zeros(T)
for t in range(T):
y = ys[t]
# propagate (vectorized)
particles = A * particles + np.random.normal(0, np.sqrt(Q), size=N)
# update weights
logw = np.zeros(N)
for i in range(N):
logw[i] = loglik_poisson_numba(particles[i], y)
logw = logw - np.max(logw)
weights *= np.exp(logw)
weights /= np.sum(weights) + 1e-12
# filtered moments
mean_t = np.sum(weights * particles)
var_t = np.sum(weights * (particles - mean_t) ** 2)
# predictive mean
pred_mean = np.sum(weights * np.exp(particles))
filtered_means[t] = mean_t
filtered_vars[t] = var_t
pred_means[t] = pred_mean
# resample (multinomial resampling) because numba doesn't support np.random.choice
cumulative_sum = np.cumsum(weights)
cumulative_sum[-1] = 1.0 # guard against rounding error
indices = np.searchsorted(cumulative_sum, np.random.rand(N))
particles = particles[indices]
weights = np.ones(N) / N
return filtered_means, filtered_vars, pred_means
# Had to fix the loglikelihood and key to use benchmarker as is
def loglik_poisson_jax(s, y):
"""Poisson Log Likelihood"""
mu = jnp.exp(s)
return y * jnp.log(mu + 1e-30) - mu - gammaln(y + 1.0)
@block
@partial(jax.jit, static_argnums=5)
def particle_filter_1d_predict_jax(
A, Q, x0_mean, x0_std, ys, N=1000,
):
"""
1D particle filter.
Parameters
----------
A: float
State transition
Q: float
Process covariance
x0_mean: float
Prior mean for the latent state
x0_std: float
Prior standard deviation
ys: np.ndarray
observations
loglik_fn: function
The log likelihood function
key:
JAX prng key
N: int
number of particles
Returns
-------
filtered_means: jnp.ndarray
The filtered mean for the latent state
filtered_vars: jnp.ndarray
The filtered variance for the latent state
pred_means: jnp.ndarray
observation predicted mean
"""
key = jax.random.PRNGKey(0)
T = ys.shape[0]
particles = jax.random.normal(key, (N,)) * x0_std + x0_mean # init particles from gaussian priors
weights = jnp.ones(N) / N # particle weights, all particles equally likely prior
def body_fun(carry, t):
particles, weights, key = carry
y = ys[t]
# propagate
key, subkey = jax.random.split(key)
particles = A * particles + jax.random.normal(subkey, (N,)) * jnp.sqrt(Q) # state transition model
# update weights
logw = jax.vmap(lambda x: loglik_poisson_jax(x, y))(particles) # update particles in parallel
logw = logw - jnp.max(logw) # avoid overflow
weights = weights * jnp.exp(logw) # old weights times the likelihood
weights /= jnp.sum(weights) + 1e-12 # normalize so that weights sum to 1
# filtered moments
mean_t = jnp.sum(weights * particles) # posterior mean of latent state
var_t = jnp.sum(weights * (particles - mean_t)**2) # posterior variance of latent state
# predictive mean
pred_mean = jnp.sum(weights * jnp.exp(particles))
# resample to prevent dominant particles
key, subkey = jax.random.split(key)
indices = jax.random.choice(subkey, N, p=weights, shape=(N,))
particles = particles[indices]
weights = jnp.ones(N) / N
carry = (particles, weights, key)
out = (mean_t, var_t, pred_mean)
return carry, out
_, outputs = jax.lax.scan(body_fun, (particles, weights, key), jnp.arange(T))
return outputs
from pytensor.tensor.random.utils import RandomStream
# Random stream for PyTensor
srng = RandomStream(seed=42)
# Poisson log-likelihood
def loglik_poisson_pytensor(s, y):
mu = pt.exp(s)
return y.flatten() * pt.log(mu + 1e-30) - mu - pt.gammaln(y.flatten() + 1.0)
ys_symbolic = pt.vector("ys")
x0_mean_symbolic = pt.scalar("x0_mean")
x0_std_symbolic = pt.scalar("x0_std")
A_symbolic = pt.scalar("A")
Q_symbolic = pt.scalar("Q")
N_symbolic = pt.scalar("N", dtype='int64')
N_STEPS = 300
# Initialize particles and weights
particles_init = srng.normal(size=(N_symbolic,)) * x0_std_symbolic + x0_mean_symbolic
weights_init = pt.ones((N_symbolic,)) / N_symbolic
# Step function for scan
def step(y_t, particles_prev, weights_prev, A_symbolic, Q_symbolic):
# Propagate particles
particles_prop = A_symbolic * particles_prev + srng.normal(size=(N_symbolic,)) * pt.sqrt(Q_symbolic)
# Update weights
# logw = pt.stack([loglik_poisson_pytensor(p, y_t) for p in particles_prop])
logw = loglik_poisson_pytensor(particles_prop, y_t)
logw_stable = logw - pt.max(logw)
w_unnorm = weights_prev * pt.exp(logw_stable)
w = w_unnorm / (pt.sum(w_unnorm) + 1e-12)
# Filtered moments
mean_t = pt.sum(w * particles_prop)
var_t = pt.sum(w * (particles_prop - mean_t) ** 2)
pred_mean = pt.sum(w * pt.exp(particles_prop))
# Resample particles
idx = srng.choice(size=(N_symbolic,), a=N_symbolic, p=w)
particles_resampled = particles_prop[idx]
weights_resampled = pt.ones((N_symbolic,)) / N_symbolic
# Return flat tuple
return particles_resampled, weights_resampled, mean_t, var_t, pred_mean
# first two are recurrent, rest are collected
outputs_info = [
particles_init,
weights_init,
None,
None,
None
]
(particles_seq, weights_seq, means_seq, vars_seq, preds_seq), updates = pytensor.scan(
fn=step,
sequences=[ys_symbolic],
outputs_info=outputs_info,
non_sequences=[A_symbolic, Q_symbolic],
n_steps=N_STEPS
)
particle_filter_1d_predict_pytensor = pytensor.function(
[A_symbolic, Q_symbolic, x0_mean_symbolic, x0_std_symbolic, ys_symbolic, N_symbolic],
[means_seq, vars_seq, preds_seq],
updates=updates,
no_default_updates=True,
trust_input=True
)
particle_filter_1d_predict_pytensor_numba = pytensor.function(
[A_symbolic, Q_symbolic, x0_mean_symbolic, x0_std_symbolic, ys_symbolic, N_symbolic],
[means_seq, vars_seq, preds_seq],
updates=updates,
no_default_updates=True,
mode="NUMBA",
trust_input=True
)
key = jax.random.PRNGKey(0)
T = 300
A = 0.95
Q = 0.05
rng = np.random.RandomState(1)
target_mean = 10.0
latent_var = Q / (1 - A**2)
x0_mean = np.log(target_mean) - 0.5 * latent_var
x0_std = 1.0
# Simulate latent
x = np.zeros(T)
x[0] = rng.normal() * np.sqrt(latent_var) + x0_mean
for t in range(1, T):
x[t] = A * x[t-1] + rng.normal() * np.sqrt(Q)
ys = np.array(rng.poisson(np.exp(x)), dtype=np.float32)
nonlinear_kalman_filter_bench = Benchmarker(
functions=[particle_filter_1d_predict_pytensor, particle_filter_1d_predict_numba, particle_filter_1d_predict_jax, particle_filter_1d_predict_pytensor_numba,],
names=['particle_filter_1d_predict_pytensor', 'particle_filter_1d_predict_numba', 'particle_filter_1d_predict_jax', 'particle_filter_1d_predict_pytensor_numba',],
number=5 # This takes a while to run reducing number of loops
)
nonlinear_kalman_filter_bench.run(
inputs={
"kalman_filter_inputs": {"A": A, "Q": Q, "x0_mean": x0_mean, "x0_std": x0_std, "ys": ys, "N": 2000},
}
)
nonlinear_kalman_filter_bench.summary()
Loops | Min (us) | Max (us) | Mean (us) | StdDev (us) | Median (us) | IQR (us) | OPS (Kops/s) | Samples | ||
---|---|---|---|---|---|---|---|---|---|---|
particle_filter_1d_predict_pytensor | kalman_filter_inputs | 5 | 728815.783397 | 742948.366801 | 734642.386761 | 5002.169131 | 733596.250002 | 6332.200003 | 0.001361 | 5 |
particle_filter_1d_predict_numba | kalman_filter_inputs | 5 | 48678.408400 | 48904.825002 | 48782.921640 | 89.835338 | 48744.633398 | 160.525000 | 0.020499 | 5 |
particle_filter_1d_predict_jax | kalman_filter_inputs | 5 | 33170.266601 | 33644.350001 | 33410.141721 | 155.134231 | 33411.583403 | 125.924998 | 0.029931 | 5 |
particle_filter_1d_predict_pytensor_numba | kalman_filter_inputs | 5 | 676612.958201 | 678200.574999 | 677432.478319 | 647.987321 | 677334.924997 | 1278.466597 | 0.001476 | 5 |
Slightly different estimates because I couldn’t reproduce 1:1
filtered_means, filtered_vars, pred_means = particle_filter_1d_predict_numba(
A, Q, x0_mean, x0_std, ys, N=2000, seed=2
)
fig = make_subplots(
rows=2, cols=1,
subplot_titles=("Observation Predictions", "Latent State Estimation"),
vertical_spacing=0.07,
shared_xaxes=True
)
fig.add_traces(
[
go.Scatter(
x = np.arange(T),
y = ys,
mode = "markers",
marker_color = "cornflowerblue",
name = "actuals"
),
go.Scatter(
x = np.arange(T),
y = pred_means,
mode = "lines",
marker_color = "#eb8c34",
name = "predicted mean"
),
go.Scatter(
name="",
x=np.arange(T),
y=pred_means + 2*jnp.sqrt(pred_means),
mode="lines",
marker=dict(color="#eb8c34"),
line=dict(width=0),
legendgroup="predicted mean 95% CI",
showlegend=False
),
go.Scatter(
name="predicted mean 95% CI",
x=np.arange(T),
y=pred_means - 2*jnp.sqrt(pred_means),
mode="lines", marker=dict(color="#eb8c34"),
line=dict(width=0),
legendgroup="predicted mean 95% CI",
fill='tonexty',
fillcolor='rgba(235, 140, 52, 0.2)'
),
],
rows=1, cols=1
)
fig.add_traces(
[
go.Scatter(
x = np.arange(T),
y = x,
mode = "lines",
marker_color = "cornflowerblue",
name = "true latent state"
),
go.Scatter(
x = np.arange(T),
y = filtered_means,
mode = "lines",
marker_color = "#eb8c34",
name = "filtered state mean"
),
go.Scatter(
name="",
x=np.arange(T),
y=filtered_means + 2*jnp.sqrt(filtered_vars),
mode="lines",
marker=dict(color="#eb8c34"),
line=dict(width=0),
legendgroup="filtered state mean 95% CI",
showlegend=False
),
go.Scatter(
name="filtered state mean 95% CI",
x=np.arange(T),
y=filtered_means - 2*jnp.sqrt(filtered_vars),
mode="lines", marker=dict(color="#eb8c34"),
line=dict(width=0),
legendgroup="filtered state mean 95% CI",
fill='tonexty',
fillcolor='rgba(235, 140, 52, 0.2)'
),
],
rows=2, cols=1
)
for i, yaxis in enumerate(fig.select_yaxes(), 1):
legend_name = f"legend{i}"
fig.update_layout({legend_name: dict(y=yaxis.domain[1], yanchor="top")}, showlegend=True)
fig.update_traces(row=i, legend=legend_name)
fig.update_layout(height=1000, width=1200, template="plotly_dark")
fig.update_layout(
legend1=dict(
yanchor="top",
y=1.0,
xanchor="left",
x=0,
orientation="h"
),
legend2=dict(
yanchor="top",
y=.465,
xanchor="left",
x=0,
orientation="h"
),
)