{
"cells": [
{
"cell_type": "markdown",
"id": "9571ee33",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "73cc811c",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import jax.numpy as jnp\n",
"from numba import jit\n",
"import numba\n",
"import pytensor\n",
"import pytensor.tensor as pt\n",
"from pytensor.compile.mode import get_mode\n",
"import timeit\n",
"import jax\n",
"import math\n",
"from jax.scipy.special import gammaln\n",
"from functools import partial\n",
"\n",
"import plotly.graph_objects as go\n",
"from plotly.subplots import make_subplots\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d5fc1350",
"metadata": {},
"outputs": [],
"source": [
"import plotly.io as pio\n",
"pio.renderers.default = \"notebook\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "daa0969b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pytensor version: 0+untagged.31343.g647570d.dirty\n",
"jax version: 0.7.2\n",
"numba version: 0.62.1\n"
]
}
],
"source": [
"print(\"pytensor version:\", pytensor.__version__)\n",
"print(\"jax version:\", jax.__version__)\n",
"print(\"numba version:\", numba.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8294718d",
"metadata": {},
"outputs": [],
"source": [
"class Benchmarker:\n",
" \"\"\"\n",
" Benchmark a set of functions by timing execution and summarizing statistics.\n",
"\n",
" Parameters\n",
" ----------\n",
" functions : list of callables\n",
" List of callables to benchmark.\n",
" names : list of str, optional\n",
" Names corresponding to each function. Default is ['func_0', 'func_1', ...].\n",
" number : int or None, optional\n",
" Number of loops per timing. If None, auto-calibrated via Timer.autorange().\n",
" Default is None.\n",
" repeat : int, optional\n",
" Number of repeats for timing. Default is 7.\n",
" target_time : float, optional\n",
" Target duration in seconds for auto-calibration. Default is 0.2.\n",
"\n",
" Attributes\n",
" ----------\n",
" results : dict\n",
" Mapping from function names to a dict with keys:\n",
" - 'raw_us': numpy.ndarray of raw timings in microseconds\n",
" - 'loops': number of loops used per timing\n",
"\n",
" Methods\n",
" -------\n",
" run()\n",
" Auto-calibrate (if needed) and run timings for all functions.\n",
" summary(unit='us') -> pandas.DataFrame\n",
" Return a summary DataFrame with statistics converted to the given unit.\n",
" raw(name=None) -> dict or numpy.ndarray\n",
" Return raw timing data in microseconds for a specific function or all.\n",
" _convert_times(times, unit) -> numpy.ndarray\n",
" Convert an array of times from microseconds to the specified unit.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self, functions, names=None, number=None, min_rounds=5, max_time=1.0, target_time=0.2\n",
" ):\n",
" self.functions = functions\n",
" self.names = names or [f\"func_{i}\" for i in range(len(functions))]\n",
" self.number = number\n",
" self.min_rounds = min_rounds\n",
" self.max_time = max_time\n",
" self.target_time = target_time\n",
" self.results = {}\n",
"\n",
" def run(self, inputs: dict[str, dict]):\n",
" \"\"\"\n",
" Auto-calibrate loop count and sample rounds if needed, then time each function.\n",
" \"\"\"\n",
" for name, func in zip(self.names, self.functions):\n",
" for input_name, kwargs in inputs.items():\n",
" timer = timeit.Timer(partial(func, **kwargs))\n",
"\n",
" # Calibrate loops\n",
" if self.number is None:\n",
" loops, calib_time = timer.autorange()\n",
" else:\n",
" loops = self.number\n",
" calib_time = timer.timeit(number=loops)\n",
"\n",
" # Determine rounds based on max_time and min_rounds\n",
" if self.max_time is not None:\n",
" rounds = max(self.min_rounds, int(np.ceil(self.max_time / calib_time)))\n",
" else:\n",
" rounds = self.min_rounds\n",
"\n",
" raw_round_times = np.array(timer.repeat(repeat=rounds, number=loops))\n",
"\n",
" # Convert to microseconds per single execution\n",
" raw_us = raw_round_times / loops * 1e6\n",
"\n",
" self.results[(name, input_name)] = {\n",
" \"raw_us\": raw_us,\n",
" \"loops\": loops,\n",
" \"rounds\": rounds,\n",
" }\n",
"\n",
" def summary(self, unit=\"us\"):\n",
" \"\"\"\n",
" Summarize benchmark statistics in a DataFrame.\n",
"\n",
" Parameters\n",
" ----------\n",
" unit : {'us', 'ms', 'ns', 's'}, optional\n",
" Unit for output times. 'us' means microseconds, 'ms' milliseconds,\n",
" 'ns' nanoseconds, 's' seconds. Default is 'us'.\n",
"\n",
" Returns\n",
" -------\n",
" pandas.DataFrame\n",
" Summary with columns:\n",
" Name, Loops, Min, Max, Mean, StdDev, Median, IQR (all in given unit),\n",
" OPS (Kops/unit), Samples.\n",
" \"\"\"\n",
" records = []\n",
" indexes = []\n",
" for name, data in self.results.items():\n",
" raw_us = data[\"raw_us\"]\n",
" # Convert to target unit\n",
" times = self._convert_times(raw_us, unit)\n",
" if isinstance(name, tuple) and len(name) > 1:\n",
" indexes.append(name)\n",
" elif isinstance(name, tuple) and len(name) == 1:\n",
" indexes.append(name[0])\n",
" else:\n",
" indexes.append(name)\n",
"\n",
" stats = {\n",
" \"Loops\": data[\"loops\"],\n",
" f\"Min ({unit})\": np.min(times),\n",
" f\"Max ({unit})\": np.max(times),\n",
" f\"Mean ({unit})\": np.mean(times),\n",
" f\"StdDev ({unit})\": np.std(times),\n",
" f\"Median ({unit})\": np.median(times),\n",
" f\"IQR ({unit})\": np.percentile(times, 75) - np.percentile(times, 25),\n",
" \"OPS (Kops/s)\": 1e3 / (np.mean(raw_us)),\n",
" \"Samples\": len(raw_us),\n",
" }\n",
" records.append(stats)\n",
"\n",
" if all(isinstance(idx, tuple) for idx in indexes):\n",
" index = pd.MultiIndex.from_tuples(indexes)\n",
" else:\n",
" index = pd.Index(indexes)\n",
" return pd.DataFrame(records, index=index)\n",
"\n",
" def raw(self, name=None):\n",
" \"\"\"\n",
" Get raw timing data in microseconds.\n",
"\n",
" Parameters\n",
" ----------\n",
" name : str, optional\n",
" If given, returns the raw_us array for that function. Otherwise returns\n",
" a dict of all raw results.\n",
"\n",
" Returns\n",
" -------\n",
" numpy.ndarray or dict\n",
" \"\"\"\n",
" if name:\n",
" return self.results.get(name, {}).get(\"raw_us\")\n",
" return {n: d[\"raw_us\"] for n, d in self.results.items()}\n",
"\n",
" def _convert_times(self, times, unit):\n",
" \"\"\"\n",
" Convert an array of times from microseconds to the specified unit.\n",
"\n",
" Parameters\n",
" ----------\n",
" times : array-like\n",
" Times in microseconds.\n",
" unit : {'us', 'ms', 'ns', 's'}\n",
" Target unit: 'us' microseconds, 'ms' milliseconds,\n",
" 'ns' nanoseconds, 's' seconds.\n",
"\n",
" Returns\n",
" -------\n",
" numpy.ndarray\n",
" Converted times.\n",
"\n",
" Raises\n",
" ------\n",
" ValueError\n",
" If `unit` is not one of the supported options.\n",
" \"\"\"\n",
" unit = unit.lower()\n",
" if unit == \"us\":\n",
" factor = 1.0\n",
" elif unit == \"ms\":\n",
" factor = 1e-3\n",
" elif unit == \"ns\":\n",
" factor = 1e3\n",
" elif unit == \"s\":\n",
" factor = 1e-6\n",
" else:\n",
" raise ValueError(f\"Unsupported unit: {unit}\")\n",
" return times * factor"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "799a759b",
"metadata": {},
"outputs": [],
"source": [
"# def block(func):\n",
"# def inner(*args, **kwargs):\n",
"# return jax.block_until_ready(func(*args, **kwargs))\n",
"# return inner\n",
"\n",
"\n",
"# I believe the above will only work if the return is a single jnp.array (at least that is what AI thinks) I would appreciate your insights here, Adrian.\n",
"\n",
"def block(func):\n",
" def inner(*args, **kwargs):\n",
" result = func(*args, **kwargs)\n",
" # Recursively block on all JAX arrays in result\n",
" jax.tree_util.tree_map(lambda x: x.block_until_ready() if hasattr(x, \"block_until_ready\") else None, result)\n",
" return result\n",
" return inner\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6e76a860",
"metadata": {},
"outputs": [],
"source": [
"# Set Pytensor to use float32\n",
"pytensor.config.floatX = \"float32\""
]
},
{
"cell_type": "markdown",
"id": "1c3eb543",
"metadata": {},
"source": [
"# Introduction"
]
},
{
"cell_type": "markdown",
"id": "e066c9c4",
"metadata": {},
"source": [
"# Baby Steps"
]
},
{
"cell_type": "markdown",
"id": "962c851c",
"metadata": {},
"source": [
"## Fibonacci Algorithm"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1c8d1654",
"metadata": {},
"outputs": [],
"source": [
"N_STEPS = 100_000\n",
"\n",
"b_symbolic = pt.vector(\"b\", dtype=\"int32\", shape=(1,))\n",
"\n",
"def step(a, b):\n",
" return a + b, a\n",
"\n",
"(outputs_a, outputs_b), _ = pytensor.scan(\n",
" fn=step,\n",
" outputs_info=[pt.ones(1, dtype=\"int32\"), b_symbolic],\n",
" n_steps=N_STEPS\n",
")\n",
"\n",
"# compile function returning final a\n",
"fibonacci_pytensor = pytensor.function([b_symbolic], outputs_a[-1], trust_input=True)\n",
"fibonacci_pytensor_numba = pytensor.function([b_symbolic], outputs_a[-1], mode='NUMBA', trust_input=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "bf48d6bf",
"metadata": {},
"outputs": [],
"source": [
"@jit(nopython=True)\n",
"def fibonacci_numba(b):\n",
" a = np.ones(1, dtype=np.int32)\n",
" for _ in range(N_STEPS):\n",
" a[0], b[0] = a[0] + b[0], a[0]\n",
" return a[0]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "317ee391",
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def fibonacci_jax(b):\n",
" a = jnp.array(1, dtype=np.int32)\n",
" for _ in range(N_STEPS):\n",
" a, b = a + b, a\n",
" return a"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "63bfdffe",
"metadata": {},
"outputs": [],
"source": [
"fibonacci_bench = Benchmarker(\n",
" functions=[fibonacci_pytensor, fibonacci_numba, fibonacci_jax, fibonacci_pytensor_numba], \n",
" names=['fibonacci_pytensor', 'fibonacci_numba', 'fibonacci_jax', 'fibonacci_pytensor_numba'],\n",
" number=10\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "65bf994e",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" | \n",
" Loops | \n",
" Min (us) | \n",
" Max (us) | \n",
" Mean (us) | \n",
" StdDev (us) | \n",
" Median (us) | \n",
" IQR (us) | \n",
" OPS (Kops/s) | \n",
" Samples | \n",
"
\n",
" \n",
" \n",
" \n",
" fibonacci_pytensor | \n",
" fibonacci_inputs | \n",
" 10 | \n",
" 54149.950002 | \n",
" 55278.412500 | \n",
" 54654.332500 | \n",
" 434.902898 | \n",
" 54807.945801 | \n",
" 694.087401 | \n",
" 0.018297 | \n",
" 5 | \n",
"
\n",
" \n",
" fibonacci_numba | \n",
" fibonacci_inputs | \n",
" 10 | \n",
" 84.208400 | \n",
" 105.558301 | \n",
" 95.575969 | \n",
" 4.793104 | \n",
" 95.562500 | \n",
" 5.033301 | \n",
" 10.462881 | \n",
" 13 | \n",
"
\n",
" \n",
" fibonacci_jax | \n",
" fibonacci_inputs | \n",
" 10 | \n",
" 7.162499 | \n",
" 31.558401 | \n",
" 14.425020 | \n",
" 8.923158 | \n",
" 12.233300 | \n",
" 5.920898 | \n",
" 69.323996 | \n",
" 5 | \n",
"
\n",
" \n",
" fibonacci_pytensor_numba | \n",
" fibonacci_inputs | \n",
" 10 | \n",
" 2338.045801 | \n",
" 2429.037497 | \n",
" 2385.059159 | \n",
" 34.438988 | \n",
" 2392.570800 | \n",
" 58.641701 | \n",
" 0.419277 | \n",
" 5 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Loops Min (us) Max (us) \\\n",
"fibonacci_pytensor fibonacci_inputs 10 54149.950002 55278.412500 \n",
"fibonacci_numba fibonacci_inputs 10 84.208400 105.558301 \n",
"fibonacci_jax fibonacci_inputs 10 7.162499 31.558401 \n",
"fibonacci_pytensor_numba fibonacci_inputs 10 2338.045801 2429.037497 \n",
"\n",
" Mean (us) StdDev (us) \\\n",
"fibonacci_pytensor fibonacci_inputs 54654.332500 434.902898 \n",
"fibonacci_numba fibonacci_inputs 95.575969 4.793104 \n",
"fibonacci_jax fibonacci_inputs 14.425020 8.923158 \n",
"fibonacci_pytensor_numba fibonacci_inputs 2385.059159 34.438988 \n",
"\n",
" Median (us) IQR (us) \\\n",
"fibonacci_pytensor fibonacci_inputs 54807.945801 694.087401 \n",
"fibonacci_numba fibonacci_inputs 95.562500 5.033301 \n",
"fibonacci_jax fibonacci_inputs 12.233300 5.920898 \n",
"fibonacci_pytensor_numba fibonacci_inputs 2392.570800 58.641701 \n",
"\n",
" OPS (Kops/s) Samples \n",
"fibonacci_pytensor fibonacci_inputs 0.018297 5 \n",
"fibonacci_numba fibonacci_inputs 10.462881 13 \n",
"fibonacci_jax fibonacci_inputs 69.323996 5 \n",
"fibonacci_pytensor_numba fibonacci_inputs 0.419277 5 "
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fibonacci_bench.run(\n",
" inputs={\n",
" \"fibonacci_inputs\": {\"b\": np.ones(1, dtype=np.int32)},\n",
" }\n",
")\n",
"fibonacci_bench.summary()"
]
},