{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2025-10-07T10:20:23.561430Z", "start_time": "2025-10-07T10:20:21.620124Z" } }, "source": [ "import pytensor\n", "import pytensor.tensor as pt\n", "import numpy as np\n", "\n", "N_STEPS = 1000\n", "\n", "b_symbolic = pt.scalar(\"b_symbolic\", dtype=\"int32\")\n", "\n", "def step(a, b):\n", " return a + b, a\n", "\n", "(outputs_a, outputs_b), _ = pytensor.scan(\n", " fn=step,\n", " outputs_info=[pt.constant(1, dtype=\"int32\"), b_symbolic],\n", " n_steps=N_STEPS\n", ")\n", "\n", "# compile function returning final a\n", "fibonacci_pytensor = pytensor.function([b_symbolic], outputs_a[-1], trust_input=True)\n", "fibonacci_pytensor_numba = pytensor.function([b_symbolic], outputs_a[-1], mode='NUMBA', trust_input=True)" ], "outputs": [], "execution_count": 1 }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:32:18.289971Z", "start_time": "2025-10-07T10:32:18.284515Z" } }, "cell_type": "code", "source": [ "import numba\n", "\n", "@numba.njit\n", "def fibonacci_numba_scalar(b):\n", " b = b.copy()\n", " a = np.ones((), dtype=np.int32)\n", " for _ in range(N_STEPS):\n", " a[()], b[()] = a[()] + b[()], a[()]\n", " return a\n", "\n", "@numba.njit\n", "def fibonacci_numba_array(b):\n", " a = np.ones((), dtype=np.int32)\n", " for _ in range(N_STEPS):\n", " a, b = np.asarray(a + b), a\n", " return a" ], "id": "b1d657d366647ada", "outputs": [], "execution_count": 66 }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:32:19.607842Z", "start_time": "2025-10-07T10:32:19.423324Z" } }, "cell_type": "code", "source": [ "b = np.ones((), dtype=np.int32)\n", "assert fibonacci_numba_array(b) == fibonacci_numba_scalar(b)" ], "id": "7f45c87d259852e6", "outputs": [], "execution_count": 67 }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:32:22.705191Z", "start_time": "2025-10-07T10:32:20.090353Z" } }, "cell_type": "code", "source": "%timeit fibonacci_numba_scalar(b)", "id": "b01c8978960c6e3d", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.21 μs ± 20.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "execution_count": 68 }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:32:25.876514Z", "start_time": "2025-10-07T10:32:23.122275Z" } }, "cell_type": "code", "source": "%timeit fibonacci_numba_array(b)", "id": "bfc8794b219db03e", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32.8 μs ± 2.48 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "execution_count": 69 }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": [ "assert fibonacci_pytensor(b) == fibonacci_numba_scalar(b)\n", "assert fibonacci_pytensor_numba(b) == fibonacci_numba_scalar(b)" ], "id": "a2185c1de1297a11" }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:29:44.724064Z", "start_time": "2025-10-07T10:29:42.655693Z" } }, "cell_type": "code", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.49 ms ± 327 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "execution_count": 54, "source": "%timeit fibonacci_pytensor(b)", "id": "f1e8bb6a0c673c8f" }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:29:58.922566Z", "start_time": "2025-10-07T10:29:44.752331Z" } }, "cell_type": "code", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "175 μs ± 6.13 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "execution_count": 55, "source": "%timeit fibonacci_pytensor_numba(b)", "id": "17cd2859b4c6d3bd" }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:30:11.832294Z", "start_time": "2025-10-07T10:29:59.016709Z" } }, "cell_type": "code", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "158 μs ± 706 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "execution_count": 56, "source": "%timeit fibonacci_pytensor_numba.vm.jit_fn(b)", "id": "6deb056f63953a42" }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:20:58.015849Z", "start_time": "2025-10-07T10:20:58.006831Z" } }, "cell_type": "code", "source": "fibonacci_pytensor_numba.dprint(print_type=True, print_memory_map=True)", "id": "17580448648fdbcf", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Subtensor{i} [id A] v={0: [0]} 6\n", " ├─ Scan{scan_fn, while_loop=False, inplace=all}.0 [id B] d={0: [1], 1: [2]} 5\n", " │ ├─ 1000 [id C] \n", " │ ├─ SetSubtensor{:stop} [id D] d={0: [0]} 4\n", " │ │ ├─ AllocEmpty{dtype='int32'} [id E] 3\n", " │ │ │ └─ 1 [id F] \n", " │ │ ├─ [1] [id G] \n", " │ │ └─ 1 [id H] \n", " │ └─ SetSubtensor{:stop} [id I] d={0: [0]} 2\n", " │ ├─ AllocEmpty{dtype='int32'} [id J] 1\n", " │ │ └─ 2 [id K] \n", " │ ├─ ExpandDims{axis=0} [id L] v={0: [0]} 0\n", " │ │ └─ b_symbolic [id M] \n", " │ └─ 1 [id H] \n", " └─ 0 [id N] \n", "\n", "Inner graphs:\n", "\n", "Scan{scan_fn, while_loop=False, inplace=all} [id B] d={0: [1], 1: [2]}\n", " ← Add [id O] \n", " ├─ *0- [id P] -> [id D]\n", " └─ *1- [id Q] -> [id I]\n", " ← *0- [id P] -> [id D]\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 8 }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:20:58.063585Z", "start_time": "2025-10-07T10:20:58.059985Z" } }, "cell_type": "code", "source": "print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__source__)", "id": "f9806651f5146bbd", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def numba_funcified_fgraph(b_symbolic):\n", " # ExpandDims{axis=0}(b_symbolic)\n", " tensor_variable = dimshuffle(b_symbolic)\n", " # AllocEmpty{dtype='int32'}(2)\n", " tensor_variable_1 = allocempty(tensor_constant)\n", " # SetSubtensor{:stop}(AllocEmpty{dtype='int32'}.0, ExpandDims{axis=0}.0, 1)\n", " tensor_variable_2 = set_subtensor(tensor_variable_1, tensor_variable, scalar_constant)\n", " # AllocEmpty{dtype='int32'}(1)\n", " tensor_variable_3 = allocempty_1(tensor_constant_1)\n", " # SetSubtensor{:stop}(AllocEmpty{dtype='int32'}.0, [1], 1)\n", " tensor_variable_4 = set_subtensor_1(tensor_variable_3, tensor_constant_2, scalar_constant)\n", " # Scan{scan_fn, while_loop=False, inplace=all}(1000, SetSubtensor{:stop}.0, SetSubtensor{:stop}.0)\n", " tensor_variable_5, tensor_variable_6 = scan(tensor_constant_3, tensor_variable_4, tensor_variable_2)\n", " # Subtensor{i}(Scan{scan_fn, while_loop=False, inplace=all}.0, 0)\n", " tensor_variable_7 = subtensor(tensor_variable_5, scalar_constant_1)\n", " return (tensor_variable_7,)\n" ] } ], "execution_count": 9 }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:20:58.113352Z", "start_time": "2025-10-07T10:20:58.109693Z" } }, "cell_type": "code", "source": "print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__globals__[\"allocempty\"].py_func.__source__)", "id": "9995392081dcbffb", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "def allocempty(tensor_constant):\n", " tensor_constant_item = to_scalar(tensor_constant)\n", " scalar_shape = (tensor_constant_item, )\n", " return np.empty(scalar_shape, dtype)\n", " \n" ] } ], "execution_count": 10 }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:20:58.162062Z", "start_time": "2025-10-07T10:20:58.158525Z" } }, "cell_type": "code", "source": "print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__globals__[\"set_subtensor\"].py_func.__source__)", "id": "1f89dfa8b172fde9", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "def set_subtensor(tensor_variable, tensor_variable_1, scalar_constant):\n", " z = tensor_variable\n", " indices = (slice(None, scalar_constant, None),)\n", " z[indices] = tensor_variable_1\n", " return np.asarray(z)\n", " \n" ] } ], "execution_count": 11 }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:20:58.211349Z", "start_time": "2025-10-07T10:20:58.207479Z" } }, "cell_type": "code", "source": "print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__globals__[\"scan\"].py_func.__source__)", "id": "648cd8952121141b", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "def scan(n_steps, outer_in_1, outer_in_2):\n", "\n", " outer_in_1_len = outer_in_1.shape[0]\n", " outer_in_1_sitsot_storage = outer_in_1\n", " outer_in_2_len = outer_in_2.shape[0]\n", " outer_in_2_sitsot_storage = outer_in_2\n", "\n", " outer_in_1_sitsot_storage_temp_scalar_0 = np.empty((), dtype=np.int32)\n", " outer_in_2_sitsot_storage_temp_scalar_0 = np.empty((), dtype=np.int32)\n", "\n", " i = 0\n", " cond = np.array(False)\n", " while i < n_steps and not cond.item():\n", " outer_in_1_sitsot_storage_temp_scalar_0[()] = outer_in_1_sitsot_storage[(i) % outer_in_1_len]\n", " outer_in_2_sitsot_storage_temp_scalar_0[()] = outer_in_2_sitsot_storage[(i) % outer_in_2_len]\n", "\n", " (inner_out_0, inner_out_1) = scan_inner_func(outer_in_1_sitsot_storage_temp_scalar_0, outer_in_2_sitsot_storage_temp_scalar_0)\n", "\n", " outer_in_1_sitsot_storage[(i + 1) % outer_in_1_len] = inner_out_0\n", " outer_in_2_sitsot_storage[(i + 1) % outer_in_2_len] = inner_out_1\n", " i += 1\n", "\n", " if 1 < outer_in_1_len < (i + 1):\n", " outer_in_1_sitsot_storage_shift = (i + 1) % (outer_in_1_len)\n", " if outer_in_1_sitsot_storage_shift > 0:\n", " outer_in_1_sitsot_storage_left = outer_in_1_sitsot_storage[:outer_in_1_sitsot_storage_shift]\n", " outer_in_1_sitsot_storage_right = outer_in_1_sitsot_storage[outer_in_1_sitsot_storage_shift:]\n", " outer_in_1_sitsot_storage = np.concatenate((outer_in_1_sitsot_storage_right, outer_in_1_sitsot_storage_left))\n", " if 1 < outer_in_2_len < (i + 1):\n", " outer_in_2_sitsot_storage_shift = (i + 1) % (outer_in_2_len)\n", " if outer_in_2_sitsot_storage_shift > 0:\n", " outer_in_2_sitsot_storage_left = outer_in_2_sitsot_storage[:outer_in_2_sitsot_storage_shift]\n", " outer_in_2_sitsot_storage_right = outer_in_2_sitsot_storage[outer_in_2_sitsot_storage_shift:]\n", " outer_in_2_sitsot_storage = np.concatenate((outer_in_2_sitsot_storage_right, outer_in_2_sitsot_storage_left))\n", "\n", " return outer_in_1_sitsot_storage, outer_in_2_sitsot_storage\n", " \n" ] } ], "execution_count": 12 }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:30:25.495418Z", "start_time": "2025-10-07T10:30:25.481386Z" } }, "cell_type": "code", "source": [ "from pytensor.link.numba.dispatch.basic import tuple_setitem, to_scalar\n", "\n", "@numba.njit\n", "def allocempty(s):\n", " s_item = to_scalar(s)\n", " scalar_shape = (s_item,)\n", " return np.empty(scalar_shape, dtype=np.int32)\n", "\n", "@numba.njit\n", "def subtensor(x, idx):\n", " indices = (idx,)\n", " z = x[indices]\n", " return np.asarray(z)\n", "\n", "@numba.njit\n", "def inner_scan_func(a, b):\n", " res = a + b\n", " return res, a\n", "\n", "@numba.njit\n", "def scan_fib(n_steps, a_buf, b_buf):\n", " a_buf_len = a_buf.shape[0]\n", " b_buf_len = b_buf.shape[0]\n", "\n", " tmp_a_scalar = np.empty((), dtype=np.int32)\n", " tmp_b_scalar = np.empty((), dtype=np.int32)\n", "\n", " i = 0\n", " while i < n_steps:\n", " tmp_a_scalar[()] = a_buf[i % a_buf_len]\n", " tmp_b_scalar[()] = b_buf[i % b_buf_len]\n", " next_a, next_b = inner_scan_func(tmp_a_scalar, tmp_b_scalar)\n", " a_buf[(i + 1) % a_buf_len] = next_a\n", " b_buf[(i + 1) % b_buf_len] = next_b\n", " i += 1\n", "\n", " if 1 < a_buf_len < (i + 1):\n", " a_buf_shift = (i + 1) % a_buf_len\n", " if a_buf_shift > 0:\n", " a_buf = np.concatenate((a_buf[a_buf_shift:], a_buf[:a_buf_shift]))\n", " if 1 < b_buf_len < (i + 1):\n", " b_buf_shift = (i + 1) % b_buf_len\n", " if b_buf_shift > 0:\n", " b_buf = np.concatenate((b_buf[b_buf_shift:], b_buf[:b_buf_shift]))\n", "\n", " return a_buf, b_buf\n", "\n", "@numba.njit\n", "def set_subtensor(x, y, idx):\n", " indices = (slice(None, idx, None),)\n", " x[indices] = y\n", " return np.asarray(x)\n", "\n", "@numba.njit\n", "def dimshuffle(x):\n", " old_shape = x.shape\n", " old_strides = x.strides\n", "\n", " new_shape = (1,)\n", " new_strides = (0,)\n", " new_order = (-1,)\n", " for i, o in enumerate(new_order):\n", " if o != -1:\n", " new_shape = tuple_setitem(new_shape, i, old_shape[o])\n", " new_strides = tuple_setitem(new_strides, i, old_strides[o])\n", "\n", " return np.lib.stride_tricks.as_strided(x, shape=new_shape, strides=new_strides)\n", " # return np.expand_dims(x, axis=0)\n", "\n", "@numba.njit\n", "def comparable_fibonacci_numba(b):\n", " a_buf = allocempty(np.array(1, dtype=np.int64))\n", " # a_buf = np.empty(1, dtype=np.int32)\n", " # a_buf[:1] = np.array([1], dtype=np.int32)\n", " a_buf_set = set_subtensor(a_buf, np.array([1], dtype=np.int32), np.int64(1))\n", "\n", " b_buf = allocempty(np.array(2, dtype=np.int64))\n", " # b_buf = np.empty(2, dtype=np.int32)\n", " # b_buf[:1] = np.expand_dims(b, axis=0)\n", " b_expanded = dimshuffle(b)\n", " b_buf_set = set_subtensor(b_buf, b_expanded, np.int64(1))\n", "\n", " a_buf_updated, b_buf_updated = scan_fib(np.array(N_STEPS, np.int64), a_buf_set, b_buf_set)\n", "\n", " res = subtensor(a_buf_updated, np.uint8(0))\n", "\n", " return (res,)" ], "id": "bcefae049d4d2540", "outputs": [], "execution_count": 59 }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:30:31.559493Z", "start_time": "2025-10-07T10:30:30.263832Z" } }, "cell_type": "code", "source": [ "b = np.ones((), dtype=np.int32)\n", "assert comparable_fibonacci_numba(b) == fibonacci_numba_scalar(b)" ], "id": "65887ebba21f46c3", "outputs": [], "execution_count": 60 }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:30:35.999409Z", "start_time": "2025-10-07T10:30:31.567997Z" } }, "cell_type": "code", "source": "%timeit comparable_fibonacci_numba(b)", "id": "2e0aba9917097009", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "54.6 μs ± 1.28 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "execution_count": 61 }, { "metadata": { "ExecuteTime": { "end_time": "2025-10-07T10:21:06.095536Z", "start_time": "2025-10-07T10:21:06.093418Z" } }, "cell_type": "code", "source": "", "id": "f7e8c11b24c366f1", "outputs": [], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }