import pytensor
import pytensor.tensor as pt
import numpy as np
N_STEPS = 1000
b_symbolic = pt.scalar("b_symbolic", dtype="int32")
def step(a, b):
return a + b, a
(outputs_a, outputs_b), _ = pytensor.scan(
fn=step,
outputs_info=[pt.constant(1, dtype="int32"), b_symbolic],
n_steps=N_STEPS
)
# compile function returning final a
fibonacci_pytensor = pytensor.function([b_symbolic], outputs_a[-1], trust_input=True)
fibonacci_pytensor_numba = pytensor.function([b_symbolic], outputs_a[-1], mode='NUMBA', trust_input=True)
import numba
@numba.njit
def fibonacci_numba_scalar(b):
b = b.copy()
a = np.ones((), dtype=np.int32)
for _ in range(N_STEPS):
a[()], b[()] = a[()] + b[()], a[()]
return a
@numba.njit
def fibonacci_numba_array(b):
a = np.ones((), dtype=np.int32)
for _ in range(N_STEPS):
a, b = np.asarray(a + b), a
return a
b = np.ones((), dtype=np.int32)
assert fibonacci_numba_array(b) == fibonacci_numba_scalar(b)
%timeit fibonacci_numba_scalar(b)
3.21 μs ± 20.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
%timeit fibonacci_numba_array(b)
32.8 μs ± 2.48 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
assert fibonacci_pytensor(b) == fibonacci_numba_scalar(b)
assert fibonacci_pytensor_numba(b) == fibonacci_numba_scalar(b)
%timeit fibonacci_pytensor(b)
2.49 ms ± 327 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit fibonacci_pytensor_numba(b)
175 μs ± 6.13 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit fibonacci_pytensor_numba.vm.jit_fn(b)
158 μs ± 706 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
fibonacci_pytensor_numba.dprint(print_type=True, print_memory_map=True)
Subtensor{i} [id A] <Scalar(int32, shape=())> v={0: [0]} 6
├─ Scan{scan_fn, while_loop=False, inplace=all}.0 [id B] <Vector(int32, shape=(?,))> d={0: [1], 1: [2]} 5
│ ├─ 1000 [id C] <Scalar(int16, shape=())>
│ ├─ SetSubtensor{:stop} [id D] <Vector(int32, shape=(1,))> d={0: [0]} 4
│ │ ├─ AllocEmpty{dtype='int32'} [id E] <Vector(int32, shape=(1,))> 3
│ │ │ └─ 1 [id F] <Scalar(int64, shape=())>
│ │ ├─ [1] [id G] <Vector(int32, shape=(1,))>
│ │ └─ 1 [id H] <int64>
│ └─ SetSubtensor{:stop} [id I] <Vector(int32, shape=(2,))> d={0: [0]} 2
│ ├─ AllocEmpty{dtype='int32'} [id J] <Vector(int32, shape=(2,))> 1
│ │ └─ 2 [id K] <Scalar(int64, shape=())>
│ ├─ ExpandDims{axis=0} [id L] <Vector(int32, shape=(1,))> v={0: [0]} 0
│ │ └─ b_symbolic [id M] <Scalar(int32, shape=())>
│ └─ 1 [id H] <int64>
└─ 0 [id N] <uint8>
Inner graphs:
Scan{scan_fn, while_loop=False, inplace=all} [id B] d={0: [1], 1: [2]}
← Add [id O] <Scalar(int32, shape=())>
├─ *0-<Scalar(int32, shape=())> [id P] <Scalar(int32, shape=())> -> [id D]
└─ *1-<Scalar(int32, shape=())> [id Q] <Scalar(int32, shape=())> -> [id I]
← *0-<Scalar(int32, shape=())> [id P] <Scalar(int32, shape=())> -> [id D]
<ipykernel.iostream.OutStream at 0x7fa4fbfacdf0>
print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__source__)
def numba_funcified_fgraph(b_symbolic):
# ExpandDims{axis=0}(b_symbolic)
tensor_variable = dimshuffle(b_symbolic)
# AllocEmpty{dtype='int32'}(2)
tensor_variable_1 = allocempty(tensor_constant)
# SetSubtensor{:stop}(AllocEmpty{dtype='int32'}.0, ExpandDims{axis=0}.0, 1)
tensor_variable_2 = set_subtensor(tensor_variable_1, tensor_variable, scalar_constant)
# AllocEmpty{dtype='int32'}(1)
tensor_variable_3 = allocempty_1(tensor_constant_1)
# SetSubtensor{:stop}(AllocEmpty{dtype='int32'}.0, [1], 1)
tensor_variable_4 = set_subtensor_1(tensor_variable_3, tensor_constant_2, scalar_constant)
# Scan{scan_fn, while_loop=False, inplace=all}(1000, SetSubtensor{:stop}.0, SetSubtensor{:stop}.0)
tensor_variable_5, tensor_variable_6 = scan(tensor_constant_3, tensor_variable_4, tensor_variable_2)
# Subtensor{i}(Scan{scan_fn, while_loop=False, inplace=all}.0, 0)
tensor_variable_7 = subtensor(tensor_variable_5, scalar_constant_1)
return (tensor_variable_7,)
print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__globals__["allocempty"].py_func.__source__)
def allocempty(tensor_constant):
tensor_constant_item = to_scalar(tensor_constant)
scalar_shape = (tensor_constant_item, )
return np.empty(scalar_shape, dtype)
print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__globals__["set_subtensor"].py_func.__source__)
def set_subtensor(tensor_variable, tensor_variable_1, scalar_constant):
z = tensor_variable
indices = (slice(None, scalar_constant, None),)
z[indices] = tensor_variable_1
return np.asarray(z)
print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__globals__["scan"].py_func.__source__)
def scan(n_steps, outer_in_1, outer_in_2):
outer_in_1_len = outer_in_1.shape[0]
outer_in_1_sitsot_storage = outer_in_1
outer_in_2_len = outer_in_2.shape[0]
outer_in_2_sitsot_storage = outer_in_2
outer_in_1_sitsot_storage_temp_scalar_0 = np.empty((), dtype=np.int32)
outer_in_2_sitsot_storage_temp_scalar_0 = np.empty((), dtype=np.int32)
i = 0
cond = np.array(False)
while i < n_steps and not cond.item():
outer_in_1_sitsot_storage_temp_scalar_0[()] = outer_in_1_sitsot_storage[(i) % outer_in_1_len]
outer_in_2_sitsot_storage_temp_scalar_0[()] = outer_in_2_sitsot_storage[(i) % outer_in_2_len]
(inner_out_0, inner_out_1) = scan_inner_func(outer_in_1_sitsot_storage_temp_scalar_0, outer_in_2_sitsot_storage_temp_scalar_0)
outer_in_1_sitsot_storage[(i + 1) % outer_in_1_len] = inner_out_0
outer_in_2_sitsot_storage[(i + 1) % outer_in_2_len] = inner_out_1
i += 1
if 1 < outer_in_1_len < (i + 1):
outer_in_1_sitsot_storage_shift = (i + 1) % (outer_in_1_len)
if outer_in_1_sitsot_storage_shift > 0:
outer_in_1_sitsot_storage_left = outer_in_1_sitsot_storage[:outer_in_1_sitsot_storage_shift]
outer_in_1_sitsot_storage_right = outer_in_1_sitsot_storage[outer_in_1_sitsot_storage_shift:]
outer_in_1_sitsot_storage = np.concatenate((outer_in_1_sitsot_storage_right, outer_in_1_sitsot_storage_left))
if 1 < outer_in_2_len < (i + 1):
outer_in_2_sitsot_storage_shift = (i + 1) % (outer_in_2_len)
if outer_in_2_sitsot_storage_shift > 0:
outer_in_2_sitsot_storage_left = outer_in_2_sitsot_storage[:outer_in_2_sitsot_storage_shift]
outer_in_2_sitsot_storage_right = outer_in_2_sitsot_storage[outer_in_2_sitsot_storage_shift:]
outer_in_2_sitsot_storage = np.concatenate((outer_in_2_sitsot_storage_right, outer_in_2_sitsot_storage_left))
return outer_in_1_sitsot_storage, outer_in_2_sitsot_storage
from pytensor.link.numba.dispatch.basic import tuple_setitem, to_scalar
@numba.njit
def allocempty(s):
s_item = to_scalar(s)
scalar_shape = (s_item,)
return np.empty(scalar_shape, dtype=np.int32)
@numba.njit
def subtensor(x, idx):
indices = (idx,)
z = x[indices]
return np.asarray(z)
@numba.njit
def inner_scan_func(a, b):
res = a + b
return res, a
@numba.njit
def scan_fib(n_steps, a_buf, b_buf):
a_buf_len = a_buf.shape[0]
b_buf_len = b_buf.shape[0]
tmp_a_scalar = np.empty((), dtype=np.int32)
tmp_b_scalar = np.empty((), dtype=np.int32)
i = 0
while i < n_steps:
tmp_a_scalar[()] = a_buf[i % a_buf_len]
tmp_b_scalar[()] = b_buf[i % b_buf_len]
next_a, next_b = inner_scan_func(tmp_a_scalar, tmp_b_scalar)
a_buf[(i + 1) % a_buf_len] = next_a
b_buf[(i + 1) % b_buf_len] = next_b
i += 1
if 1 < a_buf_len < (i + 1):
a_buf_shift = (i + 1) % a_buf_len
if a_buf_shift > 0:
a_buf = np.concatenate((a_buf[a_buf_shift:], a_buf[:a_buf_shift]))
if 1 < b_buf_len < (i + 1):
b_buf_shift = (i + 1) % b_buf_len
if b_buf_shift > 0:
b_buf = np.concatenate((b_buf[b_buf_shift:], b_buf[:b_buf_shift]))
return a_buf, b_buf
@numba.njit
def set_subtensor(x, y, idx):
indices = (slice(None, idx, None),)
x[indices] = y
return np.asarray(x)
@numba.njit
def dimshuffle(x):
old_shape = x.shape
old_strides = x.strides
new_shape = (1,)
new_strides = (0,)
new_order = (-1,)
for i, o in enumerate(new_order):
if o != -1:
new_shape = tuple_setitem(new_shape, i, old_shape[o])
new_strides = tuple_setitem(new_strides, i, old_strides[o])
return np.lib.stride_tricks.as_strided(x, shape=new_shape, strides=new_strides)
# return np.expand_dims(x, axis=0)
@numba.njit
def comparable_fibonacci_numba(b):
a_buf = allocempty(np.array(1, dtype=np.int64))
# a_buf = np.empty(1, dtype=np.int32)
# a_buf[:1] = np.array([1], dtype=np.int32)
a_buf_set = set_subtensor(a_buf, np.array([1], dtype=np.int32), np.int64(1))
b_buf = allocempty(np.array(2, dtype=np.int64))
# b_buf = np.empty(2, dtype=np.int32)
# b_buf[:1] = np.expand_dims(b, axis=0)
b_expanded = dimshuffle(b)
b_buf_set = set_subtensor(b_buf, b_expanded, np.int64(1))
a_buf_updated, b_buf_updated = scan_fib(np.array(N_STEPS, np.int64), a_buf_set, b_buf_set)
res = subtensor(a_buf_updated, np.uint8(0))
return (res,)
b = np.ones((), dtype=np.int32)
assert comparable_fibonacci_numba(b) == fibonacci_numba_scalar(b)
%timeit comparable_fibonacci_numba(b)
54.6 μs ± 1.28 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)