
{
"cell_type": "markdown",
"id": "7bbb35a2",
"metadata": {},
"source": [
"## Element-wise multiplication Algorithm"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7b8bb97b",
"metadata": {},
"outputs": [],
"source": [
"# a_symbolic = pt.vector(\"a\", dtype=\"int32\")\n",
"# b_symbolic = pt.vector(\"b\", dtype=\"int32\")\n",
"\n",
"# def step(a_element, b_element):\n",
"# return a_element * b_element\n",
" \n",
"# c, _ = pytensor.scan(\n",
"# fn=step,\n",
"# sequences=[a_symbolic, b_symbolic]\n",
"# )\n",
"\n",
"# # compile function returning final a\n",
"# c_mode = get_mode(\"FAST_RUN\").excluding(\"scan_push_out_seq\")\n",
"# elementwise_multiply_pytensor = pytensor.function([a_symbolic, b_symbolic], c, trust_input=True, mode=c_mode)\n",
"\n",
"# numba_mode = get_mode(\"NUMBA\").excluding(\"scan_push_out_seq\")\n",
"# elementwise_multiply_pytensor_numba = pytensor.function([a_symbolic, b_symbolic], c, mode=numba_mode, trust_input=True)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "0b076f77",
"metadata": {},
"outputs": [],
"source": [
"# Not sure this makes a difference\n",
"a_symbolic = pt.vector(\"a\", dtype=\"int32\")\n",
"b_symbolic = pt.vector(\"b\", dtype=\"int32\")\n",
"N_STEPS = 1000\n",
"\n",
"def step(a_element, b_element):\n",
" return a_element * b_element\n",
"\n",
"c, _ = pytensor.scan(\n",
" fn=step,\n",
" sequences=[a_symbolic, b_symbolic],\n",
" n_steps=N_STEPS\n",
")\n",
"\n",
"# compile function returning final a\n",
"c_mode = get_mode(\"FAST_RUN\").excluding(\"scan_push_out_seq\")\n",
"elementwise_multiply_pytensor = pytensor.function([a_symbolic, b_symbolic], c, trust_input=True, mode=c_mode)\n",
"\n",
"numba_mode = get_mode(\"NUMBA\").excluding(\"scan_push_out_seq\")\n",
"elementwise_multiply_pytensor_numba = pytensor.function([a_symbolic, b_symbolic], c, mode=numba_mode, trust_input=True)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "14327cbf",
"metadata": {},
"outputs": [],
"source": [
"@jit(nopython=True)\n",
"def elementwise_multiply_numba(a, b):\n",
" n = a.shape[0]\n",
" c = np.empty(n, dtype=a.dtype)\n",
" for i in range(n):\n",
" c[i] = a[i] * b[i]\n",
" return c"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "13715c9b",
"metadata": {},
"outputs": [],
"source": [
"@block\n",
"@jax.jit\n",
"def elementwise_multiply_jax(a, b):\n",
" n = a.shape[0]\n",
" c_init = jnp.empty(n, dtype=a.dtype)\n",
" def step(i, c):\n",
" return jax.lax.dynamic_update_index_in_dim(c, a[i] * b[i], i, axis=0)\n",
" \n",
" c = jax.lax.fori_loop(0, n, step, c_init)\n",
" return c"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "d43c4f9c",
"metadata": {},
"outputs": [],
"source": [
"a = np.random.normal(0, 1, (N_STEPS)).astype(np.int32)\n",
"b = np.random.normal(0, 1, (N_STEPS)).astype(np.int32)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "f0f4ede5",
"metadata": {},
"outputs": [],
"source": [
"elem_mult_bench = Benchmarker(\n",
" functions=[elementwise_multiply_pytensor, elementwise_multiply_numba, elementwise_multiply_jax, elementwise_multiply_pytensor_numba], \n",
" names=['elementwise_multiply_pytensor', 'elementwise_multiply_numba', 'elementwise_multiply_jax', 'elementwise_multiply_pytensor_numba'],\n",
" number=10\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "cdab8946",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" | \n",
" Loops | \n",
" Min (us) | \n",
" Max (us) | \n",
" Mean (us) | \n",
" StdDev (us) | \n",
" Median (us) | \n",
" IQR (us) | \n",
" OPS (Kops/s) | \n",
" Samples | \n",
"
\n",
" \n",
" \n",
" \n",
" elementwise_multiply_pytensor | \n",
" elem_mult_inputs | \n",
" 10 | \n",
" 450.595800 | \n",
" 967.095900 | \n",
" 492.991945 | \n",
" 36.396199 | \n",
" 491.412499 | \n",
" 18.221901 | \n",
" 2.028431 | \n",
" 220 | \n",
"
\n",
" \n",
" elementwise_multiply_numba | \n",
" elem_mult_inputs | \n",
" 10 | \n",
" 0.366700 | \n",
" 0.720800 | \n",
" 0.394247 | \n",
" 0.073563 | \n",
" 0.379200 | \n",
" 0.012497 | \n",
" 2536.478255 | \n",
" 21 | \n",
"
\n",
" \n",
" elementwise_multiply_jax | \n",
" elem_mult_inputs | \n",
" 10 | \n",
" 7.512499 | \n",
" 10.391598 | \n",
" 8.152427 | \n",
" 0.621782 | \n",
" 7.895901 | \n",
" 0.593750 | \n",
" 122.662853 | \n",
" 55 | \n",
"
\n",
" \n",
" elementwise_multiply_pytensor_numba | \n",
" elem_mult_inputs | \n",
" 10 | \n",
" 34.662499 | \n",
" 50.737502 | \n",
" 39.280821 | \n",
" 5.934219 | \n",
" 37.408300 | \n",
" 3.912600 | \n",
" 25.457717 | \n",
" 5 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Loops Min (us) \\\n",
"elementwise_multiply_pytensor elem_mult_inputs 10 450.595800 \n",
"elementwise_multiply_numba elem_mult_inputs 10 0.366700 \n",
"elementwise_multiply_jax elem_mult_inputs 10 7.512499 \n",
"elementwise_multiply_pytensor_numba elem_mult_inputs 10 34.662499 \n",
"\n",
" Max (us) Mean (us) \\\n",
"elementwise_multiply_pytensor elem_mult_inputs 967.095900 492.991945 \n",
"elementwise_multiply_numba elem_mult_inputs 0.720800 0.394247 \n",
"elementwise_multiply_jax elem_mult_inputs 10.391598 8.152427 \n",
"elementwise_multiply_pytensor_numba elem_mult_inputs 50.737502 39.280821 \n",
"\n",
" StdDev (us) \\\n",
"elementwise_multiply_pytensor elem_mult_inputs 36.396199 \n",
"elementwise_multiply_numba elem_mult_inputs 0.073563 \n",
"elementwise_multiply_jax elem_mult_inputs 0.621782 \n",
"elementwise_multiply_pytensor_numba elem_mult_inputs 5.934219 \n",
"\n",
" Median (us) IQR (us) \\\n",
"elementwise_multiply_pytensor elem_mult_inputs 491.412499 18.221901 \n",
"elementwise_multiply_numba elem_mult_inputs 0.379200 0.012497 \n",
"elementwise_multiply_jax elem_mult_inputs 7.895901 0.593750 \n",
"elementwise_multiply_pytensor_numba elem_mult_inputs 37.408300 3.912600 \n",
"\n",
" OPS (Kops/s) Samples \n",
"elementwise_multiply_pytensor elem_mult_inputs 2.028431 220 \n",
"elementwise_multiply_numba elem_mult_inputs 2536.478255 21 \n",
"elementwise_multiply_jax elem_mult_inputs 122.662853 55 \n",
"elementwise_multiply_pytensor_numba elem_mult_inputs 25.457717 5 "
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"elem_mult_bench.run(\n",
" inputs={\n",
" \"elem_mult_inputs\": {\"a\": a, \"b\": b},\n",
" }\n",
")\n",
"elem_mult_bench.summary()"
]
},
{
"cell_type": "markdown",
"id": "25a87710",
"metadata": {},
"source": [
"# Changepoint Detection Algorithms"
]
},