{ "cells": [ { "cell_type": "markdown", "id": "9571ee33", "metadata": {}, "source": [ "# Imports" ] }, { "cell_type": "code", "execution_count": 1, "id": "73cc811c", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import jax.numpy as jnp\n", "from numba import jit\n", "import numba\n", "import pytensor\n", "import pytensor.tensor as pt\n", "from pytensor.compile.mode import get_mode\n", "import timeit\n", "import jax\n", "import math\n", "from jax.scipy.special import gammaln\n", "from functools import partial\n", "\n", "import plotly.graph_objects as go\n", "from plotly.subplots import make_subplots\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 2, "id": "d5fc1350", "metadata": {}, "outputs": [], "source": [ "import plotly.io as pio\n", "pio.renderers.default = \"notebook\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "daa0969b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pytensor version: 0+untagged.31343.g647570d.dirty\n", "jax version: 0.7.2\n", "numba version: 0.62.1\n" ] } ], "source": [ "print(\"pytensor version:\", pytensor.__version__)\n", "print(\"jax version:\", jax.__version__)\n", "print(\"numba version:\", numba.__version__)" ] }, { "cell_type": "code", "execution_count": 4, "id": "8294718d", "metadata": {}, "outputs": [], "source": [ "class Benchmarker:\n", " \"\"\"\n", " Benchmark a set of functions by timing execution and summarizing statistics.\n", "\n", " Parameters\n", " ----------\n", " functions : list of callables\n", " List of callables to benchmark.\n", " names : list of str, optional\n", " Names corresponding to each function. Default is ['func_0', 'func_1', ...].\n", " number : int or None, optional\n", " Number of loops per timing. If None, auto-calibrated via Timer.autorange().\n", " Default is None.\n", " repeat : int, optional\n", " Number of repeats for timing. Default is 7.\n", " target_time : float, optional\n", " Target duration in seconds for auto-calibration. Default is 0.2.\n", "\n", " Attributes\n", " ----------\n", " results : dict\n", " Mapping from function names to a dict with keys:\n", " - 'raw_us': numpy.ndarray of raw timings in microseconds\n", " - 'loops': number of loops used per timing\n", "\n", " Methods\n", " -------\n", " run()\n", " Auto-calibrate (if needed) and run timings for all functions.\n", " summary(unit='us') -> pandas.DataFrame\n", " Return a summary DataFrame with statistics converted to the given unit.\n", " raw(name=None) -> dict or numpy.ndarray\n", " Return raw timing data in microseconds for a specific function or all.\n", " _convert_times(times, unit) -> numpy.ndarray\n", " Convert an array of times from microseconds to the specified unit.\n", " \"\"\"\n", "\n", " def __init__(\n", " self, functions, names=None, number=None, min_rounds=5, max_time=1.0, target_time=0.2\n", " ):\n", " self.functions = functions\n", " self.names = names or [f\"func_{i}\" for i in range(len(functions))]\n", " self.number = number\n", " self.min_rounds = min_rounds\n", " self.max_time = max_time\n", " self.target_time = target_time\n", " self.results = {}\n", "\n", " def run(self, inputs: dict[str, dict]):\n", " \"\"\"\n", " Auto-calibrate loop count and sample rounds if needed, then time each function.\n", " \"\"\"\n", " for name, func in zip(self.names, self.functions):\n", " for input_name, kwargs in inputs.items():\n", " timer = timeit.Timer(partial(func, **kwargs))\n", "\n", " # Calibrate loops\n", " if self.number is None:\n", " loops, calib_time = timer.autorange()\n", " else:\n", " loops = self.number\n", " calib_time = timer.timeit(number=loops)\n", "\n", " # Determine rounds based on max_time and min_rounds\n", " if self.max_time is not None:\n", " rounds = max(self.min_rounds, int(np.ceil(self.max_time / calib_time)))\n", " else:\n", " rounds = self.min_rounds\n", "\n", " raw_round_times = np.array(timer.repeat(repeat=rounds, number=loops))\n", "\n", " # Convert to microseconds per single execution\n", " raw_us = raw_round_times / loops * 1e6\n", "\n", " self.results[(name, input_name)] = {\n", " \"raw_us\": raw_us,\n", " \"loops\": loops,\n", " \"rounds\": rounds,\n", " }\n", "\n", " def summary(self, unit=\"us\"):\n", " \"\"\"\n", " Summarize benchmark statistics in a DataFrame.\n", "\n", " Parameters\n", " ----------\n", " unit : {'us', 'ms', 'ns', 's'}, optional\n", " Unit for output times. 'us' means microseconds, 'ms' milliseconds,\n", " 'ns' nanoseconds, 's' seconds. Default is 'us'.\n", "\n", " Returns\n", " -------\n", " pandas.DataFrame\n", " Summary with columns:\n", " Name, Loops, Min, Max, Mean, StdDev, Median, IQR (all in given unit),\n", " OPS (Kops/unit), Samples.\n", " \"\"\"\n", " records = []\n", " indexes = []\n", " for name, data in self.results.items():\n", " raw_us = data[\"raw_us\"]\n", " # Convert to target unit\n", " times = self._convert_times(raw_us, unit)\n", " if isinstance(name, tuple) and len(name) > 1:\n", " indexes.append(name)\n", " elif isinstance(name, tuple) and len(name) == 1:\n", " indexes.append(name[0])\n", " else:\n", " indexes.append(name)\n", "\n", " stats = {\n", " \"Loops\": data[\"loops\"],\n", " f\"Min ({unit})\": np.min(times),\n", " f\"Max ({unit})\": np.max(times),\n", " f\"Mean ({unit})\": np.mean(times),\n", " f\"StdDev ({unit})\": np.std(times),\n", " f\"Median ({unit})\": np.median(times),\n", " f\"IQR ({unit})\": np.percentile(times, 75) - np.percentile(times, 25),\n", " \"OPS (Kops/s)\": 1e3 / (np.mean(raw_us)),\n", " \"Samples\": len(raw_us),\n", " }\n", " records.append(stats)\n", "\n", " if all(isinstance(idx, tuple) for idx in indexes):\n", " index = pd.MultiIndex.from_tuples(indexes)\n", " else:\n", " index = pd.Index(indexes)\n", " return pd.DataFrame(records, index=index)\n", "\n", " def raw(self, name=None):\n", " \"\"\"\n", " Get raw timing data in microseconds.\n", "\n", " Parameters\n", " ----------\n", " name : str, optional\n", " If given, returns the raw_us array for that function. Otherwise returns\n", " a dict of all raw results.\n", "\n", " Returns\n", " -------\n", " numpy.ndarray or dict\n", " \"\"\"\n", " if name:\n", " return self.results.get(name, {}).get(\"raw_us\")\n", " return {n: d[\"raw_us\"] for n, d in self.results.items()}\n", "\n", " def _convert_times(self, times, unit):\n", " \"\"\"\n", " Convert an array of times from microseconds to the specified unit.\n", "\n", " Parameters\n", " ----------\n", " times : array-like\n", " Times in microseconds.\n", " unit : {'us', 'ms', 'ns', 's'}\n", " Target unit: 'us' microseconds, 'ms' milliseconds,\n", " 'ns' nanoseconds, 's' seconds.\n", "\n", " Returns\n", " -------\n", " numpy.ndarray\n", " Converted times.\n", "\n", " Raises\n", " ------\n", " ValueError\n", " If `unit` is not one of the supported options.\n", " \"\"\"\n", " unit = unit.lower()\n", " if unit == \"us\":\n", " factor = 1.0\n", " elif unit == \"ms\":\n", " factor = 1e-3\n", " elif unit == \"ns\":\n", " factor = 1e3\n", " elif unit == \"s\":\n", " factor = 1e-6\n", " else:\n", " raise ValueError(f\"Unsupported unit: {unit}\")\n", " return times * factor" ] }, { "cell_type": "code", "execution_count": 5, "id": "799a759b", "metadata": {}, "outputs": [], "source": [ "# def block(func):\n", "# def inner(*args, **kwargs):\n", "# return jax.block_until_ready(func(*args, **kwargs))\n", "# return inner\n", "\n", "\n", "# I believe the above will only work if the return is a single jnp.array (at least that is what AI thinks) I would appreciate your insights here, Adrian.\n", "\n", "def block(func):\n", " def inner(*args, **kwargs):\n", " result = func(*args, **kwargs)\n", " # Recursively block on all JAX arrays in result\n", " jax.tree_util.tree_map(lambda x: x.block_until_ready() if hasattr(x, \"block_until_ready\") else None, result)\n", " return result\n", " return inner\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "6e76a860", "metadata": {}, "outputs": [], "source": [ "# Set Pytensor to use float32\n", "pytensor.config.floatX = \"float32\"" ] }, { "cell_type": "markdown", "id": "1c3eb543", "metadata": {}, "source": [ "# Introduction" ] }, { "cell_type": "markdown", "id": "e066c9c4", "metadata": {}, "source": [ "# Baby Steps" ] }, { "cell_type": "markdown", "id": "962c851c", "metadata": {}, "source": [ "## Fibonacci Algorithm" ] }, { "cell_type": "code", "execution_count": 7, "id": "1c8d1654", "metadata": {}, "outputs": [], "source": [ "N_STEPS = 100_000\n", "\n", "b_symbolic = pt.vector(\"b\", dtype=\"int32\", shape=(1,))\n", "\n", "def step(a, b):\n", " return a + b, a\n", "\n", "(outputs_a, outputs_b), _ = pytensor.scan(\n", " fn=step,\n", " outputs_info=[pt.ones(1, dtype=\"int32\"), b_symbolic],\n", " n_steps=N_STEPS\n", ")\n", "\n", "# compile function returning final a\n", "fibonacci_pytensor = pytensor.function([b_symbolic], outputs_a[-1], trust_input=True)\n", "fibonacci_pytensor_numba = pytensor.function([b_symbolic], outputs_a[-1], mode='NUMBA', trust_input=True)" ] }, { "cell_type": "code", "execution_count": 8, "id": "bf48d6bf", "metadata": {}, "outputs": [], "source": [ "@jit(nopython=True)\n", "def fibonacci_numba(b):\n", " a = np.ones(1, dtype=np.int32)\n", " for _ in range(N_STEPS):\n", " a[0], b[0] = a[0] + b[0], a[0]\n", " return a[0]" ] }, { "cell_type": "code", "execution_count": 9, "id": "317ee391", "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def fibonacci_jax(b):\n", " a = jnp.array(1, dtype=np.int32)\n", " for _ in range(N_STEPS):\n", " a, b = a + b, a\n", " return a" ] }, { "cell_type": "code", "execution_count": 10, "id": "63bfdffe", "metadata": {}, "outputs": [], "source": [ "fibonacci_bench = Benchmarker(\n", " functions=[fibonacci_pytensor, fibonacci_numba, fibonacci_jax, fibonacci_pytensor_numba], \n", " names=['fibonacci_pytensor', 'fibonacci_numba', 'fibonacci_jax', 'fibonacci_pytensor_numba'],\n", " number=10\n", ")" ] }, { "cell_type": "code", "execution_count": 11, "id": "65bf994e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LoopsMin (us)Max (us)Mean (us)StdDev (us)Median (us)IQR (us)OPS (Kops/s)Samples
fibonacci_pytensorfibonacci_inputs1054149.95000255278.41250054654.332500434.90289854807.945801694.0874010.0182975
fibonacci_numbafibonacci_inputs1084.208400105.55830195.5759694.79310495.5625005.03330110.46288113
fibonacci_jaxfibonacci_inputs107.16249931.55840114.4250208.92315812.2333005.92089869.3239965
fibonacci_pytensor_numbafibonacci_inputs102338.0458012429.0374972385.05915934.4389882392.57080058.6417010.4192775
\n", "
" ], "text/plain": [ " Loops Min (us) Max (us) \\\n", "fibonacci_pytensor fibonacci_inputs 10 54149.950002 55278.412500 \n", "fibonacci_numba fibonacci_inputs 10 84.208400 105.558301 \n", "fibonacci_jax fibonacci_inputs 10 7.162499 31.558401 \n", "fibonacci_pytensor_numba fibonacci_inputs 10 2338.045801 2429.037497 \n", "\n", " Mean (us) StdDev (us) \\\n", "fibonacci_pytensor fibonacci_inputs 54654.332500 434.902898 \n", "fibonacci_numba fibonacci_inputs 95.575969 4.793104 \n", "fibonacci_jax fibonacci_inputs 14.425020 8.923158 \n", "fibonacci_pytensor_numba fibonacci_inputs 2385.059159 34.438988 \n", "\n", " Median (us) IQR (us) \\\n", "fibonacci_pytensor fibonacci_inputs 54807.945801 694.087401 \n", "fibonacci_numba fibonacci_inputs 95.562500 5.033301 \n", "fibonacci_jax fibonacci_inputs 12.233300 5.920898 \n", "fibonacci_pytensor_numba fibonacci_inputs 2392.570800 58.641701 \n", "\n", " OPS (Kops/s) Samples \n", "fibonacci_pytensor fibonacci_inputs 0.018297 5 \n", "fibonacci_numba fibonacci_inputs 10.462881 13 \n", "fibonacci_jax fibonacci_inputs 69.323996 5 \n", "fibonacci_pytensor_numba fibonacci_inputs 0.419277 5 " ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fibonacci_bench.run(\n", " inputs={\n", " \"fibonacci_inputs\": {\"b\": np.ones(1, dtype=np.int32)},\n", " }\n", ")\n", "fibonacci_bench.summary()" ] }, { "cell_type": "markdown", "id": "7bbb35a2", "metadata": {}, "source": [ "## Element-wise multiplication Algorithm" ] }, { "cell_type": "code", "execution_count": 12, "id": "7b8bb97b", "metadata": {}, "outputs": [], "source": [ "# a_symbolic = pt.vector(\"a\", dtype=\"int32\")\n", "# b_symbolic = pt.vector(\"b\", dtype=\"int32\")\n", "\n", "# def step(a_element, b_element):\n", "# return a_element * b_element\n", " \n", "# c, _ = pytensor.scan(\n", "# fn=step,\n", "# sequences=[a_symbolic, b_symbolic]\n", "# )\n", "\n", "# # compile function returning final a\n", "# c_mode = get_mode(\"FAST_RUN\").excluding(\"scan_push_out_seq\")\n", "# elementwise_multiply_pytensor = pytensor.function([a_symbolic, b_symbolic], c, trust_input=True, mode=c_mode)\n", "\n", "# numba_mode = get_mode(\"NUMBA\").excluding(\"scan_push_out_seq\")\n", "# elementwise_multiply_pytensor_numba = pytensor.function([a_symbolic, b_symbolic], c, mode=numba_mode, trust_input=True)" ] }, { "cell_type": "code", "execution_count": 13, "id": "0b076f77", "metadata": {}, "outputs": [], "source": [ "# Not sure this makes a difference\n", "a_symbolic = pt.vector(\"a\", dtype=\"int32\")\n", "b_symbolic = pt.vector(\"b\", dtype=\"int32\")\n", "N_STEPS = 1000\n", "\n", "def step(a_element, b_element):\n", " return a_element * b_element\n", "\n", "c, _ = pytensor.scan(\n", " fn=step,\n", " sequences=[a_symbolic, b_symbolic],\n", " n_steps=N_STEPS\n", ")\n", "\n", "# compile function returning final a\n", "c_mode = get_mode(\"FAST_RUN\").excluding(\"scan_push_out_seq\")\n", "elementwise_multiply_pytensor = pytensor.function([a_symbolic, b_symbolic], c, trust_input=True, mode=c_mode)\n", "\n", "numba_mode = get_mode(\"NUMBA\").excluding(\"scan_push_out_seq\")\n", "elementwise_multiply_pytensor_numba = pytensor.function([a_symbolic, b_symbolic], c, mode=numba_mode, trust_input=True)" ] }, { "cell_type": "code", "execution_count": 14, "id": "14327cbf", "metadata": {}, "outputs": [], "source": [ "@jit(nopython=True)\n", "def elementwise_multiply_numba(a, b):\n", " n = a.shape[0]\n", " c = np.empty(n, dtype=a.dtype)\n", " for i in range(n):\n", " c[i] = a[i] * b[i]\n", " return c" ] }, { "cell_type": "code", "execution_count": 15, "id": "13715c9b", "metadata": {}, "outputs": [], "source": [ "@block\n", "@jax.jit\n", "def elementwise_multiply_jax(a, b):\n", " n = a.shape[0]\n", " c_init = jnp.empty(n, dtype=a.dtype)\n", " def step(i, c):\n", " return jax.lax.dynamic_update_index_in_dim(c, a[i] * b[i], i, axis=0)\n", " \n", " c = jax.lax.fori_loop(0, n, step, c_init)\n", " return c" ] }, { "cell_type": "code", "execution_count": 16, "id": "d43c4f9c", "metadata": {}, "outputs": [], "source": [ "a = np.random.normal(0, 1, (N_STEPS)).astype(np.int32)\n", "b = np.random.normal(0, 1, (N_STEPS)).astype(np.int32)" ] }, { "cell_type": "code", "execution_count": 17, "id": "f0f4ede5", "metadata": {}, "outputs": [], "source": [ "elem_mult_bench = Benchmarker(\n", " functions=[elementwise_multiply_pytensor, elementwise_multiply_numba, elementwise_multiply_jax, elementwise_multiply_pytensor_numba], \n", " names=['elementwise_multiply_pytensor', 'elementwise_multiply_numba', 'elementwise_multiply_jax', 'elementwise_multiply_pytensor_numba'],\n", " number=10\n", ")" ] }, { "cell_type": "code", "execution_count": 18, "id": "cdab8946", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LoopsMin (us)Max (us)Mean (us)StdDev (us)Median (us)IQR (us)OPS (Kops/s)Samples
elementwise_multiply_pytensorelem_mult_inputs10450.595800967.095900492.99194536.396199491.41249918.2219012.028431220
elementwise_multiply_numbaelem_mult_inputs100.3667000.7208000.3942470.0735630.3792000.0124972536.47825521
elementwise_multiply_jaxelem_mult_inputs107.51249910.3915988.1524270.6217827.8959010.593750122.66285355
elementwise_multiply_pytensor_numbaelem_mult_inputs1034.66249950.73750239.2808215.93421937.4083003.91260025.4577175
\n", "
" ], "text/plain": [ " Loops Min (us) \\\n", "elementwise_multiply_pytensor elem_mult_inputs 10 450.595800 \n", "elementwise_multiply_numba elem_mult_inputs 10 0.366700 \n", "elementwise_multiply_jax elem_mult_inputs 10 7.512499 \n", "elementwise_multiply_pytensor_numba elem_mult_inputs 10 34.662499 \n", "\n", " Max (us) Mean (us) \\\n", "elementwise_multiply_pytensor elem_mult_inputs 967.095900 492.991945 \n", "elementwise_multiply_numba elem_mult_inputs 0.720800 0.394247 \n", "elementwise_multiply_jax elem_mult_inputs 10.391598 8.152427 \n", "elementwise_multiply_pytensor_numba elem_mult_inputs 50.737502 39.280821 \n", "\n", " StdDev (us) \\\n", "elementwise_multiply_pytensor elem_mult_inputs 36.396199 \n", "elementwise_multiply_numba elem_mult_inputs 0.073563 \n", "elementwise_multiply_jax elem_mult_inputs 0.621782 \n", "elementwise_multiply_pytensor_numba elem_mult_inputs 5.934219 \n", "\n", " Median (us) IQR (us) \\\n", "elementwise_multiply_pytensor elem_mult_inputs 491.412499 18.221901 \n", "elementwise_multiply_numba elem_mult_inputs 0.379200 0.012497 \n", "elementwise_multiply_jax elem_mult_inputs 7.895901 0.593750 \n", "elementwise_multiply_pytensor_numba elem_mult_inputs 37.408300 3.912600 \n", "\n", " OPS (Kops/s) Samples \n", "elementwise_multiply_pytensor elem_mult_inputs 2.028431 220 \n", "elementwise_multiply_numba elem_mult_inputs 2536.478255 21 \n", "elementwise_multiply_jax elem_mult_inputs 122.662853 55 \n", "elementwise_multiply_pytensor_numba elem_mult_inputs 25.457717 5 " ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "elem_mult_bench.run(\n", " inputs={\n", " \"elem_mult_inputs\": {\"a\": a, \"b\": b},\n", " }\n", ")\n", "elem_mult_bench.summary()" ] }, { "cell_type": "markdown", "id": "25a87710", "metadata": {}, "source": [ "# Changepoint Detection Algorithms" ] }, { "cell_type": "markdown", "id": "16b2b312", "metadata": {}, "source": [ "## Cumulative Sum (CUSUM) Algorithm" ] }, { "cell_type": "code", "execution_count": 19, "id": "c1226cfc", "metadata": {}, "outputs": [], "source": [ "@jit(nopython=True)\n", "def cusum_adaptive_numba(x, alpha=0.01, k=0.5, h=5.0):\n", " \"\"\"\n", " Two-sided CUSUM with adaptive exponential moving average baseline.\n", " \n", " Parameters\n", " ----------\n", " x: np.ndarray\n", " input signal\n", " alpha: float\n", " EMA smoothing factor (0 < alpha <= 1)\n", " k: float\n", " slack to avoid small changes triggering alarms\n", " h: float\n", " threshold for raising an alarm\n", " \n", " Returns\n", " -------\n", " s_pos: np.ndarray\n", " upper CUSUM stats\n", " s_neg: np.ndarray\n", " lower CUSUM stats\n", " mu_t: np.ndarray\n", " evolving baseline estimate\n", " alarms_pos: np.ndarray\n", " alarms for upward changes\n", " alarms_neg: np.ndarray\n", " alarms for downward changes\n", " \"\"\"\n", " n = x.shape[0]\n", "\n", " s_pos = np.zeros(n, dtype=np.float32)\n", " s_neg = np.zeros(n, dtype=np.float32)\n", " mu_t = np.zeros(n, dtype=np.float32)\n", " alarms_pos = np.zeros(n, dtype=np.bool_)\n", " alarms_neg = np.zeros(n, dtype=np.bool_)\n", "\n", " # Initialization\n", " mu_t[0] = x[0]\n", "\n", " for i in range(1, n):\n", " # Update baseline (EMA)\n", " mu_t[i] = alpha * x[i] + (1 - alpha) * mu_t[i-1]\n", "\n", " # Update CUSUM stats\n", " s_pos[i] = max(0.0, s_pos[i-1] + x[i] - mu_t[i] - k)\n", " s_neg[i] = max(0.0, s_neg[i-1] - (x[i] - mu_t[i]) - k)\n", "\n", " # Alarms\n", " alarms_pos[i] = s_pos[i] > h\n", " alarms_neg[i] = s_neg[i] > h\n", "\n", " return s_pos, s_neg, mu_t, alarms_pos, alarms_neg" ] }, { "cell_type": "code", "execution_count": 20, "id": "14937a2a", "metadata": {}, "outputs": [], "source": [ "@block\n", "@jax.jit\n", "def cusum_adaptive_jax(x, alpha=0.01, k=0.5, h=5.0):\n", " \"\"\"\n", " Two-sided CUSUM with adaptive exponential moving average baseline.\n", " \n", " Parameters\n", " ----------\n", " x: jnp.ndarray\n", " input signal\n", " alpha: float\n", " EMA smoothing factor (0 < alpha <= 1)\n", " k: float\n", " slack to avoid small changes triggering alarms\n", " h: float\n", " threshold for raising an alarm\n", " \n", " Returns\n", " -------\n", " s_pos: jnp.ndarray\n", " upper CUSUM stats\n", " s_neg: jnp.ndarray\n", " lower CUSUM stats\n", " mu_t: jnp.ndarray\n", " evolving baseline estimate\n", " alarms_pos: jnp.ndarray\n", " alarms for upward changes\n", " alarms_neg: jnp.ndarray\n", " alarms for downward changes\n", " \"\"\"\n", " def body(carry, x_t):\n", " s_pos_prev, s_neg_prev, mu_prev = carry\n", " \n", " # Update EMA baseline\n", " mu_t = alpha * x_t + (1 - alpha) * mu_prev\n", " \n", " # Update CUSUMs using updated baseline\n", " s_pos = jnp.maximum(0.0, s_pos_prev + x_t - mu_t - k)\n", " s_neg = jnp.maximum(0.0, s_neg_prev - (x_t - mu_t) - k)\n", " \n", " new_carry = (s_pos, s_neg, mu_t)\n", " output = (s_pos, s_neg, mu_t)\n", " return new_carry, output\n", "\n", " # Initialize: CUSUMs at 0, initial mean = first sample\n", " s0 = (0.0, 0.0, x[0])\n", " _, (s_pos_vals, s_neg_vals, mu_vals) = jax.lax.scan(body, s0, x)\n", " \n", " # Thresholding\n", " alarms_pos = s_pos_vals > h\n", " alarms_neg = s_neg_vals > h\n", "\n", " return s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg\n" ] }, { "cell_type": "code", "execution_count": 21, "id": "815967a4", "metadata": {}, "outputs": [], "source": [ "x_symbolic = pt.vector(\"x\")\n", "alpha_symbolic = pt.scalar(\"alpha\")\n", "k_symbolic = pt.scalar(\"k\")\n", "h_symbolic = pt.scalar(\"h\")\n", "\n", "N_STEPS = 100 # Fixing this just incase but I don't think this is changing anything when we have sequences as input to scan\n", "\n", "def step(x_t, s_pos_prev, s_neg_prev, mu_prev, alpha, k):\n", " # Update EMA baseline\n", " mu_t = alpha * x_t + (1 - alpha) * mu_prev\n", " \n", " # Update CUSUMs using updated baseline\n", " s_pos = pt.maximum(0.0, s_pos_prev + x_t - mu_t - k)\n", " s_neg = pt.maximum(0.0, s_neg_prev - (x_t - mu_t) - k)\n", " \n", " return s_pos, s_neg, mu_t\n", "\n", "\n", "(s_pos_vals, s_neg_vals, mu_vals), updates = pytensor.scan(\n", " fn=step,\n", " outputs_info=[pt.constant(0., dtype=\"float32\"), pt.constant(0., dtype=\"float32\"), x_symbolic[0]],\n", " non_sequences=[alpha_symbolic, k_symbolic],\n", " sequences=[x_symbolic],\n", " n_steps=N_STEPS\n", ")\n", "\n", "# Thresholding\n", "alarms_pos = s_pos_vals > h_symbolic\n", "alarms_neg = s_neg_vals > h_symbolic\n", "\n", "cusum_adaptive_pytensor = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], trust_input=True)\n", "\n", "cusum_adaptive_pytensor_numba = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], mode=\"NUMBA\", trust_input=True)\n", "cusum_adaptive_pytensor_jax = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], mode=\"JAX\", trust_input=True)" ] }, { "cell_type": "code", "execution_count": 22, "id": "6c892129", "metadata": {}, "outputs": [], "source": [ "xs1 = np.random.normal(80, 20, size=(int(N_STEPS/2)))\n", "xs2 = np.random.normal(50, 20, size=(int(N_STEPS/2)))\n", "xs = np.concat((xs1, xs2))\n", "xs = xs.astype(np.float32)\n", "xs = xs.astype(np.float32)\n", "xs_std = (xs - xs.mean()) / xs.std()" ] }, { "cell_type": "code", "execution_count": 23, "id": "d21873d6", "metadata": {}, "outputs": [], "source": [ "cusum_bench = Benchmarker(\n", " functions=[cusum_adaptive_pytensor, cusum_adaptive_numba, cusum_adaptive_jax, cusum_adaptive_pytensor_numba, block(cusum_adaptive_pytensor_jax)], \n", " names=['cusum_adaptive_pytensor', 'cusum_adaptive_numba', 'cusum_adaptive_jax', 'cusum_adaptive_pytensor_numba', 'cusum_adaptive_pytensor_jax'],\n", " number=10\n", ")" ] }, { "cell_type": "code", "execution_count": 24, "id": "d8eab72d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LoopsMin (us)Max (us)Mean (us)StdDev (us)Median (us)IQR (us)OPS (Kops/s)Samples
cusum_adaptive_pytensorcusum_inputs10124.412501226.245899141.8374208.902762139.56670010.3250007.050326681
cusum_adaptive_numbacusum_inputs101.7292012.1749991.8065710.1507961.7458980.018753553.5347807
cusum_adaptive_jaxcusum_inputs1014.36659918.89999815.0767310.94297614.6332980.73852566.32737836
cusum_adaptive_pytensor_numbacusum_inputs1024.31660036.34170027.2049804.59606525.0333021.24170336.7579765
cusum_adaptive_pytensor_jaxcusum_inputs1021.47910127.29590122.7824931.64259521.8937481.73434743.89335230
\n", "
" ], "text/plain": [ " Loops Min (us) Max (us) \\\n", "cusum_adaptive_pytensor cusum_inputs 10 124.412501 226.245899 \n", "cusum_adaptive_numba cusum_inputs 10 1.729201 2.174999 \n", "cusum_adaptive_jax cusum_inputs 10 14.366599 18.899998 \n", "cusum_adaptive_pytensor_numba cusum_inputs 10 24.316600 36.341700 \n", "cusum_adaptive_pytensor_jax cusum_inputs 10 21.479101 27.295901 \n", "\n", " Mean (us) StdDev (us) \\\n", "cusum_adaptive_pytensor cusum_inputs 141.837420 8.902762 \n", "cusum_adaptive_numba cusum_inputs 1.806571 0.150796 \n", "cusum_adaptive_jax cusum_inputs 15.076731 0.942976 \n", "cusum_adaptive_pytensor_numba cusum_inputs 27.204980 4.596065 \n", "cusum_adaptive_pytensor_jax cusum_inputs 22.782493 1.642595 \n", "\n", " Median (us) IQR (us) \\\n", "cusum_adaptive_pytensor cusum_inputs 139.566700 10.325000 \n", "cusum_adaptive_numba cusum_inputs 1.745898 0.018753 \n", "cusum_adaptive_jax cusum_inputs 14.633298 0.738525 \n", "cusum_adaptive_pytensor_numba cusum_inputs 25.033302 1.241703 \n", "cusum_adaptive_pytensor_jax cusum_inputs 21.893748 1.734347 \n", "\n", " OPS (Kops/s) Samples \n", "cusum_adaptive_pytensor cusum_inputs 7.050326 681 \n", "cusum_adaptive_numba cusum_inputs 553.534780 7 \n", "cusum_adaptive_jax cusum_inputs 66.327378 36 \n", "cusum_adaptive_pytensor_numba cusum_inputs 36.757976 5 \n", "cusum_adaptive_pytensor_jax cusum_inputs 43.893352 30 " ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cusum_bench.run(\n", " inputs={\n", " \"cusum_inputs\": {\"x\": xs, \"alpha\": 0.1, \"k\": 0.5, \"h\": 3.5},\n", " }\n", ")\n", "cusum_bench.summary()" ] }, { "cell_type": "code", "execution_count": 25, "id": "3e9c3339", "metadata": {}, "outputs": [], "source": [ "outputs = cusum_adaptive_numba(xs_std, alpha=0.1, k=0.5, h=3.5)" ] }, { "cell_type": "code", "execution_count": 26, "id": "b85d0c0e", "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = go.Figure()\n", "fig.add_traces(\n", " [\n", " go.Scatter(\n", " x = np.arange(len(xs)),\n", " y = xs_std,\n", " name=\"series\"\n", " ),\n", " go.Scatter(\n", " x = np.arange(len(xs)),\n", " y = outputs[0],\n", " name=\"cum. positive devs.\"\n", " ),\n", " go.Scatter(\n", " x = np.arange(len(xs)),\n", " y = outputs[1],\n", " name=\"cum. negative devs.\"\n", " ),\n", " go.Scatter(\n", " x = np.arange(len(xs)),\n", " y = outputs[2],\n", " name=\"Exp. Mean\"\n", " ),\n", " go.Scatter(\n", " x = np.arange(len(xs)),\n", " y = outputs[3].astype(np.float16),\n", " name=\"positive alarms\"\n", " ),\n", " go.Scatter(\n", " x = np.arange(len(xs)),\n", " y = outputs[4].astype(np.float16),\n", " name=\"negative alarms\"\n", " ),\n", " \n", " ]\n", ")\n", "fig.update_layout(\n", " title = dict(\n", " text = \"CUSUM Change Point Detection Algorithm\"\n", " ),\n", " xaxis=dict(\n", " title = \"Time Index\"\n", " ),\n", " yaxis=dict(\n", " title = \"Standardized Series Scaled\"\n", " ),\n", " legend=dict(\n", " yanchor=\"top\",\n", " y=1.1,\n", " xanchor=\"left\",\n", " x=0,\n", " orientation=\"h\"\n", " ),\n", " template=\"plotly_dark\"\n", ")" ] }, { "cell_type": "markdown", "id": "c2ce03ea", "metadata": {}, "source": [ "## Pruned Exact Linear Time (PELT) Algorithm" ] }, { "cell_type": "code", "execution_count": 27, "id": "eab4cbb9", "metadata": {}, "outputs": [], "source": [ "@jit(nopython=True)\n", "def segment_cost_numba(S1, S2, i, j):\n", " \"\"\"Cost of segment x[i:j], SSE around mean\"\"\"\n", " n = j - i\n", " sum_x = S1[j] - S1[i]\n", " sum_x2 = S2[j] - S2[i]\n", " if n > 0:\n", " return sum_x2 - (sum_x ** 2) / n\n", " else:\n", " return np.inf\n", "\n", "@jit(nopython=True)\n", "def pelt_numba(x, beta=10.0):\n", " \"\"\"\n", " Pruned Exact Linear Time algorithm for change point detection\n", "\n", " Parameters\n", " ----------\n", " x: np.ndarray\n", " The timeseries signal\n", " beta: float\n", " Penalty of segmenting the series\n", "\n", " Returns\n", " -------\n", " C: np.ndarray\n", " The best costs up to segment t\n", " last_change: np.ndarray\n", " The last change point up to segment t\n", " \"\"\"\n", " n = len(x)\n", "\n", " # cumulative sums for cost\n", " S1 = np.empty(n+1, dtype=np.float32)\n", " S2 = np.empty(n+1, dtype=np.float32)\n", " S1[0], S2[0] = 0.0, 0.0\n", " for i in range(1, n+1):\n", " S1[i] = S1[i-1] + x[i-1]\n", " S2[i] = S2[i-1] + x[i-1]**2\n", "\n", " # DP arrays\n", " C = np.full((n+1,), np.inf)\n", " C[0] = -beta\n", " last_change = np.full((n+1,), -1)\n", " min_size = 3\n", "\n", " for t in range(1, n+1):\n", " costs = np.full(n, np.inf)\n", " for s in range(n):\n", " if s < t and (t - s) >= min_size:\n", " costs[s] = C[s] + segment_cost_numba(S1, S2, s, t) + beta\n", " best_s = np.argmin(costs)\n", " C[t] = costs[best_s]\n", " last_change[t] = best_s\n", "\n", " return C, last_change" ] }, { "cell_type": "code", "execution_count": 28, "id": "e6997389", "metadata": {}, "outputs": [], "source": [ "def segment_cost_jax(S1, S2, i, j):\n", " \"\"\"Cost of segment x[i:j], SSE around mean\"\"\"\n", " n = j - i\n", " sum_x = S1[j] - S1[i]\n", " sum_x2 = S2[j] - S2[i]\n", " return jnp.where(n > 0, sum_x2 - (sum_x ** 2) / n, jnp.inf)\n", "\n", "@block\n", "@jax.jit\n", "def pelt_jax(x, beta=10.0):\n", " \"\"\"\n", " Pruned Exact Linear Time algorithm for change point detection\n", "\n", " Parameters\n", " ----------\n", " x: np.ndarray\n", " The timeseries signal\n", " beta: float\n", " Penalty of segmenting the series\n", "\n", " Returns\n", " -------\n", " C: jnp.ndarray\n", " The best costs up to segment t\n", " last_change: jnp.ndarray\n", " The last change point up to segment t\n", " \"\"\"\n", " n = len(x)\n", "\n", " # cumulative sums for cost\n", " S1 = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(x)])\n", " S2 = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(x**2)])\n", "\n", " # DP arrays\n", " C = jnp.full((n+1,), jnp.inf)\n", " C = C.at[0].set(-beta)\n", " last_change = jnp.full((n+1,), -1)\n", " min_size = 3\n", "\n", " s_all = jnp.arange(n) # all possible candidates\n", "\n", " def body(t, carry):\n", " C, last_change = carry\n", "\n", " # Compute cost for all s < t, mask invalid\n", " # valid = s_all < t & ((t - s_all) >= min_size)\n", " \n", " valid = (s_all < t) & ((t - s_all) >= min_size)\n", " costs = jnp.where(\n", " valid,\n", " C[s_all] + segment_cost_jax(S1, S2, s_all, t) + beta,\n", " jnp.inf\n", " )\n", "\n", " best_s = jnp.argmin(costs)\n", " C = C.at[t].set(costs[best_s])\n", " last_change = last_change.at[t].set(best_s)\n", " return C, last_change\n", "\n", " C, last_change = jax.lax.fori_loop(1, n+1, body, (C, last_change))\n", " return C, last_change" ] }, { "cell_type": "code", "execution_count": 29, "id": "094b9e8e", "metadata": {}, "outputs": [], "source": [ "def segment_cost_pytensor(S1, S2, i, j):\n", " \"\"\"Cost of segment x[i:j], SSE around mean\"\"\"\n", " n = j - i\n", " sum_x = S1[j] - S1[i]\n", " sum_x2 = S2[j] - S2[i]\n", " return pt.switch(\n", " pt.gt(n, 0),\n", " sum_x2 - (sum_x ** 2) / n,\n", " np.inf\n", " )\n" ] }, { "cell_type": "code", "execution_count": 30, "id": "03e5e927", "metadata": {}, "outputs": [], "source": [ "x_symbolic = pt.vector(\"x\")\n", "beta_symbolic = pt.scalar(\"beta\")\n", "n = x_symbolic.shape[0]\n", "N_STEPS=100\n", "\n", "# cumulative sums for cost\n", "S1 = pt.concatenate([pt.as_tensor([0.0]), pt.cumsum(x_symbolic)])\n", "S2 = pt.concatenate([pt.as_tensor([0.0]), pt.cumsum(x_symbolic**2)])\n", "\n", "# DP arrays\n", "C_init = pt.alloc(np.inf, n+1)\n", "C_init = pt.set_subtensor(C_init[0], -beta_symbolic)\n", "last_change_init = pt.alloc(-1, n+1)\n", "\n", "s_all = pt.arange(n) # candidate change points\n", "min_size = 3\n", "\n", "def step(t, C_prev, last_change_prev, S1, S2, beta_symbolic, s_all):\n", " # valid = (s_all < t) & ((t - s_all) >= min_size)\n", " valid = pt.and_(pt.lt(s_all, t), pt.ge(t - s_all, min_size))\n", "\n", " # compute costs for all candidates\n", " costs, _ = pytensor.scan(\n", " fn=lambda s: pt.switch(\n", " valid[s],\n", " C_prev[s] + segment_cost_pytensor(S1, S2, s, t) + beta_symbolic,\n", " np.inf\n", " ),\n", " sequences=[pt.arange(n)]\n", " )\n", " costs = costs.flatten()\n", "\n", " best_s = pt.argmin(costs, axis=0)\n", " C_new = pt.set_subtensor(C_prev[t], costs[best_s])\n", " last_change_new = pt.set_subtensor(last_change_prev[t], best_s)\n", "\n", " return C_new, last_change_new\n", "\n", "(C_vals, last_change_vals), _ = pytensor.scan(\n", " fn=step,\n", " sequences=[pt.arange(1, n+1)],\n", " outputs_info=[C_init, last_change_init],\n", " non_sequences=[S1, S2, beta_symbolic, s_all],\n", " n_steps=N_STEPS # Added fixed iterations here\n", ")\n", "\n", "pelt_pytensor = pytensor.function([x_symbolic, beta_symbolic], [C_vals[-1], last_change_vals[-1]], trust_input=True)\n", "pelt_pytensor_numba = pytensor.function(inputs=[x_symbolic, beta_symbolic], outputs=[C_vals[-1], last_change_vals[-1]], mode=\"NUMBA\", trust_input=True)" ] }, { "cell_type": "code", "execution_count": 31, "id": "69aea9a2", "metadata": {}, "outputs": [], "source": [ "pelt_bench = Benchmarker(\n", " functions=[pelt_pytensor, pelt_numba, pelt_jax, pelt_pytensor_numba], \n", " names=['pelt_pytensor', 'pelt_numba', 'pelt_jax', 'pelt_pytensor_numba'],\n", " number=10\n", ")" ] }, { "cell_type": "code", "execution_count": 32, "id": "5aee4a18", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LoopsMin (us)Max (us)Mean (us)StdDev (us)Median (us)IQR (us)OPS (Kops/s)Samples
pelt_pytensorpelt_inputs1011903.71250012687.70830112234.242133245.98912312224.875001277.7999010.0817389
pelt_numbapelt_inputs1017.40829922.20410119.4008201.58531819.2417011.00000251.5442135
pelt_jaxpelt_inputs1064.61670082.34579971.3396794.87938970.2791004.89374814.01744519
pelt_pytensor_numbapelt_inputs102330.9250012496.1583002424.29998056.5148422435.22500161.8333990.4124905
\n", "
" ], "text/plain": [ " Loops Min (us) Max (us) \\\n", "pelt_pytensor pelt_inputs 10 11903.712500 12687.708301 \n", "pelt_numba pelt_inputs 10 17.408299 22.204101 \n", "pelt_jax pelt_inputs 10 64.616700 82.345799 \n", "pelt_pytensor_numba pelt_inputs 10 2330.925001 2496.158300 \n", "\n", " Mean (us) StdDev (us) Median (us) \\\n", "pelt_pytensor pelt_inputs 12234.242133 245.989123 12224.875001 \n", "pelt_numba pelt_inputs 19.400820 1.585318 19.241701 \n", "pelt_jax pelt_inputs 71.339679 4.879389 70.279100 \n", "pelt_pytensor_numba pelt_inputs 2424.299980 56.514842 2435.225001 \n", "\n", " IQR (us) OPS (Kops/s) Samples \n", "pelt_pytensor pelt_inputs 277.799901 0.081738 9 \n", "pelt_numba pelt_inputs 1.000002 51.544213 5 \n", "pelt_jax pelt_inputs 4.893748 14.017445 19 \n", "pelt_pytensor_numba pelt_inputs 61.833399 0.412490 5 " ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pelt_bench.run(\n", " inputs={\n", " \"pelt_inputs\": {\"x\": xs_std, \"beta\": 2. * np.log(len(xs_std))},\n", " }\n", ")\n", "pelt_bench.summary()" ] }, { "cell_type": "code", "execution_count": 33, "id": "47cae530", "metadata": {}, "outputs": [], "source": [ "outputs = pelt_numba(xs_std, 2. * np.log(len(xs_std)))" ] }, { "cell_type": "code", "execution_count": 34, "id": "7b395b06", "metadata": {}, "outputs": [], "source": [ "def plot_pelt_diagnostics(x, cps, C):\n", " \"\"\"\n", " Diagnostic plots for PELT changepoint detection.\n", " \n", " Args:\n", " x: 1D array, original time series\n", " C: 1D array, cumulative DP cost from pelt()\n", " cps: list of changepoint indices (sorted ascending)\n", " \"\"\"\n", " n = len(x)\n", " cps_full = [0] + cps + [n]\n", "\n", " # Segment means, std, SSE\n", " segment_means = []\n", " segment_stds = []\n", " segment_costs = []\n", " for start, end in zip(cps_full[:-1], cps_full[1:]):\n", " seg = x[start:end]\n", " mean = np.mean(seg)\n", " std = np.std(seg)\n", " cost = np.sum((seg - mean)**2)\n", " segment_means.append(mean)\n", " segment_stds.append(std)\n", " segment_costs.append(cost)\n", "\n", " # Step function for segment mean\n", " mean_step = np.zeros(n)\n", " for i, (start, end) in enumerate(zip(cps_full[:-1], cps_full[1:])):\n", " mean_step[start:end] = segment_means[i]\n", "\n", " # Step function for segment std\n", " std_step = np.zeros(n)\n", " for i, (start, end) in enumerate(zip(cps_full[:-1], cps_full[1:])):\n", " std_step[start:end] = segment_stds[i]\n", "\n", " if len(x) < 20:\n", " title1 = \"Warning: Sample size is small - Detected Changepoints\"\n", " else:\n", " title1 = \"Detected Changepoints\"\n", "\n", " fig = make_subplots(\n", " rows=4, \n", " cols=1,\n", " subplot_titles=(title1, \"Average Shifts\", \"Variability Shifts\", \"Cumulative Cost\")\n", " )\n", "\n", " fig.add_trace(\n", " go.Scatter(\n", " x = np.arange(len(x)),\n", " y = x,\n", " line_color=\"royalblue\",\n", " name = \"Actuals\",\n", " mode=\"lines\",\n", " showlegend=False,\n", " hovertemplate=\"Time Point: %{x}
Actual: %{y}\"\n", " ),\n", " row=1, col=1\n", " )\n", "\n", " for cp in cps:\n", " fig.add_vline(x=cp, line_dash='dash', line_color=\"red\", row=1, col=1)\n", "\n", " fig.add_trace(\n", " go.Scatter(\n", " x = np.arange(len(x)),\n", " y = x,\n", " name = \"Actuals\",\n", " mode=\"lines\",\n", " line_color=\"rgba(105, 105, 105, 0.25)\",\n", " showlegend=False,\n", " hoverinfo=\"skip\"\n", " ),\n", " row=2, col=1\n", " )\n", "\n", " fig.add_trace(\n", " go.Scatter(\n", " x = np.arange(len(x)),\n", " y = mean_step,\n", " name = \"Average\",\n", " line_color=\"royalblue\",\n", " showlegend=False,\n", " hovertemplate=\"Time Point: %{x}
Average: %{y}\"\n", " ),\n", " row=2, col=1\n", " )\n", "\n", " fig.add_trace(\n", " go.Scatter(\n", " x = np.arange(len(x)),\n", " y = std_step,\n", " name = \"Standard Deviation\",\n", " line_color=\"royalblue\",\n", " showlegend=False,\n", " hovertemplate=\"Time Point: %{x}
Standard Deviation: %{y}\"\n", " ),\n", " row=3, col=1\n", " )\n", "\n", " fig.add_trace(\n", " go.Scatter(\n", " x = np.arange(len(x)),\n", " y = C,\n", " name = \"Cumulative Cost\",\n", " line_color=\"royalblue\",\n", " showlegend=False,\n", " hovertemplate=\"Time Point: %{x}
Cost: %{y}\"\n", " ),\n", " row=4, col=1\n", " )\n", "\n", " for cp in cps:\n", " fig.add_vline(x=cp, line_dash='dash', line_color=\"red\", row=4, col=1)\n", "\n", " return fig.update_layout(height=1000, width=1200, template=\"plotly_dark\")\n" ] }, { "cell_type": "code", "execution_count": 35, "id": "e1de7df5", "metadata": {}, "outputs": [], "source": [ "def get_changepoints(last_change, n):\n", " \"\"\"\n", " Backtrack changepoints from last_change array.\n", " \n", " Args:\n", " last_change: array from pelt()\n", " n: length of input series\n", "\n", " Returns:\n", " list of changepoint indices (sorted ascending)\n", " \"\"\"\n", " cps = []\n", " t = n\n", " while t > 0:\n", " s = int(last_change[t])\n", " if s <= 0:\n", " break\n", " cps.append(s)\n", " t = s\n", " return list(reversed(cps))" ] }, { "cell_type": "code", "execution_count": 36, "id": "b73d80ac", "metadata": {}, "outputs": [], "source": [ "cps = get_changepoints(outputs[1], n=len(xs_std))" ] }, { "cell_type": "code", "execution_count": 37, "id": "e2e376de", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_pelt_diagnostics(xs, cps, outputs[0])" ] }, { "cell_type": "markdown", "id": "09f2e235", "metadata": {}, "source": [ "# Kalman Filter Algorithms" ] }, { "cell_type": "markdown", "id": "1c42336f", "metadata": {}, "source": [ "## Linear Gaussian Kalman Filter" ] }, { "cell_type": "code", "execution_count": 38, "id": "5004eeab", "metadata": {}, "outputs": [], "source": [ "@jit(nopython=True)\n", "def atrocious_kalman_filter_numba(z, F, H, Q, R, x0, P0):\n", " \"\"\"\n", " This implementation of the Kalman filter is Atrocious and in standard Python would be a \n", " BIG NO-NO. That being said this version SIGNIFICANTLY reduces Numba Compilation time. \n", " \n", " Linear Gaussian Kalman filter algorithm\n", "\n", " Parameters\n", " ----------\n", " z: np.ndarray\n", " shape (T, m) - observations\n", " F: np.ndarray\n", " state transition matrix - shape (n, n)\n", " H: np.ndarray\n", " observation/design matrix - shape (m, n)\n", " Q: np.ndarray\n", " process noise covariance - shape (n, n)\n", " R: np.ndarray\n", " observation noise covariance - shape (m, m)\n", " x0: np.ndarray\n", " initial state mean - shape (n,)\n", " P0: np.ndarray\n", " initial state covariance - shape (n, n)\n", "\n", " Returns\n", " -------\n", " xs: np.ndarray\n", " shape (T, n) - filtered state means\n", " Ps: np.ndarray\n", " shape (T, n, n) - filtered state covariances\n", " \"\"\"\n", " T = z.shape[0]\n", " m = z.shape[1]\n", " n = x0.shape[0]\n", "\n", " xs = np.empty((T, n), dtype=np.float32)\n", " Ps = np.empty((T, n, n), dtype=np.float32)\n", "\n", " # local working arrays\n", " x = np.empty(n, dtype=np.float32)\n", " for i in range(n):\n", " x[i] = x0[i]\n", " P = np.empty((n, n), dtype=np.float32)\n", " for i in range(n):\n", " for j in range(n):\n", " P[i, j] = P0[i, j]\n", "\n", " # temporary matrices/vectors\n", " x_pred = np.empty((T, n), dtype=np.float32)\n", " P_pred = np.empty((T, n, n), dtype=np.float32)\n", " y = np.empty(m, dtype=np.float32)\n", " S = np.empty((m, m), dtype=np.float32)\n", " K = np.empty((n, m), dtype=np.float32)\n", " I_n = np.eye(n, dtype=np.float32)\n", "\n", " for t in range(T):\n", " # === Predict ===\n", " # x_pred = F @ x\n", " for i in range(n):\n", " s = 0.0\n", " for j in range(n):\n", " s += F[i, j] * x[j]\n", " x_pred[t, i] = s\n", "\n", " # P_pred = F @ P @ F.T + Q\n", " # temp = F @ P\n", " temp = np.empty((n, n), dtype=np.float32)\n", " for i in range(n):\n", " for j in range(n):\n", " s = 0.0\n", " for k in range(n):\n", " s += F[i, k] * P[k, j]\n", " temp[i, j] = s\n", " # P_pred = temp @ F.T\n", " for i in range(n):\n", " for j in range(n):\n", " s = 0.0\n", " for k in range(n):\n", " s += temp[i, k] * F[j, k] # F.T[k, j] = F[j, k]\n", " P_pred[t, i, j] = s + Q[i, j]\n", "\n", " # === Update ===\n", " # y = z[t] - H @ x_pred\n", " for i in range(m):\n", " s = 0.0\n", " for j in range(n):\n", " s += H[i, j] * x_pred[t, j]\n", " y[i] = z[t, i] - s\n", "\n", " # S = H @ P_pred @ H.T + R\n", " # temp2 = H @ P_pred\n", " temp2 = np.empty((m, n), dtype=np.float32)\n", " for i in range(m):\n", " for j in range(n):\n", " s = 0.0\n", " for k in range(n):\n", " s += H[i, k] * P_pred[t, k, j]\n", " temp2[i, j] = s\n", " # S = temp2 @ H.T\n", " for i in range(m):\n", " for j in range(m):\n", " s = 0.0\n", " for k in range(n):\n", " s += temp2[i, k] * H[j, k] # H.T[k,j] = H[j,k]\n", " S[i, j] = s + R[i, j]\n", "\n", " # K = P_pred @ H.T @ inv(S)\n", " # first compute P_pred @ H.T -> (n, m)\n", " P_Ht = np.empty((n, m), dtype=np.float32)\n", " for i in range(n):\n", " for j in range(m):\n", " s = 0.0\n", " for k in range(n):\n", " s += P_pred[t, i, k] * H[j, k] # H.T[k,j] = H[j,k]\n", " P_Ht[i, j] = s\n", "\n", " # invert S\n", " S_inv = np.linalg.inv(S)\n", "\n", " # K = P_Ht @ S_inv (n,m) @ (m,m) -> (n,m)\n", " for i in range(n):\n", " for j in range(m):\n", " s = 0.0\n", " for k in range(m):\n", " s += P_Ht[i, k] * S_inv[k, j]\n", " K[i, j] = s\n", "\n", " # x = x_pred + K @ y\n", " for i in range(n):\n", " s = 0.0\n", " for j in range(m):\n", " s += K[i, j] * y[j]\n", " x[i] = x_pred[t, i] + s\n", "\n", " # P = (I - K H) P_pred\n", " # compute (I - K H)\n", " KH = np.empty((n, n), dtype=np.float32)\n", " for i in range(n):\n", " for j in range(n):\n", " s = 0.0\n", " for k in range(m):\n", " s += K[i, k] * H[k, j]\n", " KH[i, j] = s\n", "\n", " I_minus_KH = np.empty((n, n), dtype=np.float32)\n", " for i in range(n):\n", " for j in range(n):\n", " I_minus_KH[i, j] = I_n[i, j] - KH[i, j]\n", "\n", " # P = I_minus_KH @ P_pred\n", " for i in range(n):\n", " for j in range(n):\n", " s = 0.0\n", " for k in range(n):\n", " s += I_minus_KH[i, k] * P_pred[t, k, j]\n", " P[i, j] = s\n", "\n", " # store results\n", " for i in range(n):\n", " xs[t, i] = x[i]\n", " for i in range(n):\n", " for j in range(n):\n", " Ps[t, i, j] = P[i, j]\n", "\n", " return xs, Ps, x_pred, P_pred\n" ] }, { "cell_type": "code", "execution_count": 39, "id": "25bcb14e", "metadata": {}, "outputs": [], "source": [ "@jit(nopython=True)\n", "def kalman_filter_numba(z, F, H, Q, R, x0, P0):\n", " \"\"\"\n", " Linear Gaussian Kalman filter algorithm\n", "\n", " Parameters\n", " ----------\n", " z: np.ndarray\n", " shape (T, m) - observations\n", " F: np.ndarray\n", " state transition matrix - shape (n, n)\n", " H: np.ndarray\n", " observation/design matrix - shape (m, n)\n", " Q: np.ndarray\n", " process noise covariance - shape (n, n)\n", " R: np.ndarray\n", " observation noise covariance - shape (m, m)\n", " x0: np.ndarray\n", " initial state mean - shape (n,)\n", " P0: np.ndarray\n", " initial state covariance - shape (n, n)\n", "\n", " Returns\n", " -------\n", " xs: np.ndarray\n", " shape (T, n) - filtered state means\n", " Ps: np.ndarray\n", " shape (T, n, n) - filtered state covariances\n", " \"\"\"\n", " T, m = z.shape\n", " n = x0.shape[0]\n", "\n", " xs = np.zeros((T, n), dtype=np.float32)\n", " Ps = np.zeros((T, n, n), dtype=np.float32)\n", "\n", " x_pred = np.zeros((T, n), dtype=np.float32)\n", " P_pred = np.zeros((T, n, n), dtype=np.float32)\n", "\n", " x = x0.copy()\n", " P = P0.copy()\n", "\n", " I = np.eye(n, dtype=np.float32)\n", "\n", " for t in range(T):\n", " # --- Predict ---\n", " x_pred[t] = F @ x\n", " P_pred[t] = F @ P @ F.T + Q\n", "\n", " # --- Update ---\n", " y = z[t] - H @ x_pred[t]\n", " S = H @ P_pred[t] @ H.T + R\n", " K = P_pred[t] @ H.T @ np.linalg.inv(S)\n", "\n", " x = x_pred[t] + K @ y\n", " P = (I - K @ H) @ P_pred[t]\n", "\n", " xs[t] = x\n", " Ps[t] = P\n", "\n", " return xs, Ps, x_pred, P_pred" ] }, { "cell_type": "code", "execution_count": 40, "id": "cd715dfb", "metadata": {}, "outputs": [], "source": [ "@block\n", "@jax.jit\n", "def kalman_filter_jax(z, F, H, Q, R, x0, P0):\n", " \"\"\"\n", " Linear Gaussian Kalman filter algorithm\n", "\n", " Parameters\n", " ----------\n", " z: np.ndarray\n", " shape (T, m) - observations\n", " F: np.ndarray\n", " state transition matrix - shape (n, n)\n", " H: np.ndarray\n", " observation/design matrix - shape (m, n)\n", " Q: np.ndarray\n", " process noise covariance - shape (n, n)\n", " R: np.ndarray\n", " observation noise covariance - shape (m, m)\n", " x0: np.ndarray\n", " initial state mean - shape (n,)\n", " P0: np.ndarray\n", " initial state covariance - shape (n, n)\n", "\n", " Returns\n", " -------\n", " xs: jnp.ndarray\n", " shape (T, n) - filtered state means\n", " Ps: jnp.ndarray\n", " shape (T, n, n) - filtered state covariances\n", " \"\"\"\n", "\n", " n = x0.shape[0]\n", " I = jnp.eye(n)\n", " X_pred_init = jnp.zeros((1,))\n", " P_pred_init = jnp.zeros((1, 1,))\n", "\n", " def step(carry, z_t):\n", " x, P, _, _ = carry\n", "\n", " # --- Predict ---\n", " x_pred = F @ x\n", " P_pred = F @ P @ F.T + Q\n", "\n", " # --- Update ---\n", " y = z_t - H @ x_pred\n", " S = H @ P_pred @ H.T + R\n", " K = P_pred @ H.T @ jnp.linalg.inv(S)\n", "\n", " x_new = x_pred + K @ y\n", " P_new = (I - K @ H) @ P_pred\n", "\n", " return (x_new, P_new, x_pred, P_pred), (x_new, P_new, x_pred, P_pred)\n", "\n", " # run scan\n", " (_, _, _, _), (xs, Ps, x_pred, P_pred) = jax.lax.scan(step, (x0, P0, X_pred_init, P_pred_init), z)\n", "\n", " return xs, Ps, x_pred, P_pred" ] }, { "cell_type": "code", "execution_count": 41, "id": "5af37c40", "metadata": {}, "outputs": [], "source": [ "z_symbolic = pt.matrix(\"z\")\n", "F_symbolic = pt.matrix(\"F\")\n", "H_symbolic = pt.matrix(\"H\")\n", "Q_symbolic = pt.matrix(\"Q\")\n", "R_symbolic = pt.matrix(\"R\")\n", "x0_symbolic = pt.vector(\"x0\")\n", "P0_symbolic = pt.matrix(\"P0\")\n", "\n", "n = x0_symbolic.shape[0]\n", "I = pt.eye(n)\n", "X_pred_init = pt.zeros_like(x0_symbolic)\n", "P_pred_init = pt.zeros_like(P0_symbolic)\n", "\n", "N_STEPS = 500\n", "\n", "def step(z_t, x, P, x_pred, P_pred, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, I):\n", "\n", " # --- Predict ---\n", " x_pred = F_symbolic @ x\n", " P_pred = F_symbolic @ P @ F_symbolic.T + Q_symbolic\n", "\n", " # --- Update ---\n", " y = z_t - H_symbolic @ x_pred\n", " S = H_symbolic @ P_pred @ H_symbolic.T + R_symbolic\n", " K = P_pred @ H_symbolic.T @ pt.linalg.inv(S)\n", "\n", " x_new = x_pred + K @ y\n", " P_new = (I - K @ H_symbolic) @ P_pred\n", "\n", " return x_new, P_new, x_pred, P_pred\n", "\n", "# run scan\n", "(xs, Ps, x_pred, P_pred), _ = pytensor.scan(\n", " fn=step,\n", " outputs_info=[x0_symbolic, P0_symbolic, X_pred_init, P_pred_init],\n", " sequences=[z_symbolic],\n", " non_sequences=[F_symbolic, H_symbolic, Q_symbolic, R_symbolic, I],\n", " n_steps=N_STEPS\n", ")\n", "\n", "kalman_filter_pytensor = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], trust_input=True)\n", "\n", "kalman_filter_pytensor_numba = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], mode=\"NUMBA\", trust_input=True)\n", "kalman_filter_pytensor_jax = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], mode=\"JAX\", trust_input=True)" ] }, { "cell_type": "code", "execution_count": 42, "id": "3c9e7254", "metadata": {}, "outputs": [], "source": [ "T = 500\n", "F = np.array([[1.0]]).astype(np.float32)\n", "H = np.array([[1.0]]).astype(np.float32)\n", "Q = np.array([[0.01]]).astype(np.float32)\n", "R = np.array([[0.1]]).astype(np.float32)\n", "x0 = np.array([0.0]).astype(np.float32)\n", "P0 = np.array([[1.0]]).astype(np.float32)\n", "\n", "true = 1.0\n", "z = (true + 0.4*np.random.randn(T)).reshape(T, 1).astype(np.float32)" ] }, { "cell_type": "code", "execution_count": 43, "id": "00472afd", "metadata": {}, "outputs": [], "source": [ "kalman_filter_bench = Benchmarker(\n", " functions=[kalman_filter_pytensor, atrocious_kalman_filter_numba, kalman_filter_numba, kalman_filter_jax, kalman_filter_pytensor_numba, block(kalman_filter_pytensor_jax)], \n", " names=['kalman_filter_pytensor', 'atrocious_kalman_filter_numba', 'kalman_filter_numba', 'kalman_filter_jax', 'kalman_filter_pytensor_numba', 'kalman_filter_pytensor_jax'],\n", " number=10\n", ")" ] }, { "cell_type": "code", "execution_count": 44, "id": "68ec703d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/tmpmbe0qi_u:33: NumbaWarning:\n", "\n", "\u001b[1m\u001b[1mCannot cache compiled function \"scan\" as it uses dynamic globals (such as ctypes pointers and large global arrays)\u001b[0m\u001b[0m\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LoopsMin (us)Max (us)Mean (us)StdDev (us)Median (us)IQR (us)OPS (Kops/s)Samples
kalman_filter_pytensorkalman_filter_inputs105675.5916016334.5666015949.072047159.6177455937.045900191.6000010.16809317
atrocious_kalman_filter_numbakalman_filter_inputs10297.462501330.600000314.41502011.425574315.53340114.2874023.1805105
kalman_filter_numbakalman_filter_inputs10641.612499718.120800684.41580032.973390706.01250062.0249981.4611005
kalman_filter_jaxkalman_filter_inputs10305.941701368.020899333.58010518.368842329.07704922.4343512.99778118
kalman_filter_pytensor_numbakalman_filter_inputs10853.600001933.024997897.57000033.462131918.74999961.3083981.1141195
kalman_filter_pytensor_jaxkalman_filter_inputs10304.041599372.449998330.68582517.271642327.66869824.3562493.02401820
\n", "
" ], "text/plain": [ " Loops Min (us) \\\n", "kalman_filter_pytensor kalman_filter_inputs 10 5675.591601 \n", "atrocious_kalman_filter_numba kalman_filter_inputs 10 297.462501 \n", "kalman_filter_numba kalman_filter_inputs 10 641.612499 \n", "kalman_filter_jax kalman_filter_inputs 10 305.941701 \n", "kalman_filter_pytensor_numba kalman_filter_inputs 10 853.600001 \n", "kalman_filter_pytensor_jax kalman_filter_inputs 10 304.041599 \n", "\n", " Max (us) Mean (us) \\\n", "kalman_filter_pytensor kalman_filter_inputs 6334.566601 5949.072047 \n", "atrocious_kalman_filter_numba kalman_filter_inputs 330.600000 314.415020 \n", "kalman_filter_numba kalman_filter_inputs 718.120800 684.415800 \n", "kalman_filter_jax kalman_filter_inputs 368.020899 333.580105 \n", "kalman_filter_pytensor_numba kalman_filter_inputs 933.024997 897.570000 \n", "kalman_filter_pytensor_jax kalman_filter_inputs 372.449998 330.685825 \n", "\n", " StdDev (us) Median (us) \\\n", "kalman_filter_pytensor kalman_filter_inputs 159.617745 5937.045900 \n", "atrocious_kalman_filter_numba kalman_filter_inputs 11.425574 315.533401 \n", "kalman_filter_numba kalman_filter_inputs 32.973390 706.012500 \n", "kalman_filter_jax kalman_filter_inputs 18.368842 329.077049 \n", "kalman_filter_pytensor_numba kalman_filter_inputs 33.462131 918.749999 \n", "kalman_filter_pytensor_jax kalman_filter_inputs 17.271642 327.668698 \n", "\n", " IQR (us) OPS (Kops/s) \\\n", "kalman_filter_pytensor kalman_filter_inputs 191.600001 0.168093 \n", "atrocious_kalman_filter_numba kalman_filter_inputs 14.287402 3.180510 \n", "kalman_filter_numba kalman_filter_inputs 62.024998 1.461100 \n", "kalman_filter_jax kalman_filter_inputs 22.434351 2.997781 \n", "kalman_filter_pytensor_numba kalman_filter_inputs 61.308398 1.114119 \n", "kalman_filter_pytensor_jax kalman_filter_inputs 24.356249 3.024018 \n", "\n", " Samples \n", "kalman_filter_pytensor kalman_filter_inputs 17 \n", "atrocious_kalman_filter_numba kalman_filter_inputs 5 \n", "kalman_filter_numba kalman_filter_inputs 5 \n", "kalman_filter_jax kalman_filter_inputs 18 \n", "kalman_filter_pytensor_numba kalman_filter_inputs 5 \n", "kalman_filter_pytensor_jax kalman_filter_inputs 20 " ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "kalman_filter_bench.run(\n", " inputs={\n", " \"kalman_filter_inputs\": {\"z\": z, \"F\": F, \"H\": H, \"Q\": Q, \"R\": R, \"x0\": x0, \"P0\": P0},\n", " }\n", ")\n", "kalman_filter_bench.summary()" ] }, { "cell_type": "code", "execution_count": 45, "id": "37526f86", "metadata": {}, "outputs": [], "source": [ "xs, Ps, x_pred, P_pred = kalman_filter_jax(z, F, H, Q, R, x0, P0)" ] }, { "cell_type": "code", "execution_count": 46, "id": "231140ba", "metadata": {}, "outputs": [], "source": [ "def compute_pred_intervals(z, x_pred, P_pred, H, R, zscore=1.96):\n", " T = z.shape[0]\n", " m = H.shape[0]\n", " mean = np.zeros((T, m))\n", " lower = np.zeros((T, m))\n", " upper = np.zeros((T, m))\n", " outside = np.zeros(T, dtype=np.bool_)\n", "\n", " for t in range(T):\n", " mean[t] = H @ x_pred[t]\n", " S = H @ P_pred[t] @ H.T + R\n", " std = np.sqrt(np.diag(S))\n", " lower[t] = mean[t] - zscore * std\n", " upper[t] = mean[t] + zscore * std\n", "\n", " # check coverage of actual obs\n", " outside[t] = np.any((z[t] < lower[t]) | (z[t] > upper[t]))\n", "\n", " coverage = 1 - outside.mean()\n", " return mean, lower, upper, coverage\n" ] }, { "cell_type": "code", "execution_count": 47, "id": "ff739003", "metadata": {}, "outputs": [], "source": [ "mean, lower, upper, coverage = compute_pred_intervals(z, x_pred, P_pred, H, R)" ] }, { "cell_type": "code", "execution_count": 48, "id": "7546c95c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "np.float64(0.91)" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "coverage" ] }, { "cell_type": "code", "execution_count": 49, "id": "6b765a37", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig= go.Figure()\n", "fig.add_traces(\n", " [\n", " go.Scatter(\n", " x = np.arange(T),\n", " y = z.ravel(),\n", " mode=\"markers\",\n", " marker_color = \"royalblue\",\n", " name = \"actuals\"\n", " ),\n", " go.Scatter(\n", " x = np.arange(T),\n", " y = xs.ravel(),\n", " mode = \"lines\",\n", " marker_color = \"orange\",\n", " name = \"filtered mean\"\n", " ),\n", " go.Scatter(\n", " name=\"\", \n", " x=np.arange(T), \n", " y=upper.ravel(), \n", " mode=\"lines\", \n", " marker=dict(color=\"#eb8c34\"), \n", " line=dict(width=0), \n", " legendgroup=\"95% CI\",\n", " showlegend=False\n", " ),\n", " go.Scatter(\n", " name=\"95% CI\", \n", " x=np.arange(T), \n", " y=lower.ravel(), \n", " mode=\"lines\", marker=dict(color=\"#eb8c34\"), \n", " line=dict(width=0), \n", " legendgroup=\"95% CI\", \n", " fill='tonexty', \n", " fillcolor='rgba(235, 140, 52, 0.2)'\n", " ),\n", "\n", " ]\n", ")\n", "fig.update_layout(\n", " xaxis=dict(\n", " title = \"Time Index\",\n", " ),\n", " yaxis=dict(\n", " title = \"y\"\n", " ),\n", " template = \"plotly_dark\"\n", ")" ] }, { "cell_type": "markdown", "id": "052b0194", "metadata": {}, "source": [ "## Non-linear Kalman Filter" ] }, { "cell_type": "code", "execution_count": 50, "id": "6b088d6b", "metadata": {}, "outputs": [], "source": [ "@jit(nopython=True)\n", "def loglik_poisson_numba(s, y):\n", " \"\"\"Poisson Log Likelihood\"\"\"\n", " mu = np.exp(s)\n", " return y * np.log(mu + 1e-30) - mu - math.lgamma(y + 1.0) # numba does not support scipy.special gammaln\n", "\n", "@jit(nopython=True)\n", "def particle_filter_1d_predict_numba(A, Q, x0_mean, x0_std, ys, N=1000, seed=2):\n", " \"\"\"\n", " 1D particle filter.\n", " \n", " Parameters\n", " ----------\n", " A: float\n", " State transition\n", " Q: float\n", " Process covariance\n", " x0_mean: float\n", " Prior mean for the latent state\n", " x0_std: float\n", " Prior standard deviation \n", " ys: np.ndarray\n", " observations\n", " N: int\n", " number of particles\n", " seed: int\n", " rng seed for reproducibility\n", "\n", " Returns\n", " -------\n", " filtered_means: np.ndarray\n", " The filtered mean for the latent state \n", " filtered_vars: np.ndarray\n", " The filtered variance for the latent state\n", " pred_means: np.ndarray\n", " observation predicted mean \n", " \"\"\"\n", " np.random.seed(seed)\n", " T = ys.shape[0]\n", " particles = np.random.normal(x0_mean, x0_std, size=N)\n", " weights = np.ones(N) / N\n", "\n", " filtered_means = np.zeros(T)\n", " filtered_vars = np.zeros(T)\n", " pred_means = np.zeros(T)\n", "\n", " for t in range(T):\n", " y = ys[t]\n", "\n", " # propagate (vectorized)\n", " particles = A * particles + np.random.normal(0, np.sqrt(Q), size=N)\n", "\n", " # update weights\n", " logw = np.zeros(N)\n", " for i in range(N):\n", " logw[i] = loglik_poisson_numba(particles[i], y)\n", " logw = logw - np.max(logw)\n", " weights *= np.exp(logw)\n", " weights /= np.sum(weights) + 1e-12\n", "\n", " # filtered moments\n", " mean_t = np.sum(weights * particles)\n", " var_t = np.sum(weights * (particles - mean_t) ** 2)\n", "\n", " # predictive mean\n", " pred_mean = np.sum(weights * np.exp(particles))\n", "\n", " filtered_means[t] = mean_t\n", " filtered_vars[t] = var_t\n", " pred_means[t] = pred_mean\n", "\n", " # resample (multinomial resampling) because numba doesn't support np.random.choice\n", " cumulative_sum = np.cumsum(weights)\n", " cumulative_sum[-1] = 1.0 # guard against rounding error\n", " indices = np.searchsorted(cumulative_sum, np.random.rand(N))\n", "\n", " particles = particles[indices]\n", " weights = np.ones(N) / N\n", "\n", " return filtered_means, filtered_vars, pred_means" ] }, { "cell_type": "code", "execution_count": 51, "id": "a468c403", "metadata": {}, "outputs": [], "source": [ "# Had to fix the loglikelihood and key to use benchmarker as is\n", "def loglik_poisson_jax(s, y):\n", " \"\"\"Poisson Log Likelihood\"\"\"\n", " mu = jnp.exp(s)\n", " return y * jnp.log(mu + 1e-30) - mu - gammaln(y + 1.0)\n", "\n", "@block\n", "@partial(jax.jit, static_argnums=5)\n", "def particle_filter_1d_predict_jax(\n", " A, Q, x0_mean, x0_std, ys, N=1000,\n", "):\n", " \"\"\"\n", " 1D particle filter.\n", " \n", " Parameters\n", " ----------\n", " A: float\n", " State transition\n", " Q: float\n", " Process covariance\n", " x0_mean: float\n", " Prior mean for the latent state\n", " x0_std: float\n", " Prior standard deviation \n", " ys: np.ndarray\n", " observations\n", " loglik_fn: function\n", " The log likelihood function\n", " key: \n", " JAX prng key\n", " N: int\n", " number of particles\n", "\n", " Returns\n", " -------\n", " filtered_means: jnp.ndarray\n", " The filtered mean for the latent state \n", " filtered_vars: jnp.ndarray\n", " The filtered variance for the latent state\n", " pred_means: jnp.ndarray\n", " observation predicted mean \n", " \"\"\"\n", " key = jax.random.PRNGKey(0)\n", " T = ys.shape[0]\n", " particles = jax.random.normal(key, (N,)) * x0_std + x0_mean # init particles from gaussian priors\n", " weights = jnp.ones(N) / N # particle weights, all particles equally likely prior\n", "\n", " def body_fun(carry, t):\n", " particles, weights, key = carry\n", " y = ys[t]\n", "\n", " # propagate\n", " key, subkey = jax.random.split(key)\n", " particles = A * particles + jax.random.normal(subkey, (N,)) * jnp.sqrt(Q) # state transition model\n", "\n", " # update weights\n", " logw = jax.vmap(lambda x: loglik_poisson_jax(x, y))(particles) # update particles in parallel\n", " logw = logw - jnp.max(logw) # avoid overflow\n", " weights = weights * jnp.exp(logw) # old weights times the likelihood\n", " weights /= jnp.sum(weights) + 1e-12 # normalize so that weights sum to 1\n", "\n", " # filtered moments\n", " mean_t = jnp.sum(weights * particles) # posterior mean of latent state\n", " var_t = jnp.sum(weights * (particles - mean_t)**2) # posterior variance of latent state\n", "\n", " # predictive mean\n", " pred_mean = jnp.sum(weights * jnp.exp(particles))\n", "\n", " # resample to prevent dominant particles\n", " key, subkey = jax.random.split(key)\n", " indices = jax.random.choice(subkey, N, p=weights, shape=(N,))\n", " particles = particles[indices]\n", " weights = jnp.ones(N) / N\n", "\n", " carry = (particles, weights, key)\n", " out = (mean_t, var_t, pred_mean)\n", " return carry, out\n", "\n", " _, outputs = jax.lax.scan(body_fun, (particles, weights, key), jnp.arange(T))\n", " return outputs\n" ] }, { "cell_type": "code", "execution_count": 52, "id": "15fd4a0c", "metadata": {}, "outputs": [], "source": [ "from pytensor.tensor.random.utils import RandomStream\n", "\n", "# Random stream for PyTensor\n", "srng = RandomStream(seed=42)\n", "\n", "# Poisson log-likelihood\n", "def loglik_poisson_pytensor(s, y):\n", " mu = pt.exp(s)\n", " return y.flatten() * pt.log(mu + 1e-30) - mu - pt.gammaln(y.flatten() + 1.0)\n" ] }, { "cell_type": "code", "execution_count": 53, "id": "59525da5", "metadata": {}, "outputs": [], "source": [ "ys_symbolic = pt.vector(\"ys\")\n", "x0_mean_symbolic = pt.scalar(\"x0_mean\")\n", "x0_std_symbolic = pt.scalar(\"x0_std\")\n", "A_symbolic = pt.scalar(\"A\")\n", "Q_symbolic = pt.scalar(\"Q\")\n", "N_symbolic = pt.scalar(\"N\", dtype='int64')\n", "\n", "N_STEPS = 300\n", "\n", "# Initialize particles and weights\n", "particles_init = srng.normal(size=(N_symbolic,)) * x0_std_symbolic + x0_mean_symbolic\n", "weights_init = pt.ones((N_symbolic,)) / N_symbolic \n", "\n", "# Step function for scan\n", "def step(y_t, particles_prev, weights_prev, A_symbolic, Q_symbolic):\n", " # Propagate particles\n", " particles_prop = A_symbolic * particles_prev + srng.normal(size=(N_symbolic,)) * pt.sqrt(Q_symbolic)\n", "\n", " # Update weights\n", " # logw = pt.stack([loglik_poisson_pytensor(p, y_t) for p in particles_prop])\n", " logw = loglik_poisson_pytensor(particles_prop, y_t)\n", " logw_stable = logw - pt.max(logw)\n", " w_unnorm = weights_prev * pt.exp(logw_stable)\n", " w = w_unnorm / (pt.sum(w_unnorm) + 1e-12) \n", "\n", " # Filtered moments\n", " mean_t = pt.sum(w * particles_prop)\n", " var_t = pt.sum(w * (particles_prop - mean_t) ** 2)\n", " pred_mean = pt.sum(w * pt.exp(particles_prop))\n", "\n", " # Resample particles\n", " idx = srng.choice(size=(N_symbolic,), a=N_symbolic, p=w) \n", " particles_resampled = particles_prop[idx]\n", " weights_resampled = pt.ones((N_symbolic,)) / N_symbolic\n", "\n", " # Return flat tuple\n", " return particles_resampled, weights_resampled, mean_t, var_t, pred_mean\n", "\n", "# first two are recurrent, rest are collected\n", "outputs_info = [\n", " particles_init,\n", " weights_init,\n", " None,\n", " None,\n", " None\n", "]\n", "\n", "(particles_seq, weights_seq, means_seq, vars_seq, preds_seq), updates = pytensor.scan(\n", " fn=step,\n", " sequences=[ys_symbolic],\n", " outputs_info=outputs_info,\n", " non_sequences=[A_symbolic, Q_symbolic],\n", " n_steps=N_STEPS\n", ")\n", "\n", "particle_filter_1d_predict_pytensor = pytensor.function(\n", " [A_symbolic, Q_symbolic, x0_mean_symbolic, x0_std_symbolic, ys_symbolic, N_symbolic],\n", " [means_seq, vars_seq, preds_seq],\n", " updates=updates,\n", " no_default_updates=True,\n", " trust_input=True\n", ")\n", "\n", "particle_filter_1d_predict_pytensor_numba = pytensor.function(\n", " [A_symbolic, Q_symbolic, x0_mean_symbolic, x0_std_symbolic, ys_symbolic, N_symbolic],\n", " [means_seq, vars_seq, preds_seq],\n", " updates=updates,\n", " no_default_updates=True,\n", " mode=\"NUMBA\", \n", " trust_input=True\n", ")" ] }, { "cell_type": "code", "execution_count": 54, "id": "c033a1d0", "metadata": {}, "outputs": [], "source": [ "key = jax.random.PRNGKey(0)\n", "T = 300\n", "A = 0.95\n", "Q = 0.05\n", "rng = np.random.RandomState(1)\n", "\n", "target_mean = 10.0\n", "latent_var = Q / (1 - A**2)\n", "x0_mean = np.log(target_mean) - 0.5 * latent_var\n", "x0_std = 1.0\n", "\n", "# Simulate latent\n", "x = np.zeros(T)\n", "x[0] = rng.normal() * np.sqrt(latent_var) + x0_mean\n", "for t in range(1, T):\n", " x[t] = A * x[t-1] + rng.normal() * np.sqrt(Q)\n", "\n", "ys = np.array(rng.poisson(np.exp(x)), dtype=np.float32)" ] }, { "cell_type": "code", "execution_count": 55, "id": "2a9cbfa5", "metadata": {}, "outputs": [], "source": [ "nonlinear_kalman_filter_bench = Benchmarker(\n", " functions=[particle_filter_1d_predict_pytensor, particle_filter_1d_predict_numba, particle_filter_1d_predict_jax, particle_filter_1d_predict_pytensor_numba,], \n", " names=['particle_filter_1d_predict_pytensor', 'particle_filter_1d_predict_numba', 'particle_filter_1d_predict_jax', 'particle_filter_1d_predict_pytensor_numba',],\n", " number=5 # This takes a while to run reducing number of loops\n", ")" ] }, { "cell_type": "code", "execution_count": 56, "id": "c782c42b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LoopsMin (us)Max (us)Mean (us)StdDev (us)Median (us)IQR (us)OPS (Kops/s)Samples
particle_filter_1d_predict_pytensorkalman_filter_inputs5728815.783397742948.366801734642.3867615002.169131733596.2500026332.2000030.0013615
particle_filter_1d_predict_numbakalman_filter_inputs548678.40840048904.82500248782.92164089.83533848744.633398160.5250000.0204995
particle_filter_1d_predict_jaxkalman_filter_inputs533170.26660133644.35000133410.141721155.13423133411.583403125.9249980.0299315
particle_filter_1d_predict_pytensor_numbakalman_filter_inputs5676612.958201678200.574999677432.478319647.987321677334.9249971278.4665970.0014765
\n", "
" ], "text/plain": [ " Loops \\\n", "particle_filter_1d_predict_pytensor kalman_filter_inputs 5 \n", "particle_filter_1d_predict_numba kalman_filter_inputs 5 \n", "particle_filter_1d_predict_jax kalman_filter_inputs 5 \n", "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 5 \n", "\n", " Min (us) \\\n", "particle_filter_1d_predict_pytensor kalman_filter_inputs 728815.783397 \n", "particle_filter_1d_predict_numba kalman_filter_inputs 48678.408400 \n", "particle_filter_1d_predict_jax kalman_filter_inputs 33170.266601 \n", "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 676612.958201 \n", "\n", " Max (us) \\\n", "particle_filter_1d_predict_pytensor kalman_filter_inputs 742948.366801 \n", "particle_filter_1d_predict_numba kalman_filter_inputs 48904.825002 \n", "particle_filter_1d_predict_jax kalman_filter_inputs 33644.350001 \n", "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 678200.574999 \n", "\n", " Mean (us) \\\n", "particle_filter_1d_predict_pytensor kalman_filter_inputs 734642.386761 \n", "particle_filter_1d_predict_numba kalman_filter_inputs 48782.921640 \n", "particle_filter_1d_predict_jax kalman_filter_inputs 33410.141721 \n", "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 677432.478319 \n", "\n", " StdDev (us) \\\n", "particle_filter_1d_predict_pytensor kalman_filter_inputs 5002.169131 \n", "particle_filter_1d_predict_numba kalman_filter_inputs 89.835338 \n", "particle_filter_1d_predict_jax kalman_filter_inputs 155.134231 \n", "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 647.987321 \n", "\n", " Median (us) \\\n", "particle_filter_1d_predict_pytensor kalman_filter_inputs 733596.250002 \n", "particle_filter_1d_predict_numba kalman_filter_inputs 48744.633398 \n", "particle_filter_1d_predict_jax kalman_filter_inputs 33411.583403 \n", "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 677334.924997 \n", "\n", " IQR (us) \\\n", "particle_filter_1d_predict_pytensor kalman_filter_inputs 6332.200003 \n", "particle_filter_1d_predict_numba kalman_filter_inputs 160.525000 \n", "particle_filter_1d_predict_jax kalman_filter_inputs 125.924998 \n", "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 1278.466597 \n", "\n", " OPS (Kops/s) \\\n", "particle_filter_1d_predict_pytensor kalman_filter_inputs 0.001361 \n", "particle_filter_1d_predict_numba kalman_filter_inputs 0.020499 \n", "particle_filter_1d_predict_jax kalman_filter_inputs 0.029931 \n", "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 0.001476 \n", "\n", " Samples \n", "particle_filter_1d_predict_pytensor kalman_filter_inputs 5 \n", "particle_filter_1d_predict_numba kalman_filter_inputs 5 \n", "particle_filter_1d_predict_jax kalman_filter_inputs 5 \n", "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 5 " ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nonlinear_kalman_filter_bench.run(\n", " inputs={\n", " \"kalman_filter_inputs\": {\"A\": A, \"Q\": Q, \"x0_mean\": x0_mean, \"x0_std\": x0_std, \"ys\": ys, \"N\": 2000},\n", " }\n", ")\n", "nonlinear_kalman_filter_bench.summary()" ] }, { "cell_type": "markdown", "id": "29fe0cc4", "metadata": {}, "source": [ "Slightly different estimates because I couldn't reproduce 1:1 " ] }, { "cell_type": "code", "execution_count": 57, "id": "2bd23897", "metadata": {}, "outputs": [], "source": [ "filtered_means, filtered_vars, pred_means = particle_filter_1d_predict_numba(\n", " A, Q, x0_mean, x0_std, ys, N=2000, seed=2\n", ")" ] }, { "cell_type": "code", "execution_count": 58, "id": "7caac5a5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = make_subplots(\n", " rows=2, cols=1,\n", " subplot_titles=(\"Observation Predictions\", \"Latent State Estimation\"),\n", " vertical_spacing=0.07,\n", " shared_xaxes=True\n", ")\n", "\n", "fig.add_traces(\n", " [\n", " go.Scatter(\n", " x = np.arange(T),\n", " y = ys,\n", " mode = \"markers\",\n", " marker_color = \"cornflowerblue\",\n", " name = \"actuals\"\n", " ),\n", " go.Scatter(\n", " x = np.arange(T),\n", " y = pred_means,\n", " mode = \"lines\",\n", " marker_color = \"#eb8c34\",\n", " name = \"predicted mean\"\n", " ),\n", " go.Scatter(\n", " name=\"\", \n", " x=np.arange(T), \n", " y=pred_means + 2*jnp.sqrt(pred_means), \n", " mode=\"lines\", \n", " marker=dict(color=\"#eb8c34\"), \n", " line=dict(width=0), \n", " legendgroup=\"predicted mean 95% CI\",\n", " showlegend=False\n", " ),\n", " go.Scatter(\n", " name=\"predicted mean 95% CI\", \n", " x=np.arange(T), \n", " y=pred_means - 2*jnp.sqrt(pred_means), \n", " mode=\"lines\", marker=dict(color=\"#eb8c34\"), \n", " line=dict(width=0), \n", " legendgroup=\"predicted mean 95% CI\", \n", " fill='tonexty', \n", " fillcolor='rgba(235, 140, 52, 0.2)'\n", " ),\n", " ],\n", " rows=1, cols=1\n", ")\n", "\n", "fig.add_traces(\n", " [\n", " go.Scatter(\n", " x = np.arange(T),\n", " y = x,\n", " mode = \"lines\",\n", " marker_color = \"cornflowerblue\",\n", " name = \"true latent state\"\n", " ),\n", " go.Scatter(\n", " x = np.arange(T),\n", " y = filtered_means,\n", " mode = \"lines\",\n", " marker_color = \"#eb8c34\",\n", " name = \"filtered state mean\"\n", " ),\n", " go.Scatter(\n", " name=\"\", \n", " x=np.arange(T), \n", " y=filtered_means + 2*jnp.sqrt(filtered_vars), \n", " mode=\"lines\", \n", " marker=dict(color=\"#eb8c34\"), \n", " line=dict(width=0), \n", " legendgroup=\"filtered state mean 95% CI\",\n", " showlegend=False\n", " ),\n", " go.Scatter(\n", " name=\"filtered state mean 95% CI\", \n", " x=np.arange(T), \n", " y=filtered_means - 2*jnp.sqrt(filtered_vars), \n", " mode=\"lines\", marker=dict(color=\"#eb8c34\"), \n", " line=dict(width=0), \n", " legendgroup=\"filtered state mean 95% CI\", \n", " fill='tonexty', \n", " fillcolor='rgba(235, 140, 52, 0.2)'\n", " ),\n", " ],\n", " rows=2, cols=1\n", ")\n", "\n", "for i, yaxis in enumerate(fig.select_yaxes(), 1):\n", " legend_name = f\"legend{i}\"\n", " fig.update_layout({legend_name: dict(y=yaxis.domain[1], yanchor=\"top\")}, showlegend=True)\n", " fig.update_traces(row=i, legend=legend_name)\n", "\n", "fig.update_layout(height=1000, width=1200, template=\"plotly_dark\")\n", "\n", "fig.update_layout(\n", " legend1=dict(\n", " yanchor=\"top\",\n", " y=1.0,\n", " xanchor=\"left\",\n", " x=0,\n", " orientation=\"h\"\n", " ),\n", " legend2=dict(\n", " yanchor=\"top\",\n", " y=.465,\n", " xanchor=\"left\",\n", " x=0,\n", " orientation=\"h\"\n", " ),\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "48ccc984", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "pytensor-dev", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.7" } }, "nbformat": 4, "nbformat_minor": 5 }