
{
"cell_type": "markdown",
"id": "16b2b312",
"metadata": {},
"source": [
"## Cumulative Sum (CUSUM) Algorithm"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "c1226cfc",
"metadata": {},
"outputs": [],
"source": [
"@jit(nopython=True)\n",
"def cusum_adaptive_numba(x, alpha=0.01, k=0.5, h=5.0):\n",
" \"\"\"\n",
" Two-sided CUSUM with adaptive exponential moving average baseline.\n",
" \n",
" Parameters\n",
" ----------\n",
" x: np.ndarray\n",
" input signal\n",
" alpha: float\n",
" EMA smoothing factor (0 < alpha <= 1)\n",
" k: float\n",
" slack to avoid small changes triggering alarms\n",
" h: float\n",
" threshold for raising an alarm\n",
" \n",
" Returns\n",
" -------\n",
" s_pos: np.ndarray\n",
" upper CUSUM stats\n",
" s_neg: np.ndarray\n",
" lower CUSUM stats\n",
" mu_t: np.ndarray\n",
" evolving baseline estimate\n",
" alarms_pos: np.ndarray\n",
" alarms for upward changes\n",
" alarms_neg: np.ndarray\n",
" alarms for downward changes\n",
" \"\"\"\n",
" n = x.shape[0]\n",
"\n",
" s_pos = np.zeros(n, dtype=np.float32)\n",
" s_neg = np.zeros(n, dtype=np.float32)\n",
" mu_t = np.zeros(n, dtype=np.float32)\n",
" alarms_pos = np.zeros(n, dtype=np.bool_)\n",
" alarms_neg = np.zeros(n, dtype=np.bool_)\n",
"\n",
" # Initialization\n",
" mu_t[0] = x[0]\n",
"\n",
" for i in range(1, n):\n",
" # Update baseline (EMA)\n",
" mu_t[i] = alpha * x[i] + (1 - alpha) * mu_t[i-1]\n",
"\n",
" # Update CUSUM stats\n",
" s_pos[i] = max(0.0, s_pos[i-1] + x[i] - mu_t[i] - k)\n",
" s_neg[i] = max(0.0, s_neg[i-1] - (x[i] - mu_t[i]) - k)\n",
"\n",
" # Alarms\n",
" alarms_pos[i] = s_pos[i] > h\n",
" alarms_neg[i] = s_neg[i] > h\n",
"\n",
" return s_pos, s_neg, mu_t, alarms_pos, alarms_neg"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "14937a2a",
"metadata": {},
"outputs": [],
"source": [
"@block\n",
"@jax.jit\n",
"def cusum_adaptive_jax(x, alpha=0.01, k=0.5, h=5.0):\n",
" \"\"\"\n",
" Two-sided CUSUM with adaptive exponential moving average baseline.\n",
" \n",
" Parameters\n",
" ----------\n",
" x: jnp.ndarray\n",
" input signal\n",
" alpha: float\n",
" EMA smoothing factor (0 < alpha <= 1)\n",
" k: float\n",
" slack to avoid small changes triggering alarms\n",
" h: float\n",
" threshold for raising an alarm\n",
" \n",
" Returns\n",
" -------\n",
" s_pos: jnp.ndarray\n",
" upper CUSUM stats\n",
" s_neg: jnp.ndarray\n",
" lower CUSUM stats\n",
" mu_t: jnp.ndarray\n",
" evolving baseline estimate\n",
" alarms_pos: jnp.ndarray\n",
" alarms for upward changes\n",
" alarms_neg: jnp.ndarray\n",
" alarms for downward changes\n",
" \"\"\"\n",
" def body(carry, x_t):\n",
" s_pos_prev, s_neg_prev, mu_prev = carry\n",
" \n",
" # Update EMA baseline\n",
" mu_t = alpha * x_t + (1 - alpha) * mu_prev\n",
" \n",
" # Update CUSUMs using updated baseline\n",
" s_pos = jnp.maximum(0.0, s_pos_prev + x_t - mu_t - k)\n",
" s_neg = jnp.maximum(0.0, s_neg_prev - (x_t - mu_t) - k)\n",
" \n",
" new_carry = (s_pos, s_neg, mu_t)\n",
" output = (s_pos, s_neg, mu_t)\n",
" return new_carry, output\n",
"\n",
" # Initialize: CUSUMs at 0, initial mean = first sample\n",
" s0 = (0.0, 0.0, x[0])\n",
" _, (s_pos_vals, s_neg_vals, mu_vals) = jax.lax.scan(body, s0, x)\n",
" \n",
" # Thresholding\n",
" alarms_pos = s_pos_vals > h\n",
" alarms_neg = s_neg_vals > h\n",
"\n",
" return s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "815967a4",
"metadata": {},
"outputs": [],
"source": [
"x_symbolic = pt.vector(\"x\")\n",
"alpha_symbolic = pt.scalar(\"alpha\")\n",
"k_symbolic = pt.scalar(\"k\")\n",
"h_symbolic = pt.scalar(\"h\")\n",
"\n",
"N_STEPS = 100 # Fixing this just incase but I don't think this is changing anything when we have sequences as input to scan\n",
"\n",
"def step(x_t, s_pos_prev, s_neg_prev, mu_prev, alpha, k):\n",
" # Update EMA baseline\n",
" mu_t = alpha * x_t + (1 - alpha) * mu_prev\n",
" \n",
" # Update CUSUMs using updated baseline\n",
" s_pos = pt.maximum(0.0, s_pos_prev + x_t - mu_t - k)\n",
" s_neg = pt.maximum(0.0, s_neg_prev - (x_t - mu_t) - k)\n",
" \n",
" return s_pos, s_neg, mu_t\n",
"\n",
"\n",
"(s_pos_vals, s_neg_vals, mu_vals), updates = pytensor.scan(\n",
" fn=step,\n",
" outputs_info=[pt.constant(0., dtype=\"float32\"), pt.constant(0., dtype=\"float32\"), x_symbolic[0]],\n",
" non_sequences=[alpha_symbolic, k_symbolic],\n",
" sequences=[x_symbolic],\n",
" n_steps=N_STEPS\n",
")\n",
"\n",
"# Thresholding\n",
"alarms_pos = s_pos_vals > h_symbolic\n",
"alarms_neg = s_neg_vals > h_symbolic\n",
"\n",
"cusum_adaptive_pytensor = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], trust_input=True)\n",
"\n",
"cusum_adaptive_pytensor_numba = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], mode=\"NUMBA\", trust_input=True)\n",
"cusum_adaptive_pytensor_jax = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], mode=\"JAX\", trust_input=True)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "6c892129",
"metadata": {},
"outputs": [],
"source": [
"xs1 = np.random.normal(80, 20, size=(int(N_STEPS/2)))\n",
"xs2 = np.random.normal(50, 20, size=(int(N_STEPS/2)))\n",
"xs = np.concat((xs1, xs2))\n",
"xs = xs.astype(np.float32)\n",
"xs = xs.astype(np.float32)\n",
"xs_std = (xs - xs.mean()) / xs.std()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "d21873d6",
"metadata": {},
"outputs": [],
"source": [
"cusum_bench = Benchmarker(\n",
" functions=[cusum_adaptive_pytensor, cusum_adaptive_numba, cusum_adaptive_jax, cusum_adaptive_pytensor_numba, block(cusum_adaptive_pytensor_jax)], \n",
" names=['cusum_adaptive_pytensor', 'cusum_adaptive_numba', 'cusum_adaptive_jax', 'cusum_adaptive_pytensor_numba', 'cusum_adaptive_pytensor_jax'],\n",
" number=10\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "d8eab72d",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" | \n",
" Loops | \n",
" Min (us) | \n",
" Max (us) | \n",
" Mean (us) | \n",
" StdDev (us) | \n",
" Median (us) | \n",
" IQR (us) | \n",
" OPS (Kops/s) | \n",
" Samples | \n",
"
\n",
" \n",
" \n",
" \n",
" cusum_adaptive_pytensor | \n",
" cusum_inputs | \n",
" 10 | \n",
" 124.412501 | \n",
" 226.245899 | \n",
" 141.837420 | \n",
" 8.902762 | \n",
" 139.566700 | \n",
" 10.325000 | \n",
" 7.050326 | \n",
" 681 | \n",
"
\n",
" \n",
" cusum_adaptive_numba | \n",
" cusum_inputs | \n",
" 10 | \n",
" 1.729201 | \n",
" 2.174999 | \n",
" 1.806571 | \n",
" 0.150796 | \n",
" 1.745898 | \n",
" 0.018753 | \n",
" 553.534780 | \n",
" 7 | \n",
"
\n",
" \n",
" cusum_adaptive_jax | \n",
" cusum_inputs | \n",
" 10 | \n",
" 14.366599 | \n",
" 18.899998 | \n",
" 15.076731 | \n",
" 0.942976 | \n",
" 14.633298 | \n",
" 0.738525 | \n",
" 66.327378 | \n",
" 36 | \n",
"
\n",
" \n",
" cusum_adaptive_pytensor_numba | \n",
" cusum_inputs | \n",
" 10 | \n",
" 24.316600 | \n",
" 36.341700 | \n",
" 27.204980 | \n",
" 4.596065 | \n",
" 25.033302 | \n",
" 1.241703 | \n",
" 36.757976 | \n",
" 5 | \n",
"
\n",
" \n",
" cusum_adaptive_pytensor_jax | \n",
" cusum_inputs | \n",
" 10 | \n",
" 21.479101 | \n",
" 27.295901 | \n",
" 22.782493 | \n",
" 1.642595 | \n",
" 21.893748 | \n",
" 1.734347 | \n",
" 43.893352 | \n",
" 30 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Loops Min (us) Max (us) \\\n",
"cusum_adaptive_pytensor cusum_inputs 10 124.412501 226.245899 \n",
"cusum_adaptive_numba cusum_inputs 10 1.729201 2.174999 \n",
"cusum_adaptive_jax cusum_inputs 10 14.366599 18.899998 \n",
"cusum_adaptive_pytensor_numba cusum_inputs 10 24.316600 36.341700 \n",
"cusum_adaptive_pytensor_jax cusum_inputs 10 21.479101 27.295901 \n",
"\n",
" Mean (us) StdDev (us) \\\n",
"cusum_adaptive_pytensor cusum_inputs 141.837420 8.902762 \n",
"cusum_adaptive_numba cusum_inputs 1.806571 0.150796 \n",
"cusum_adaptive_jax cusum_inputs 15.076731 0.942976 \n",
"cusum_adaptive_pytensor_numba cusum_inputs 27.204980 4.596065 \n",
"cusum_adaptive_pytensor_jax cusum_inputs 22.782493 1.642595 \n",
"\n",
" Median (us) IQR (us) \\\n",
"cusum_adaptive_pytensor cusum_inputs 139.566700 10.325000 \n",
"cusum_adaptive_numba cusum_inputs 1.745898 0.018753 \n",
"cusum_adaptive_jax cusum_inputs 14.633298 0.738525 \n",
"cusum_adaptive_pytensor_numba cusum_inputs 25.033302 1.241703 \n",
"cusum_adaptive_pytensor_jax cusum_inputs 21.893748 1.734347 \n",
"\n",
" OPS (Kops/s) Samples \n",
"cusum_adaptive_pytensor cusum_inputs 7.050326 681 \n",
"cusum_adaptive_numba cusum_inputs 553.534780 7 \n",
"cusum_adaptive_jax cusum_inputs 66.327378 36 \n",
"cusum_adaptive_pytensor_numba cusum_inputs 36.757976 5 \n",
"cusum_adaptive_pytensor_jax cusum_inputs 43.893352 30 "
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cusum_bench.run(\n",
" inputs={\n",
" \"cusum_inputs\": {\"x\": xs, \"alpha\": 0.1, \"k\": 0.5, \"h\": 3.5},\n",
" }\n",
")\n",
"cusum_bench.summary()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "3e9c3339",
"metadata": {},
"outputs": [],
"source": [
"outputs = cusum_adaptive_numba(xs_std, alpha=0.1, k=0.5, h=3.5)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "b85d0c0e",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
" \n",
" \n",
" "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = go.Figure()\n",
"fig.add_traces(\n",
" [\n",
" go.Scatter(\n",
" x = np.arange(len(xs)),\n",
" y = xs_std,\n",
" name=\"series\"\n",
" ),\n",
" go.Scatter(\n",
" x = np.arange(len(xs)),\n",
" y = outputs[0],\n",
" name=\"cum. positive devs.\"\n",
" ),\n",
" go.Scatter(\n",
" x = np.arange(len(xs)),\n",
" y = outputs[1],\n",
" name=\"cum. negative devs.\"\n",
" ),\n",
" go.Scatter(\n",
" x = np.arange(len(xs)),\n",
" y = outputs[2],\n",
" name=\"Exp. Mean\"\n",
" ),\n",
" go.Scatter(\n",
" x = np.arange(len(xs)),\n",
" y = outputs[3].astype(np.float16),\n",
" name=\"positive alarms\"\n",
" ),\n",
" go.Scatter(\n",
" x = np.arange(len(xs)),\n",
" y = outputs[4].astype(np.float16),\n",
" name=\"negative alarms\"\n",
" ),\n",
" \n",
" ]\n",
")\n",
"fig.update_layout(\n",
" title = dict(\n",
" text = \"CUSUM Change Point Detection Algorithm\"\n",
" ),\n",
" xaxis=dict(\n",
" title = \"Time Index\"\n",
" ),\n",
" yaxis=dict(\n",
" title = \"Standardized Series Scaled\"\n",
" ),\n",
" legend=dict(\n",
" yanchor=\"top\",\n",
" y=1.1,\n",
" xanchor=\"left\",\n",
" x=0,\n",
" orientation=\"h\"\n",
" ),\n",
" template=\"plotly_dark\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c2ce03ea",
"metadata": {},
"source": [
"## Pruned Exact Linear Time (PELT) Algorithm"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "eab4cbb9",
"metadata": {},
"outputs": [],
"source": [
"@jit(nopython=True)\n",
"def segment_cost_numba(S1, S2, i, j):\n",
" \"\"\"Cost of segment x[i:j], SSE around mean\"\"\"\n",
" n = j - i\n",
" sum_x = S1[j] - S1[i]\n",
" sum_x2 = S2[j] - S2[i]\n",
" if n > 0:\n",
" return sum_x2 - (sum_x ** 2) / n\n",
" else:\n",
" return np.inf\n",
"\n",
"@jit(nopython=True)\n",
"def pelt_numba(x, beta=10.0):\n",
" \"\"\"\n",
" Pruned Exact Linear Time algorithm for change point detection\n",
"\n",
" Parameters\n",
" ----------\n",
" x: np.ndarray\n",
" The timeseries signal\n",
" beta: float\n",
" Penalty of segmenting the series\n",
"\n",
" Returns\n",
" -------\n",
" C: np.ndarray\n",
" The best costs up to segment t\n",
" last_change: np.ndarray\n",
" The last change point up to segment t\n",
" \"\"\"\n",
" n = len(x)\n",
"\n",
" # cumulative sums for cost\n",
" S1 = np.empty(n+1, dtype=np.float32)\n",
" S2 = np.empty(n+1, dtype=np.float32)\n",
" S1[0], S2[0] = 0.0, 0.0\n",
" for i in range(1, n+1):\n",
" S1[i] = S1[i-1] + x[i-1]\n",
" S2[i] = S2[i-1] + x[i-1]**2\n",
"\n",
" # DP arrays\n",
" C = np.full((n+1,), np.inf)\n",
" C[0] = -beta\n",
" last_change = np.full((n+1,), -1)\n",
" min_size = 3\n",
"\n",
" for t in range(1, n+1):\n",
" costs = np.full(n, np.inf)\n",
" for s in range(n):\n",
" if s < t and (t - s) >= min_size:\n",
" costs[s] = C[s] + segment_cost_numba(S1, S2, s, t) + beta\n",
" best_s = np.argmin(costs)\n",
" C[t] = costs[best_s]\n",
" last_change[t] = best_s\n",
"\n",
" return C, last_change"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "e6997389",
"metadata": {},
"outputs": [],
"source": [
"def segment_cost_jax(S1, S2, i, j):\n",
" \"\"\"Cost of segment x[i:j], SSE around mean\"\"\"\n",
" n = j - i\n",
" sum_x = S1[j] - S1[i]\n",
" sum_x2 = S2[j] - S2[i]\n",
" return jnp.where(n > 0, sum_x2 - (sum_x ** 2) / n, jnp.inf)\n",
"\n",
"@block\n",
"@jax.jit\n",
"def pelt_jax(x, beta=10.0):\n",
" \"\"\"\n",
" Pruned Exact Linear Time algorithm for change point detection\n",
"\n",
" Parameters\n",
" ----------\n",
" x: np.ndarray\n",
" The timeseries signal\n",
" beta: float\n",
" Penalty of segmenting the series\n",
"\n",
" Returns\n",
" -------\n",
" C: jnp.ndarray\n",
" The best costs up to segment t\n",
" last_change: jnp.ndarray\n",
" The last change point up to segment t\n",
" \"\"\"\n",
" n = len(x)\n",
"\n",
" # cumulative sums for cost\n",
" S1 = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(x)])\n",
" S2 = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(x**2)])\n",
"\n",
" # DP arrays\n",
" C = jnp.full((n+1,), jnp.inf)\n",
" C = C.at[0].set(-beta)\n",
" last_change = jnp.full((n+1,), -1)\n",
" min_size = 3\n",
"\n",
" s_all = jnp.arange(n) # all possible candidates\n",
"\n",
" def body(t, carry):\n",
" C, last_change = carry\n",
"\n",
" # Compute cost for all s < t, mask invalid\n",
" # valid = s_all < t & ((t - s_all) >= min_size)\n",
" \n",
" valid = (s_all < t) & ((t - s_all) >= min_size)\n",
" costs = jnp.where(\n",
" valid,\n",
" C[s_all] + segment_cost_jax(S1, S2, s_all, t) + beta,\n",
" jnp.inf\n",
" )\n",
"\n",
" best_s = jnp.argmin(costs)\n",
" C = C.at[t].set(costs[best_s])\n",
" last_change = last_change.at[t].set(best_s)\n",
" return C, last_change\n",
"\n",
" C, last_change = jax.lax.fori_loop(1, n+1, body, (C, last_change))\n",
" return C, last_change"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "094b9e8e",
"metadata": {},
"outputs": [],
"source": [
"def segment_cost_pytensor(S1, S2, i, j):\n",
" \"\"\"Cost of segment x[i:j], SSE around mean\"\"\"\n",
" n = j - i\n",
" sum_x = S1[j] - S1[i]\n",
" sum_x2 = S2[j] - S2[i]\n",
" return pt.switch(\n",
" pt.gt(n, 0),\n",
" sum_x2 - (sum_x ** 2) / n,\n",
" np.inf\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "03e5e927",
"metadata": {},
"outputs": [],
"source": [
"x_symbolic = pt.vector(\"x\")\n",
"beta_symbolic = pt.scalar(\"beta\")\n",
"n = x_symbolic.shape[0]\n",
"N_STEPS=100\n",
"\n",
"# cumulative sums for cost\n",
"S1 = pt.concatenate([pt.as_tensor([0.0]), pt.cumsum(x_symbolic)])\n",
"S2 = pt.concatenate([pt.as_tensor([0.0]), pt.cumsum(x_symbolic**2)])\n",
"\n",
"# DP arrays\n",
"C_init = pt.alloc(np.inf, n+1)\n",
"C_init = pt.set_subtensor(C_init[0], -beta_symbolic)\n",
"last_change_init = pt.alloc(-1, n+1)\n",
"\n",
"s_all = pt.arange(n) # candidate change points\n",
"min_size = 3\n",
"\n",
"def step(t, C_prev, last_change_prev, S1, S2, beta_symbolic, s_all):\n",
" # valid = (s_all < t) & ((t - s_all) >= min_size)\n",
" valid = pt.and_(pt.lt(s_all, t), pt.ge(t - s_all, min_size))\n",
"\n",
" # compute costs for all candidates\n",
" costs, _ = pytensor.scan(\n",
" fn=lambda s: pt.switch(\n",
" valid[s],\n",
" C_prev[s] + segment_cost_pytensor(S1, S2, s, t) + beta_symbolic,\n",
" np.inf\n",
" ),\n",
" sequences=[pt.arange(n)]\n",
" )\n",
" costs = costs.flatten()\n",
"\n",
" best_s = pt.argmin(costs, axis=0)\n",
" C_new = pt.set_subtensor(C_prev[t], costs[best_s])\n",
" last_change_new = pt.set_subtensor(last_change_prev[t], best_s)\n",
"\n",
" return C_new, last_change_new\n",
"\n",
"(C_vals, last_change_vals), _ = pytensor.scan(\n",
" fn=step,\n",
" sequences=[pt.arange(1, n+1)],\n",
" outputs_info=[C_init, last_change_init],\n",
" non_sequences=[S1, S2, beta_symbolic, s_all],\n",
" n_steps=N_STEPS # Added fixed iterations here\n",
")\n",
"\n",
"pelt_pytensor = pytensor.function([x_symbolic, beta_symbolic], [C_vals[-1], last_change_vals[-1]], trust_input=True)\n",
"pelt_pytensor_numba = pytensor.function(inputs=[x_symbolic, beta_symbolic], outputs=[C_vals[-1], last_change_vals[-1]], mode=\"NUMBA\", trust_input=True)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "69aea9a2",
"metadata": {},
"outputs": [],
"source": [
"pelt_bench = Benchmarker(\n",
" functions=[pelt_pytensor, pelt_numba, pelt_jax, pelt_pytensor_numba], \n",
" names=['pelt_pytensor', 'pelt_numba', 'pelt_jax', 'pelt_pytensor_numba'],\n",
" number=10\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "5aee4a18",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" | \n",
" Loops | \n",
" Min (us) | \n",
" Max (us) | \n",
" Mean (us) | \n",
" StdDev (us) | \n",
" Median (us) | \n",
" IQR (us) | \n",
" OPS (Kops/s) | \n",
" Samples | \n",
"
\n",
" \n",
" \n",
" \n",
" pelt_pytensor | \n",
" pelt_inputs | \n",
" 10 | \n",
" 11903.712500 | \n",
" 12687.708301 | \n",
" 12234.242133 | \n",
" 245.989123 | \n",
" 12224.875001 | \n",
" 277.799901 | \n",
" 0.081738 | \n",
" 9 | \n",
"
\n",
" \n",
" pelt_numba | \n",
" pelt_inputs | \n",
" 10 | \n",
" 17.408299 | \n",
" 22.204101 | \n",
" 19.400820 | \n",
" 1.585318 | \n",
" 19.241701 | \n",
" 1.000002 | \n",
" 51.544213 | \n",
" 5 | \n",
"
\n",
" \n",
" pelt_jax | \n",
" pelt_inputs | \n",
" 10 | \n",
" 64.616700 | \n",
" 82.345799 | \n",
" 71.339679 | \n",
" 4.879389 | \n",
" 70.279100 | \n",
" 4.893748 | \n",
" 14.017445 | \n",
" 19 | \n",
"
\n",
" \n",
" pelt_pytensor_numba | \n",
" pelt_inputs | \n",
" 10 | \n",
" 2330.925001 | \n",
" 2496.158300 | \n",
" 2424.299980 | \n",
" 56.514842 | \n",
" 2435.225001 | \n",
" 61.833399 | \n",
" 0.412490 | \n",
" 5 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Loops Min (us) Max (us) \\\n",
"pelt_pytensor pelt_inputs 10 11903.712500 12687.708301 \n",
"pelt_numba pelt_inputs 10 17.408299 22.204101 \n",
"pelt_jax pelt_inputs 10 64.616700 82.345799 \n",
"pelt_pytensor_numba pelt_inputs 10 2330.925001 2496.158300 \n",
"\n",
" Mean (us) StdDev (us) Median (us) \\\n",
"pelt_pytensor pelt_inputs 12234.242133 245.989123 12224.875001 \n",
"pelt_numba pelt_inputs 19.400820 1.585318 19.241701 \n",
"pelt_jax pelt_inputs 71.339679 4.879389 70.279100 \n",
"pelt_pytensor_numba pelt_inputs 2424.299980 56.514842 2435.225001 \n",
"\n",
" IQR (us) OPS (Kops/s) Samples \n",
"pelt_pytensor pelt_inputs 277.799901 0.081738 9 \n",
"pelt_numba pelt_inputs 1.000002 51.544213 5 \n",
"pelt_jax pelt_inputs 4.893748 14.017445 19 \n",
"pelt_pytensor_numba pelt_inputs 61.833399 0.412490 5 "
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pelt_bench.run(\n",
" inputs={\n",
" \"pelt_inputs\": {\"x\": xs_std, \"beta\": 2. * np.log(len(xs_std))},\n",
" }\n",
")\n",
"pelt_bench.summary()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "47cae530",
"metadata": {},
"outputs": [],
"source": [
"outputs = pelt_numba(xs_std, 2. * np.log(len(xs_std)))"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "7b395b06",
"metadata": {},
"outputs": [],
"source": [
"def plot_pelt_diagnostics(x, cps, C):\n",
" \"\"\"\n",
" Diagnostic plots for PELT changepoint detection.\n",
" \n",
" Args:\n",
" x: 1D array, original time series\n",
" C: 1D array, cumulative DP cost from pelt()\n",
" cps: list of changepoint indices (sorted ascending)\n",
" \"\"\"\n",
" n = len(x)\n",
" cps_full = [0] + cps + [n]\n",
"\n",
" # Segment means, std, SSE\n",
" segment_means = []\n",
" segment_stds = []\n",
" segment_costs = []\n",
" for start, end in zip(cps_full[:-1], cps_full[1:]):\n",
" seg = x[start:end]\n",
" mean = np.mean(seg)\n",
" std = np.std(seg)\n",
" cost = np.sum((seg - mean)**2)\n",
" segment_means.append(mean)\n",
" segment_stds.append(std)\n",
" segment_costs.append(cost)\n",
"\n",
" # Step function for segment mean\n",
" mean_step = np.zeros(n)\n",
" for i, (start, end) in enumerate(zip(cps_full[:-1], cps_full[1:])):\n",
" mean_step[start:end] = segment_means[i]\n",
"\n",
" # Step function for segment std\n",
" std_step = np.zeros(n)\n",
" for i, (start, end) in enumerate(zip(cps_full[:-1], cps_full[1:])):\n",
" std_step[start:end] = segment_stds[i]\n",
"\n",
" if len(x) < 20:\n",
" title1 = \"Warning: Sample size is small - Detected Changepoints\"\n",
" else:\n",
" title1 = \"Detected Changepoints\"\n",
"\n",
" fig = make_subplots(\n",
" rows=4, \n",
" cols=1,\n",
" subplot_titles=(title1, \"Average Shifts\", \"Variability Shifts\", \"Cumulative Cost\")\n",
" )\n",
"\n",
" fig.add_trace(\n",
" go.Scatter(\n",
" x = np.arange(len(x)),\n",
" y = x,\n",
" line_color=\"royalblue\",\n",
" name = \"Actuals\",\n",
" mode=\"lines\",\n",
" showlegend=False,\n",
" hovertemplate=\"Time Point: %{x}
Actual: %{y}\"\n",
" ),\n",
" row=1, col=1\n",
" )\n",
"\n",
" for cp in cps:\n",
" fig.add_vline(x=cp, line_dash='dash', line_color=\"red\", row=1, col=1)\n",
"\n",
" fig.add_trace(\n",
" go.Scatter(\n",
" x = np.arange(len(x)),\n",
" y = x,\n",
" name = \"Actuals\",\n",
" mode=\"lines\",\n",
" line_color=\"rgba(105, 105, 105, 0.25)\",\n",
" showlegend=False,\n",
" hoverinfo=\"skip\"\n",
" ),\n",
" row=2, col=1\n",
" )\n",
"\n",
" fig.add_trace(\n",
" go.Scatter(\n",
" x = np.arange(len(x)),\n",
" y = mean_step,\n",
" name = \"Average\",\n",
" line_color=\"royalblue\",\n",
" showlegend=False,\n",
" hovertemplate=\"Time Point: %{x}
Average: %{y}\"\n",
" ),\n",
" row=2, col=1\n",
" )\n",
"\n",
" fig.add_trace(\n",
" go.Scatter(\n",
" x = np.arange(len(x)),\n",
" y = std_step,\n",
" name = \"Standard Deviation\",\n",
" line_color=\"royalblue\",\n",
" showlegend=False,\n",
" hovertemplate=\"Time Point: %{x}
Standard Deviation: %{y}\"\n",
" ),\n",
" row=3, col=1\n",
" )\n",
"\n",
" fig.add_trace(\n",
" go.Scatter(\n",
" x = np.arange(len(x)),\n",
" y = C,\n",
" name = \"Cumulative Cost\",\n",
" line_color=\"royalblue\",\n",
" showlegend=False,\n",
" hovertemplate=\"Time Point: %{x}
Cost: %{y}\"\n",
" ),\n",
" row=4, col=1\n",
" )\n",
"\n",
" for cp in cps:\n",
" fig.add_vline(x=cp, line_dash='dash', line_color=\"red\", row=4, col=1)\n",
"\n",
" return fig.update_layout(height=1000, width=1200, template=\"plotly_dark\")\n"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "e1de7df5",
"metadata": {},
"outputs": [],
"source": [
"def get_changepoints(last_change, n):\n",
" \"\"\"\n",
" Backtrack changepoints from last_change array.\n",
" \n",
" Args:\n",
" last_change: array from pelt()\n",
" n: length of input series\n",
"\n",
" Returns:\n",
" list of changepoint indices (sorted ascending)\n",
" \"\"\"\n",
" cps = []\n",
" t = n\n",
" while t > 0:\n",
" s = int(last_change[t])\n",
" if s <= 0:\n",
" break\n",
" cps.append(s)\n",
" t = s\n",
" return list(reversed(cps))"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "b73d80ac",
"metadata": {},
"outputs": [],
"source": [
"cps = get_changepoints(outputs[1], n=len(xs_std))"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "e2e376de",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_pelt_diagnostics(xs, cps, outputs[0])"
]
},
{
"cell_type": "markdown",
"id": "09f2e235",
"metadata": {},
"source": [
"# Kalman Filter Algorithms"
]
},
{
"cell_type": "markdown",
"id": "1c42336f",
"metadata": {},
"source": [
"## Linear Gaussian Kalman Filter"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "5004eeab",
"metadata": {},
"outputs": [],
"source": [
"@jit(nopython=True)\n",
"def atrocious_kalman_filter_numba(z, F, H, Q, R, x0, P0):\n",
" \"\"\"\n",
" This implementation of the Kalman filter is Atrocious and in standard Python would be a \n",
" BIG NO-NO. That being said this version SIGNIFICANTLY reduces Numba Compilation time. \n",
" \n",
" Linear Gaussian Kalman filter algorithm\n",
"\n",
" Parameters\n",
" ----------\n",
" z: np.ndarray\n",
" shape (T, m) - observations\n",
" F: np.ndarray\n",
" state transition matrix - shape (n, n)\n",
" H: np.ndarray\n",
" observation/design matrix - shape (m, n)\n",
" Q: np.ndarray\n",
" process noise covariance - shape (n, n)\n",
" R: np.ndarray\n",
" observation noise covariance - shape (m, m)\n",
" x0: np.ndarray\n",
" initial state mean - shape (n,)\n",
" P0: np.ndarray\n",
" initial state covariance - shape (n, n)\n",
"\n",
" Returns\n",
" -------\n",
" xs: np.ndarray\n",
" shape (T, n) - filtered state means\n",
" Ps: np.ndarray\n",
" shape (T, n, n) - filtered state covariances\n",
" \"\"\"\n",
" T = z.shape[0]\n",
" m = z.shape[1]\n",
" n = x0.shape[0]\n",
"\n",
" xs = np.empty((T, n), dtype=np.float32)\n",
" Ps = np.empty((T, n, n), dtype=np.float32)\n",
"\n",
" # local working arrays\n",
" x = np.empty(n, dtype=np.float32)\n",
" for i in range(n):\n",
" x[i] = x0[i]\n",
" P = np.empty((n, n), dtype=np.float32)\n",
" for i in range(n):\n",
" for j in range(n):\n",
" P[i, j] = P0[i, j]\n",
"\n",
" # temporary matrices/vectors\n",
" x_pred = np.empty((T, n), dtype=np.float32)\n",
" P_pred = np.empty((T, n, n), dtype=np.float32)\n",
" y = np.empty(m, dtype=np.float32)\n",
" S = np.empty((m, m), dtype=np.float32)\n",
" K = np.empty((n, m), dtype=np.float32)\n",
" I_n = np.eye(n, dtype=np.float32)\n",
"\n",
" for t in range(T):\n",
" # === Predict ===\n",
" # x_pred = F @ x\n",
" for i in range(n):\n",
" s = 0.0\n",
" for j in range(n):\n",
" s += F[i, j] * x[j]\n",
" x_pred[t, i] = s\n",
"\n",
" # P_pred = F @ P @ F.T + Q\n",
" # temp = F @ P\n",
" temp = np.empty((n, n), dtype=np.float32)\n",
" for i in range(n):\n",
" for j in range(n):\n",
" s = 0.0\n",
" for k in range(n):\n",
" s += F[i, k] * P[k, j]\n",
" temp[i, j] = s\n",
" # P_pred = temp @ F.T\n",
" for i in range(n):\n",
" for j in range(n):\n",
" s = 0.0\n",
" for k in range(n):\n",
" s += temp[i, k] * F[j, k] # F.T[k, j] = F[j, k]\n",
" P_pred[t, i, j] = s + Q[i, j]\n",
"\n",
" # === Update ===\n",
" # y = z[t] - H @ x_pred\n",
" for i in range(m):\n",
" s = 0.0\n",
" for j in range(n):\n",
" s += H[i, j] * x_pred[t, j]\n",
" y[i] = z[t, i] - s\n",
"\n",
" # S = H @ P_pred @ H.T + R\n",
" # temp2 = H @ P_pred\n",
" temp2 = np.empty((m, n), dtype=np.float32)\n",
" for i in range(m):\n",
" for j in range(n):\n",
" s = 0.0\n",
" for k in range(n):\n",
" s += H[i, k] * P_pred[t, k, j]\n",
" temp2[i, j] = s\n",
" # S = temp2 @ H.T\n",
" for i in range(m):\n",
" for j in range(m):\n",
" s = 0.0\n",
" for k in range(n):\n",
" s += temp2[i, k] * H[j, k] # H.T[k,j] = H[j,k]\n",
" S[i, j] = s + R[i, j]\n",
"\n",
" # K = P_pred @ H.T @ inv(S)\n",
" # first compute P_pred @ H.T -> (n, m)\n",
" P_Ht = np.empty((n, m), dtype=np.float32)\n",
" for i in range(n):\n",
" for j in range(m):\n",
" s = 0.0\n",
" for k in range(n):\n",
" s += P_pred[t, i, k] * H[j, k] # H.T[k,j] = H[j,k]\n",
" P_Ht[i, j] = s\n",
"\n",
" # invert S\n",
" S_inv = np.linalg.inv(S)\n",
"\n",
" # K = P_Ht @ S_inv (n,m) @ (m,m) -> (n,m)\n",
" for i in range(n):\n",
" for j in range(m):\n",
" s = 0.0\n",
" for k in range(m):\n",
" s += P_Ht[i, k] * S_inv[k, j]\n",
" K[i, j] = s\n",
"\n",
" # x = x_pred + K @ y\n",
" for i in range(n):\n",
" s = 0.0\n",
" for j in range(m):\n",
" s += K[i, j] * y[j]\n",
" x[i] = x_pred[t, i] + s\n",
"\n",
" # P = (I - K H) P_pred\n",
" # compute (I - K H)\n",
" KH = np.empty((n, n), dtype=np.float32)\n",
" for i in range(n):\n",
" for j in range(n):\n",
" s = 0.0\n",
" for k in range(m):\n",
" s += K[i, k] * H[k, j]\n",
" KH[i, j] = s\n",
"\n",
" I_minus_KH = np.empty((n, n), dtype=np.float32)\n",
" for i in range(n):\n",
" for j in range(n):\n",
" I_minus_KH[i, j] = I_n[i, j] - KH[i, j]\n",
"\n",
" # P = I_minus_KH @ P_pred\n",
" for i in range(n):\n",
" for j in range(n):\n",
" s = 0.0\n",
" for k in range(n):\n",
" s += I_minus_KH[i, k] * P_pred[t, k, j]\n",
" P[i, j] = s\n",
"\n",
" # store results\n",
" for i in range(n):\n",
" xs[t, i] = x[i]\n",
" for i in range(n):\n",
" for j in range(n):\n",
" Ps[t, i, j] = P[i, j]\n",
"\n",
" return xs, Ps, x_pred, P_pred\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "25bcb14e",
"metadata": {},
"outputs": [],
"source": [
"@jit(nopython=True)\n",
"def kalman_filter_numba(z, F, H, Q, R, x0, P0):\n",
" \"\"\"\n",
" Linear Gaussian Kalman filter algorithm\n",
"\n",
" Parameters\n",
" ----------\n",
" z: np.ndarray\n",
" shape (T, m) - observations\n",
" F: np.ndarray\n",
" state transition matrix - shape (n, n)\n",
" H: np.ndarray\n",
" observation/design matrix - shape (m, n)\n",
" Q: np.ndarray\n",
" process noise covariance - shape (n, n)\n",
" R: np.ndarray\n",
" observation noise covariance - shape (m, m)\n",
" x0: np.ndarray\n",
" initial state mean - shape (n,)\n",
" P0: np.ndarray\n",
" initial state covariance - shape (n, n)\n",
"\n",
" Returns\n",
" -------\n",
" xs: np.ndarray\n",
" shape (T, n) - filtered state means\n",
" Ps: np.ndarray\n",
" shape (T, n, n) - filtered state covariances\n",
" \"\"\"\n",
" T, m = z.shape\n",
" n = x0.shape[0]\n",
"\n",
" xs = np.zeros((T, n), dtype=np.float32)\n",
" Ps = np.zeros((T, n, n), dtype=np.float32)\n",
"\n",
" x_pred = np.zeros((T, n), dtype=np.float32)\n",
" P_pred = np.zeros((T, n, n), dtype=np.float32)\n",
"\n",
" x = x0.copy()\n",
" P = P0.copy()\n",
"\n",
" I = np.eye(n, dtype=np.float32)\n",
"\n",
" for t in range(T):\n",
" # --- Predict ---\n",
" x_pred[t] = F @ x\n",
" P_pred[t] = F @ P @ F.T + Q\n",
"\n",
" # --- Update ---\n",
" y = z[t] - H @ x_pred[t]\n",
" S = H @ P_pred[t] @ H.T + R\n",
" K = P_pred[t] @ H.T @ np.linalg.inv(S)\n",
"\n",
" x = x_pred[t] + K @ y\n",
" P = (I - K @ H) @ P_pred[t]\n",
"\n",
" xs[t] = x\n",
" Ps[t] = P\n",
"\n",
" return xs, Ps, x_pred, P_pred"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "cd715dfb",
"metadata": {},
"outputs": [],
"source": [
"@block\n",
"@jax.jit\n",
"def kalman_filter_jax(z, F, H, Q, R, x0, P0):\n",
" \"\"\"\n",
" Linear Gaussian Kalman filter algorithm\n",
"\n",
" Parameters\n",
" ----------\n",
" z: np.ndarray\n",
" shape (T, m) - observations\n",
" F: np.ndarray\n",
" state transition matrix - shape (n, n)\n",
" H: np.ndarray\n",
" observation/design matrix - shape (m, n)\n",
" Q: np.ndarray\n",
" process noise covariance - shape (n, n)\n",
" R: np.ndarray\n",
" observation noise covariance - shape (m, m)\n",
" x0: np.ndarray\n",
" initial state mean - shape (n,)\n",
" P0: np.ndarray\n",
" initial state covariance - shape (n, n)\n",
"\n",
" Returns\n",
" -------\n",
" xs: jnp.ndarray\n",
" shape (T, n) - filtered state means\n",
" Ps: jnp.ndarray\n",
" shape (T, n, n) - filtered state covariances\n",
" \"\"\"\n",
"\n",
" n = x0.shape[0]\n",
" I = jnp.eye(n)\n",
" X_pred_init = jnp.zeros((1,))\n",
" P_pred_init = jnp.zeros((1, 1,))\n",
"\n",
" def step(carry, z_t):\n",
" x, P, _, _ = carry\n",
"\n",
" # --- Predict ---\n",
" x_pred = F @ x\n",
" P_pred = F @ P @ F.T + Q\n",
"\n",
" # --- Update ---\n",
" y = z_t - H @ x_pred\n",
" S = H @ P_pred @ H.T + R\n",
" K = P_pred @ H.T @ jnp.linalg.inv(S)\n",
"\n",
" x_new = x_pred + K @ y\n",
" P_new = (I - K @ H) @ P_pred\n",
"\n",
" return (x_new, P_new, x_pred, P_pred), (x_new, P_new, x_pred, P_pred)\n",
"\n",
" # run scan\n",
" (_, _, _, _), (xs, Ps, x_pred, P_pred) = jax.lax.scan(step, (x0, P0, X_pred_init, P_pred_init), z)\n",
"\n",
" return xs, Ps, x_pred, P_pred"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "5af37c40",
"metadata": {},
"outputs": [],
"source": [
"z_symbolic = pt.matrix(\"z\")\n",
"F_symbolic = pt.matrix(\"F\")\n",
"H_symbolic = pt.matrix(\"H\")\n",
"Q_symbolic = pt.matrix(\"Q\")\n",
"R_symbolic = pt.matrix(\"R\")\n",
"x0_symbolic = pt.vector(\"x0\")\n",
"P0_symbolic = pt.matrix(\"P0\")\n",
"\n",
"n = x0_symbolic.shape[0]\n",
"I = pt.eye(n)\n",
"X_pred_init = pt.zeros_like(x0_symbolic)\n",
"P_pred_init = pt.zeros_like(P0_symbolic)\n",
"\n",
"N_STEPS = 500\n",
"\n",
"def step(z_t, x, P, x_pred, P_pred, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, I):\n",
"\n",
" # --- Predict ---\n",
" x_pred = F_symbolic @ x\n",
" P_pred = F_symbolic @ P @ F_symbolic.T + Q_symbolic\n",
"\n",
" # --- Update ---\n",
" y = z_t - H_symbolic @ x_pred\n",
" S = H_symbolic @ P_pred @ H_symbolic.T + R_symbolic\n",
" K = P_pred @ H_symbolic.T @ pt.linalg.inv(S)\n",
"\n",
" x_new = x_pred + K @ y\n",
" P_new = (I - K @ H_symbolic) @ P_pred\n",
"\n",
" return x_new, P_new, x_pred, P_pred\n",
"\n",
"# run scan\n",
"(xs, Ps, x_pred, P_pred), _ = pytensor.scan(\n",
" fn=step,\n",
" outputs_info=[x0_symbolic, P0_symbolic, X_pred_init, P_pred_init],\n",
" sequences=[z_symbolic],\n",
" non_sequences=[F_symbolic, H_symbolic, Q_symbolic, R_symbolic, I],\n",
" n_steps=N_STEPS\n",
")\n",
"\n",
"kalman_filter_pytensor = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], trust_input=True)\n",
"\n",
"kalman_filter_pytensor_numba = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], mode=\"NUMBA\", trust_input=True)\n",
"kalman_filter_pytensor_jax = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], mode=\"JAX\", trust_input=True)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "3c9e7254",
"metadata": {},
"outputs": [],
"source": [
"T = 500\n",
"F = np.array([[1.0]]).astype(np.float32)\n",
"H = np.array([[1.0]]).astype(np.float32)\n",
"Q = np.array([[0.01]]).astype(np.float32)\n",
"R = np.array([[0.1]]).astype(np.float32)\n",
"x0 = np.array([0.0]).astype(np.float32)\n",
"P0 = np.array([[1.0]]).astype(np.float32)\n",
"\n",
"true = 1.0\n",
"z = (true + 0.4*np.random.randn(T)).reshape(T, 1).astype(np.float32)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "00472afd",
"metadata": {},
"outputs": [],
"source": [
"kalman_filter_bench = Benchmarker(\n",
" functions=[kalman_filter_pytensor, atrocious_kalman_filter_numba, kalman_filter_numba, kalman_filter_jax, kalman_filter_pytensor_numba, block(kalman_filter_pytensor_jax)], \n",
" names=['kalman_filter_pytensor', 'atrocious_kalman_filter_numba', 'kalman_filter_numba', 'kalman_filter_jax', 'kalman_filter_pytensor_numba', 'kalman_filter_pytensor_jax'],\n",
" number=10\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "68ec703d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/tmpmbe0qi_u:33: NumbaWarning:\n",
"\n",
"\u001b[1m\u001b[1mCannot cache compiled function \"scan\" as it uses dynamic globals (such as ctypes pointers and large global arrays)\u001b[0m\u001b[0m\n",
"\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" | \n",
" Loops | \n",
" Min (us) | \n",
" Max (us) | \n",
" Mean (us) | \n",
" StdDev (us) | \n",
" Median (us) | \n",
" IQR (us) | \n",
" OPS (Kops/s) | \n",
" Samples | \n",
"
\n",
" \n",
" \n",
" \n",
" kalman_filter_pytensor | \n",
" kalman_filter_inputs | \n",
" 10 | \n",
" 5675.591601 | \n",
" 6334.566601 | \n",
" 5949.072047 | \n",
" 159.617745 | \n",
" 5937.045900 | \n",
" 191.600001 | \n",
" 0.168093 | \n",
" 17 | \n",
"
\n",
" \n",
" atrocious_kalman_filter_numba | \n",
" kalman_filter_inputs | \n",
" 10 | \n",
" 297.462501 | \n",
" 330.600000 | \n",
" 314.415020 | \n",
" 11.425574 | \n",
" 315.533401 | \n",
" 14.287402 | \n",
" 3.180510 | \n",
" 5 | \n",
"
\n",
" \n",
" kalman_filter_numba | \n",
" kalman_filter_inputs | \n",
" 10 | \n",
" 641.612499 | \n",
" 718.120800 | \n",
" 684.415800 | \n",
" 32.973390 | \n",
" 706.012500 | \n",
" 62.024998 | \n",
" 1.461100 | \n",
" 5 | \n",
"
\n",
" \n",
" kalman_filter_jax | \n",
" kalman_filter_inputs | \n",
" 10 | \n",
" 305.941701 | \n",
" 368.020899 | \n",
" 333.580105 | \n",
" 18.368842 | \n",
" 329.077049 | \n",
" 22.434351 | \n",
" 2.997781 | \n",
" 18 | \n",
"
\n",
" \n",
" kalman_filter_pytensor_numba | \n",
" kalman_filter_inputs | \n",
" 10 | \n",
" 853.600001 | \n",
" 933.024997 | \n",
" 897.570000 | \n",
" 33.462131 | \n",
" 918.749999 | \n",
" 61.308398 | \n",
" 1.114119 | \n",
" 5 | \n",
"
\n",
" \n",
" kalman_filter_pytensor_jax | \n",
" kalman_filter_inputs | \n",
" 10 | \n",
" 304.041599 | \n",
" 372.449998 | \n",
" 330.685825 | \n",
" 17.271642 | \n",
" 327.668698 | \n",
" 24.356249 | \n",
" 3.024018 | \n",
" 20 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Loops Min (us) \\\n",
"kalman_filter_pytensor kalman_filter_inputs 10 5675.591601 \n",
"atrocious_kalman_filter_numba kalman_filter_inputs 10 297.462501 \n",
"kalman_filter_numba kalman_filter_inputs 10 641.612499 \n",
"kalman_filter_jax kalman_filter_inputs 10 305.941701 \n",
"kalman_filter_pytensor_numba kalman_filter_inputs 10 853.600001 \n",
"kalman_filter_pytensor_jax kalman_filter_inputs 10 304.041599 \n",
"\n",
" Max (us) Mean (us) \\\n",
"kalman_filter_pytensor kalman_filter_inputs 6334.566601 5949.072047 \n",
"atrocious_kalman_filter_numba kalman_filter_inputs 330.600000 314.415020 \n",
"kalman_filter_numba kalman_filter_inputs 718.120800 684.415800 \n",
"kalman_filter_jax kalman_filter_inputs 368.020899 333.580105 \n",
"kalman_filter_pytensor_numba kalman_filter_inputs 933.024997 897.570000 \n",
"kalman_filter_pytensor_jax kalman_filter_inputs 372.449998 330.685825 \n",
"\n",
" StdDev (us) Median (us) \\\n",
"kalman_filter_pytensor kalman_filter_inputs 159.617745 5937.045900 \n",
"atrocious_kalman_filter_numba kalman_filter_inputs 11.425574 315.533401 \n",
"kalman_filter_numba kalman_filter_inputs 32.973390 706.012500 \n",
"kalman_filter_jax kalman_filter_inputs 18.368842 329.077049 \n",
"kalman_filter_pytensor_numba kalman_filter_inputs 33.462131 918.749999 \n",
"kalman_filter_pytensor_jax kalman_filter_inputs 17.271642 327.668698 \n",
"\n",
" IQR (us) OPS (Kops/s) \\\n",
"kalman_filter_pytensor kalman_filter_inputs 191.600001 0.168093 \n",
"atrocious_kalman_filter_numba kalman_filter_inputs 14.287402 3.180510 \n",
"kalman_filter_numba kalman_filter_inputs 62.024998 1.461100 \n",
"kalman_filter_jax kalman_filter_inputs 22.434351 2.997781 \n",
"kalman_filter_pytensor_numba kalman_filter_inputs 61.308398 1.114119 \n",
"kalman_filter_pytensor_jax kalman_filter_inputs 24.356249 3.024018 \n",
"\n",
" Samples \n",
"kalman_filter_pytensor kalman_filter_inputs 17 \n",
"atrocious_kalman_filter_numba kalman_filter_inputs 5 \n",
"kalman_filter_numba kalman_filter_inputs 5 \n",
"kalman_filter_jax kalman_filter_inputs 18 \n",
"kalman_filter_pytensor_numba kalman_filter_inputs 5 \n",
"kalman_filter_pytensor_jax kalman_filter_inputs 20 "
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"kalman_filter_bench.run(\n",
" inputs={\n",
" \"kalman_filter_inputs\": {\"z\": z, \"F\": F, \"H\": H, \"Q\": Q, \"R\": R, \"x0\": x0, \"P0\": P0},\n",
" }\n",
")\n",
"kalman_filter_bench.summary()"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "37526f86",
"metadata": {},
"outputs": [],
"source": [
"xs, Ps, x_pred, P_pred = kalman_filter_jax(z, F, H, Q, R, x0, P0)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "231140ba",
"metadata": {},
"outputs": [],
"source": [
"def compute_pred_intervals(z, x_pred, P_pred, H, R, zscore=1.96):\n",
" T = z.shape[0]\n",
" m = H.shape[0]\n",
" mean = np.zeros((T, m))\n",
" lower = np.zeros((T, m))\n",
" upper = np.zeros((T, m))\n",
" outside = np.zeros(T, dtype=np.bool_)\n",
"\n",
" for t in range(T):\n",
" mean[t] = H @ x_pred[t]\n",
" S = H @ P_pred[t] @ H.T + R\n",
" std = np.sqrt(np.diag(S))\n",
" lower[t] = mean[t] - zscore * std\n",
" upper[t] = mean[t] + zscore * std\n",
"\n",
" # check coverage of actual obs\n",
" outside[t] = np.any((z[t] < lower[t]) | (z[t] > upper[t]))\n",
"\n",
" coverage = 1 - outside.mean()\n",
" return mean, lower, upper, coverage\n"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "ff739003",
"metadata": {},
"outputs": [],
"source": [
"mean, lower, upper, coverage = compute_pred_intervals(z, x_pred, P_pred, H, R)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "7546c95c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"np.float64(0.91)"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"coverage"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "6b765a37",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig= go.Figure()\n",
"fig.add_traces(\n",
" [\n",
" go.Scatter(\n",
" x = np.arange(T),\n",
" y = z.ravel(),\n",
" mode=\"markers\",\n",
" marker_color = \"royalblue\",\n",
" name = \"actuals\"\n",
" ),\n",
" go.Scatter(\n",
" x = np.arange(T),\n",
" y = xs.ravel(),\n",
" mode = \"lines\",\n",
" marker_color = \"orange\",\n",
" name = \"filtered mean\"\n",
" ),\n",
" go.Scatter(\n",
" name=\"\", \n",
" x=np.arange(T), \n",
" y=upper.ravel(), \n",
" mode=\"lines\", \n",
" marker=dict(color=\"#eb8c34\"), \n",
" line=dict(width=0), \n",
" legendgroup=\"95% CI\",\n",
" showlegend=False\n",
" ),\n",
" go.Scatter(\n",
" name=\"95% CI\", \n",
" x=np.arange(T), \n",
" y=lower.ravel(), \n",
" mode=\"lines\", marker=dict(color=\"#eb8c34\"), \n",
" line=dict(width=0), \n",
" legendgroup=\"95% CI\", \n",
" fill='tonexty', \n",
" fillcolor='rgba(235, 140, 52, 0.2)'\n",
" ),\n",
"\n",
" ]\n",
")\n",
"fig.update_layout(\n",
" xaxis=dict(\n",
" title = \"Time Index\",\n",
" ),\n",
" yaxis=dict(\n",
" title = \"y\"\n",
" ),\n",
" template = \"plotly_dark\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "052b0194",
"metadata": {},
"source": [
"## Non-linear Kalman Filter"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "6b088d6b",
"metadata": {},
"outputs": [],
"source": [
"@jit(nopython=True)\n",
"def loglik_poisson_numba(s, y):\n",
" \"\"\"Poisson Log Likelihood\"\"\"\n",
" mu = np.exp(s)\n",
" return y * np.log(mu + 1e-30) - mu - math.lgamma(y + 1.0) # numba does not support scipy.special gammaln\n",
"\n",
"@jit(nopython=True)\n",
"def particle_filter_1d_predict_numba(A, Q, x0_mean, x0_std, ys, N=1000, seed=2):\n",
" \"\"\"\n",
" 1D particle filter.\n",
" \n",
" Parameters\n",
" ----------\n",
" A: float\n",
" State transition\n",
" Q: float\n",
" Process covariance\n",
" x0_mean: float\n",
" Prior mean for the latent state\n",
" x0_std: float\n",
" Prior standard deviation \n",
" ys: np.ndarray\n",
" observations\n",
" N: int\n",
" number of particles\n",
" seed: int\n",
" rng seed for reproducibility\n",
"\n",
" Returns\n",
" -------\n",
" filtered_means: np.ndarray\n",
" The filtered mean for the latent state \n",
" filtered_vars: np.ndarray\n",
" The filtered variance for the latent state\n",
" pred_means: np.ndarray\n",
" observation predicted mean \n",
" \"\"\"\n",
" np.random.seed(seed)\n",
" T = ys.shape[0]\n",
" particles = np.random.normal(x0_mean, x0_std, size=N)\n",
" weights = np.ones(N) / N\n",
"\n",
" filtered_means = np.zeros(T)\n",
" filtered_vars = np.zeros(T)\n",
" pred_means = np.zeros(T)\n",
"\n",
" for t in range(T):\n",
" y = ys[t]\n",
"\n",
" # propagate (vectorized)\n",
" particles = A * particles + np.random.normal(0, np.sqrt(Q), size=N)\n",
"\n",
" # update weights\n",
" logw = np.zeros(N)\n",
" for i in range(N):\n",
" logw[i] = loglik_poisson_numba(particles[i], y)\n",
" logw = logw - np.max(logw)\n",
" weights *= np.exp(logw)\n",
" weights /= np.sum(weights) + 1e-12\n",
"\n",
" # filtered moments\n",
" mean_t = np.sum(weights * particles)\n",
" var_t = np.sum(weights * (particles - mean_t) ** 2)\n",
"\n",
" # predictive mean\n",
" pred_mean = np.sum(weights * np.exp(particles))\n",
"\n",
" filtered_means[t] = mean_t\n",
" filtered_vars[t] = var_t\n",
" pred_means[t] = pred_mean\n",
"\n",
" # resample (multinomial resampling) because numba doesn't support np.random.choice\n",
" cumulative_sum = np.cumsum(weights)\n",
" cumulative_sum[-1] = 1.0 # guard against rounding error\n",
" indices = np.searchsorted(cumulative_sum, np.random.rand(N))\n",
"\n",
" particles = particles[indices]\n",
" weights = np.ones(N) / N\n",
"\n",
" return filtered_means, filtered_vars, pred_means"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "a468c403",
"metadata": {},
"outputs": [],
"source": [
"# Had to fix the loglikelihood and key to use benchmarker as is\n",
"def loglik_poisson_jax(s, y):\n",
" \"\"\"Poisson Log Likelihood\"\"\"\n",
" mu = jnp.exp(s)\n",
" return y * jnp.log(mu + 1e-30) - mu - gammaln(y + 1.0)\n",
"\n",
"@block\n",
"@partial(jax.jit, static_argnums=5)\n",
"def particle_filter_1d_predict_jax(\n",
" A, Q, x0_mean, x0_std, ys, N=1000,\n",
"):\n",
" \"\"\"\n",
" 1D particle filter.\n",
" \n",
" Parameters\n",
" ----------\n",
" A: float\n",
" State transition\n",
" Q: float\n",
" Process covariance\n",
" x0_mean: float\n",
" Prior mean for the latent state\n",
" x0_std: float\n",
" Prior standard deviation \n",
" ys: np.ndarray\n",
" observations\n",
" loglik_fn: function\n",
" The log likelihood function\n",
" key: \n",
" JAX prng key\n",
" N: int\n",
" number of particles\n",
"\n",
" Returns\n",
" -------\n",
" filtered_means: jnp.ndarray\n",
" The filtered mean for the latent state \n",
" filtered_vars: jnp.ndarray\n",
" The filtered variance for the latent state\n",
" pred_means: jnp.ndarray\n",
" observation predicted mean \n",
" \"\"\"\n",
" key = jax.random.PRNGKey(0)\n",
" T = ys.shape[0]\n",
" particles = jax.random.normal(key, (N,)) * x0_std + x0_mean # init particles from gaussian priors\n",
" weights = jnp.ones(N) / N # particle weights, all particles equally likely prior\n",
"\n",
" def body_fun(carry, t):\n",
" particles, weights, key = carry\n",
" y = ys[t]\n",
"\n",
" # propagate\n",
" key, subkey = jax.random.split(key)\n",
" particles = A * particles + jax.random.normal(subkey, (N,)) * jnp.sqrt(Q) # state transition model\n",
"\n",
" # update weights\n",
" logw = jax.vmap(lambda x: loglik_poisson_jax(x, y))(particles) # update particles in parallel\n",
" logw = logw - jnp.max(logw) # avoid overflow\n",
" weights = weights * jnp.exp(logw) # old weights times the likelihood\n",
" weights /= jnp.sum(weights) + 1e-12 # normalize so that weights sum to 1\n",
"\n",
" # filtered moments\n",
" mean_t = jnp.sum(weights * particles) # posterior mean of latent state\n",
" var_t = jnp.sum(weights * (particles - mean_t)**2) # posterior variance of latent state\n",
"\n",
" # predictive mean\n",
" pred_mean = jnp.sum(weights * jnp.exp(particles))\n",
"\n",
" # resample to prevent dominant particles\n",
" key, subkey = jax.random.split(key)\n",
" indices = jax.random.choice(subkey, N, p=weights, shape=(N,))\n",
" particles = particles[indices]\n",
" weights = jnp.ones(N) / N\n",
"\n",
" carry = (particles, weights, key)\n",
" out = (mean_t, var_t, pred_mean)\n",
" return carry, out\n",
"\n",
" _, outputs = jax.lax.scan(body_fun, (particles, weights, key), jnp.arange(T))\n",
" return outputs\n"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "15fd4a0c",
"metadata": {},
"outputs": [],
"source": [
"from pytensor.tensor.random.utils import RandomStream\n",
"\n",
"# Random stream for PyTensor\n",
"srng = RandomStream(seed=42)\n",
"\n",
"# Poisson log-likelihood\n",
"def loglik_poisson_pytensor(s, y):\n",
" mu = pt.exp(s)\n",
" return y.flatten() * pt.log(mu + 1e-30) - mu - pt.gammaln(y.flatten() + 1.0)\n"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "59525da5",
"metadata": {},
"outputs": [],
"source": [
"ys_symbolic = pt.vector(\"ys\")\n",
"x0_mean_symbolic = pt.scalar(\"x0_mean\")\n",
"x0_std_symbolic = pt.scalar(\"x0_std\")\n",
"A_symbolic = pt.scalar(\"A\")\n",
"Q_symbolic = pt.scalar(\"Q\")\n",
"N_symbolic = pt.scalar(\"N\", dtype='int64')\n",
"\n",
"N_STEPS = 300\n",
"\n",
"# Initialize particles and weights\n",
"particles_init = srng.normal(size=(N_symbolic,)) * x0_std_symbolic + x0_mean_symbolic\n",
"weights_init = pt.ones((N_symbolic,)) / N_symbolic \n",
"\n",
"# Step function for scan\n",
"def step(y_t, particles_prev, weights_prev, A_symbolic, Q_symbolic):\n",
" # Propagate particles\n",
" particles_prop = A_symbolic * particles_prev + srng.normal(size=(N_symbolic,)) * pt.sqrt(Q_symbolic)\n",
"\n",
" # Update weights\n",
" # logw = pt.stack([loglik_poisson_pytensor(p, y_t) for p in particles_prop])\n",
" logw = loglik_poisson_pytensor(particles_prop, y_t)\n",
" logw_stable = logw - pt.max(logw)\n",
" w_unnorm = weights_prev * pt.exp(logw_stable)\n",
" w = w_unnorm / (pt.sum(w_unnorm) + 1e-12) \n",
"\n",
" # Filtered moments\n",
" mean_t = pt.sum(w * particles_prop)\n",
" var_t = pt.sum(w * (particles_prop - mean_t) ** 2)\n",
" pred_mean = pt.sum(w * pt.exp(particles_prop))\n",
"\n",
" # Resample particles\n",
" idx = srng.choice(size=(N_symbolic,), a=N_symbolic, p=w) \n",
" particles_resampled = particles_prop[idx]\n",
" weights_resampled = pt.ones((N_symbolic,)) / N_symbolic\n",
"\n",
" # Return flat tuple\n",
" return particles_resampled, weights_resampled, mean_t, var_t, pred_mean\n",
"\n",
"# first two are recurrent, rest are collected\n",
"outputs_info = [\n",
" particles_init,\n",
" weights_init,\n",
" None,\n",
" None,\n",
" None\n",
"]\n",
"\n",
"(particles_seq, weights_seq, means_seq, vars_seq, preds_seq), updates = pytensor.scan(\n",
" fn=step,\n",
" sequences=[ys_symbolic],\n",
" outputs_info=outputs_info,\n",
" non_sequences=[A_symbolic, Q_symbolic],\n",
" n_steps=N_STEPS\n",
")\n",
"\n",
"particle_filter_1d_predict_pytensor = pytensor.function(\n",
" [A_symbolic, Q_symbolic, x0_mean_symbolic, x0_std_symbolic, ys_symbolic, N_symbolic],\n",
" [means_seq, vars_seq, preds_seq],\n",
" updates=updates,\n",
" no_default_updates=True,\n",
" trust_input=True\n",
")\n",
"\n",
"particle_filter_1d_predict_pytensor_numba = pytensor.function(\n",
" [A_symbolic, Q_symbolic, x0_mean_symbolic, x0_std_symbolic, ys_symbolic, N_symbolic],\n",
" [means_seq, vars_seq, preds_seq],\n",
" updates=updates,\n",
" no_default_updates=True,\n",
" mode=\"NUMBA\", \n",
" trust_input=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "c033a1d0",
"metadata": {},
"outputs": [],
"source": [
"key = jax.random.PRNGKey(0)\n",
"T = 300\n",
"A = 0.95\n",
"Q = 0.05\n",
"rng = np.random.RandomState(1)\n",
"\n",
"target_mean = 10.0\n",
"latent_var = Q / (1 - A**2)\n",
"x0_mean = np.log(target_mean) - 0.5 * latent_var\n",
"x0_std = 1.0\n",
"\n",
"# Simulate latent\n",
"x = np.zeros(T)\n",
"x[0] = rng.normal() * np.sqrt(latent_var) + x0_mean\n",
"for t in range(1, T):\n",
" x[t] = A * x[t-1] + rng.normal() * np.sqrt(Q)\n",
"\n",
"ys = np.array(rng.poisson(np.exp(x)), dtype=np.float32)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "2a9cbfa5",
"metadata": {},
"outputs": [],
"source": [
"nonlinear_kalman_filter_bench = Benchmarker(\n",
" functions=[particle_filter_1d_predict_pytensor, particle_filter_1d_predict_numba, particle_filter_1d_predict_jax, particle_filter_1d_predict_pytensor_numba,], \n",
" names=['particle_filter_1d_predict_pytensor', 'particle_filter_1d_predict_numba', 'particle_filter_1d_predict_jax', 'particle_filter_1d_predict_pytensor_numba',],\n",
" number=5 # This takes a while to run reducing number of loops\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "c782c42b",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" | \n",
" Loops | \n",
" Min (us) | \n",
" Max (us) | \n",
" Mean (us) | \n",
" StdDev (us) | \n",
" Median (us) | \n",
" IQR (us) | \n",
" OPS (Kops/s) | \n",
" Samples | \n",
"
\n",
" \n",
" \n",
" \n",
" particle_filter_1d_predict_pytensor | \n",
" kalman_filter_inputs | \n",
" 5 | \n",
" 728815.783397 | \n",
" 742948.366801 | \n",
" 734642.386761 | \n",
" 5002.169131 | \n",
" 733596.250002 | \n",
" 6332.200003 | \n",
" 0.001361 | \n",
" 5 | \n",
"
\n",
" \n",
" particle_filter_1d_predict_numba | \n",
" kalman_filter_inputs | \n",
" 5 | \n",
" 48678.408400 | \n",
" 48904.825002 | \n",
" 48782.921640 | \n",
" 89.835338 | \n",
" 48744.633398 | \n",
" 160.525000 | \n",
" 0.020499 | \n",
" 5 | \n",
"
\n",
" \n",
" particle_filter_1d_predict_jax | \n",
" kalman_filter_inputs | \n",
" 5 | \n",
" 33170.266601 | \n",
" 33644.350001 | \n",
" 33410.141721 | \n",
" 155.134231 | \n",
" 33411.583403 | \n",
" 125.924998 | \n",
" 0.029931 | \n",
" 5 | \n",
"
\n",
" \n",
" particle_filter_1d_predict_pytensor_numba | \n",
" kalman_filter_inputs | \n",
" 5 | \n",
" 676612.958201 | \n",
" 678200.574999 | \n",
" 677432.478319 | \n",
" 647.987321 | \n",
" 677334.924997 | \n",
" 1278.466597 | \n",
" 0.001476 | \n",
" 5 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Loops \\\n",
"particle_filter_1d_predict_pytensor kalman_filter_inputs 5 \n",
"particle_filter_1d_predict_numba kalman_filter_inputs 5 \n",
"particle_filter_1d_predict_jax kalman_filter_inputs 5 \n",
"particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 5 \n",
"\n",
" Min (us) \\\n",
"particle_filter_1d_predict_pytensor kalman_filter_inputs 728815.783397 \n",
"particle_filter_1d_predict_numba kalman_filter_inputs 48678.408400 \n",
"particle_filter_1d_predict_jax kalman_filter_inputs 33170.266601 \n",
"particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 676612.958201 \n",
"\n",
" Max (us) \\\n",
"particle_filter_1d_predict_pytensor kalman_filter_inputs 742948.366801 \n",
"particle_filter_1d_predict_numba kalman_filter_inputs 48904.825002 \n",
"particle_filter_1d_predict_jax kalman_filter_inputs 33644.350001 \n",
"particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 678200.574999 \n",
"\n",
" Mean (us) \\\n",
"particle_filter_1d_predict_pytensor kalman_filter_inputs 734642.386761 \n",
"particle_filter_1d_predict_numba kalman_filter_inputs 48782.921640 \n",
"particle_filter_1d_predict_jax kalman_filter_inputs 33410.141721 \n",
"particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 677432.478319 \n",
"\n",
" StdDev (us) \\\n",
"particle_filter_1d_predict_pytensor kalman_filter_inputs 5002.169131 \n",
"particle_filter_1d_predict_numba kalman_filter_inputs 89.835338 \n",
"particle_filter_1d_predict_jax kalman_filter_inputs 155.134231 \n",
"particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 647.987321 \n",
"\n",
" Median (us) \\\n",
"particle_filter_1d_predict_pytensor kalman_filter_inputs 733596.250002 \n",
"particle_filter_1d_predict_numba kalman_filter_inputs 48744.633398 \n",
"particle_filter_1d_predict_jax kalman_filter_inputs 33411.583403 \n",
"particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 677334.924997 \n",
"\n",
" IQR (us) \\\n",
"particle_filter_1d_predict_pytensor kalman_filter_inputs 6332.200003 \n",
"particle_filter_1d_predict_numba kalman_filter_inputs 160.525000 \n",
"particle_filter_1d_predict_jax kalman_filter_inputs 125.924998 \n",
"particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 1278.466597 \n",
"\n",
" OPS (Kops/s) \\\n",
"particle_filter_1d_predict_pytensor kalman_filter_inputs 0.001361 \n",
"particle_filter_1d_predict_numba kalman_filter_inputs 0.020499 \n",
"particle_filter_1d_predict_jax kalman_filter_inputs 0.029931 \n",
"particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 0.001476 \n",
"\n",
" Samples \n",
"particle_filter_1d_predict_pytensor kalman_filter_inputs 5 \n",
"particle_filter_1d_predict_numba kalman_filter_inputs 5 \n",
"particle_filter_1d_predict_jax kalman_filter_inputs 5 \n",
"particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 5 "
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nonlinear_kalman_filter_bench.run(\n",
" inputs={\n",
" \"kalman_filter_inputs\": {\"A\": A, \"Q\": Q, \"x0_mean\": x0_mean, \"x0_std\": x0_std, \"ys\": ys, \"N\": 2000},\n",
" }\n",
")\n",
"nonlinear_kalman_filter_bench.summary()"
]
},
{
"cell_type": "markdown",
"id": "29fe0cc4",
"metadata": {},
"source": [
"Slightly different estimates because I couldn't reproduce 1:1 "
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "2bd23897",
"metadata": {},
"outputs": [],
"source": [
"filtered_means, filtered_vars, pred_means = particle_filter_1d_predict_numba(\n",
" A, Q, x0_mean, x0_std, ys, N=2000, seed=2\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "7caac5a5",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = make_subplots(\n",
" rows=2, cols=1,\n",
" subplot_titles=(\"Observation Predictions\", \"Latent State Estimation\"),\n",
" vertical_spacing=0.07,\n",
" shared_xaxes=True\n",
")\n",
"\n",
"fig.add_traces(\n",
" [\n",
" go.Scatter(\n",
" x = np.arange(T),\n",
" y = ys,\n",
" mode = \"markers\",\n",
" marker_color = \"cornflowerblue\",\n",
" name = \"actuals\"\n",
" ),\n",
" go.Scatter(\n",
" x = np.arange(T),\n",
" y = pred_means,\n",
" mode = \"lines\",\n",
" marker_color = \"#eb8c34\",\n",
" name = \"predicted mean\"\n",
" ),\n",
" go.Scatter(\n",
" name=\"\", \n",
" x=np.arange(T), \n",
" y=pred_means + 2*jnp.sqrt(pred_means), \n",
" mode=\"lines\", \n",
" marker=dict(color=\"#eb8c34\"), \n",
" line=dict(width=0), \n",
" legendgroup=\"predicted mean 95% CI\",\n",
" showlegend=False\n",
" ),\n",
" go.Scatter(\n",
" name=\"predicted mean 95% CI\", \n",
" x=np.arange(T), \n",
" y=pred_means - 2*jnp.sqrt(pred_means), \n",
" mode=\"lines\", marker=dict(color=\"#eb8c34\"), \n",
" line=dict(width=0), \n",
" legendgroup=\"predicted mean 95% CI\", \n",
" fill='tonexty', \n",
" fillcolor='rgba(235, 140, 52, 0.2)'\n",
" ),\n",
" ],\n",
" rows=1, cols=1\n",
")\n",
"\n",
"fig.add_traces(\n",
" [\n",
" go.Scatter(\n",
" x = np.arange(T),\n",
" y = x,\n",
" mode = \"lines\",\n",
" marker_color = \"cornflowerblue\",\n",
" name = \"true latent state\"\n",
" ),\n",
" go.Scatter(\n",
" x = np.arange(T),\n",
" y = filtered_means,\n",
" mode = \"lines\",\n",
" marker_color = \"#eb8c34\",\n",
" name = \"filtered state mean\"\n",
" ),\n",
" go.Scatter(\n",
" name=\"\", \n",
" x=np.arange(T), \n",
" y=filtered_means + 2*jnp.sqrt(filtered_vars), \n",
" mode=\"lines\", \n",
" marker=dict(color=\"#eb8c34\"), \n",
" line=dict(width=0), \n",
" legendgroup=\"filtered state mean 95% CI\",\n",
" showlegend=False\n",
" ),\n",
" go.Scatter(\n",
" name=\"filtered state mean 95% CI\", \n",
" x=np.arange(T), \n",
" y=filtered_means - 2*jnp.sqrt(filtered_vars), \n",
" mode=\"lines\", marker=dict(color=\"#eb8c34\"), \n",
" line=dict(width=0), \n",
" legendgroup=\"filtered state mean 95% CI\", \n",
" fill='tonexty', \n",
" fillcolor='rgba(235, 140, 52, 0.2)'\n",
" ),\n",
" ],\n",
" rows=2, cols=1\n",
")\n",
"\n",
"for i, yaxis in enumerate(fig.select_yaxes(), 1):\n",
" legend_name = f\"legend{i}\"\n",
" fig.update_layout({legend_name: dict(y=yaxis.domain[1], yanchor=\"top\")}, showlegend=True)\n",
" fig.update_traces(row=i, legend=legend_name)\n",
"\n",
"fig.update_layout(height=1000, width=1200, template=\"plotly_dark\")\n",
"\n",
"fig.update_layout(\n",
" legend1=dict(\n",
" yanchor=\"top\",\n",
" y=1.0,\n",
" xanchor=\"left\",\n",
" x=0,\n",
" orientation=\"h\"\n",
" ),\n",
" legend2=dict(\n",
" yanchor=\"top\",\n",
" y=.465,\n",
" xanchor=\"left\",\n",
" x=0,\n",
" orientation=\"h\"\n",
" ),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48ccc984",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pytensor-dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}