import pytensor
import pytensor.tensor as pt
import numpy as np

N_STEPS = 1000

b_symbolic = pt.scalar("b_symbolic", dtype="int32")

def step(a, b):
    return a + b, a

(outputs_a, outputs_b), _ = pytensor.scan(
    fn=step,
    outputs_info=[pt.constant(1, dtype="int32"), b_symbolic],
    n_steps=N_STEPS
)

# compile function returning final a
fibonacci_pytensor = pytensor.function([b_symbolic], outputs_a[-1], trust_input=True)
fibonacci_pytensor_numba = pytensor.function([b_symbolic], outputs_a[-1], mode='NUMBA', trust_input=True)
import numba

@numba.njit
def fibonacci_numba_scalar(b):
    b = b.copy()
    a = np.ones((), dtype=np.int32)
    for _ in range(N_STEPS):
        a[()], b[()] = a[()] + b[()], a[()]
    return a

@numba.njit
def fibonacci_numba_array(b):
    a = np.ones((), dtype=np.int32)
    for _ in range(N_STEPS):
        a, b = np.asarray(a + b), a
    return a
b = np.ones((), dtype=np.int32)
assert fibonacci_numba_array(b) == fibonacci_numba_scalar(b)
%timeit fibonacci_numba_scalar(b)
3.21 μs ± 20.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
%timeit fibonacci_numba_array(b)
32.8 μs ± 2.48 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
assert fibonacci_pytensor(b) == fibonacci_numba_scalar(b)
assert fibonacci_pytensor_numba(b) == fibonacci_numba_scalar(b)
%timeit fibonacci_pytensor(b)
2.49 ms ± 327 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit fibonacci_pytensor_numba(b)
175 μs ± 6.13 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit fibonacci_pytensor_numba.vm.jit_fn(b)
158 μs ± 706 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
fibonacci_pytensor_numba.dprint(print_type=True, print_memory_map=True)
Subtensor{i} [id A] <Scalar(int32, shape=())> v={0: [0]} 6
 ├─ Scan{scan_fn, while_loop=False, inplace=all}.0 [id B] <Vector(int32, shape=(?,))> d={0: [1], 1: [2]} 5
 │  ├─ 1000 [id C] <Scalar(int16, shape=())>
 │  ├─ SetSubtensor{:stop} [id D] <Vector(int32, shape=(1,))> d={0: [0]} 4
 │  │  ├─ AllocEmpty{dtype='int32'} [id E] <Vector(int32, shape=(1,))> 3
 │  │  │  └─ 1 [id F] <Scalar(int64, shape=())>
 │  │  ├─ [1] [id G] <Vector(int32, shape=(1,))>
 │  │  └─ 1 [id H] <int64>
 │  └─ SetSubtensor{:stop} [id I] <Vector(int32, shape=(2,))> d={0: [0]} 2
 │     ├─ AllocEmpty{dtype='int32'} [id J] <Vector(int32, shape=(2,))> 1
 │     │  └─ 2 [id K] <Scalar(int64, shape=())>
 │     ├─ ExpandDims{axis=0} [id L] <Vector(int32, shape=(1,))> v={0: [0]} 0
 │     │  └─ b_symbolic [id M] <Scalar(int32, shape=())>
 │     └─ 1 [id H] <int64>
 └─ 0 [id N] <uint8>

Inner graphs:

Scan{scan_fn, while_loop=False, inplace=all} [id B] d={0: [1], 1: [2]}
 ← Add [id O] <Scalar(int32, shape=())>
    ├─ *0-<Scalar(int32, shape=())> [id P] <Scalar(int32, shape=())> -> [id D]
    └─ *1-<Scalar(int32, shape=())> [id Q] <Scalar(int32, shape=())> -> [id I]
 ← *0-<Scalar(int32, shape=())> [id P] <Scalar(int32, shape=())> -> [id D]
<ipykernel.iostream.OutStream at 0x7fa4fbfacdf0>
print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__source__)
def numba_funcified_fgraph(b_symbolic):
    # ExpandDims{axis=0}(b_symbolic)
    tensor_variable = dimshuffle(b_symbolic)
    # AllocEmpty{dtype='int32'}(2)
    tensor_variable_1 = allocempty(tensor_constant)
    # SetSubtensor{:stop}(AllocEmpty{dtype='int32'}.0, ExpandDims{axis=0}.0, 1)
    tensor_variable_2 = set_subtensor(tensor_variable_1, tensor_variable, scalar_constant)
    # AllocEmpty{dtype='int32'}(1)
    tensor_variable_3 = allocempty_1(tensor_constant_1)
    # SetSubtensor{:stop}(AllocEmpty{dtype='int32'}.0, [1], 1)
    tensor_variable_4 = set_subtensor_1(tensor_variable_3, tensor_constant_2, scalar_constant)
    # Scan{scan_fn, while_loop=False, inplace=all}(1000, SetSubtensor{:stop}.0, SetSubtensor{:stop}.0)
    tensor_variable_5, tensor_variable_6 = scan(tensor_constant_3, tensor_variable_4, tensor_variable_2)
    # Subtensor{i}(Scan{scan_fn, while_loop=False, inplace=all}.0, 0)
    tensor_variable_7 = subtensor(tensor_variable_5, scalar_constant_1)
    return (tensor_variable_7,)
print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__globals__["allocempty"].py_func.__source__)
def allocempty(tensor_constant):
    tensor_constant_item = to_scalar(tensor_constant)
    scalar_shape = (tensor_constant_item, )
    return np.empty(scalar_shape, dtype)
    
print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__globals__["set_subtensor"].py_func.__source__)
def set_subtensor(tensor_variable, tensor_variable_1, scalar_constant):
    z = tensor_variable
    indices = (slice(None, scalar_constant, None),)
    z[indices] = tensor_variable_1
    return np.asarray(z)
    
print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__globals__["scan"].py_func.__source__)
def scan(n_steps, outer_in_1, outer_in_2):

    outer_in_1_len = outer_in_1.shape[0]
    outer_in_1_sitsot_storage = outer_in_1
    outer_in_2_len = outer_in_2.shape[0]
    outer_in_2_sitsot_storage = outer_in_2

    outer_in_1_sitsot_storage_temp_scalar_0 = np.empty((), dtype=np.int32)
    outer_in_2_sitsot_storage_temp_scalar_0 = np.empty((), dtype=np.int32)

    i = 0
    cond = np.array(False)
    while i < n_steps and not cond.item():
        outer_in_1_sitsot_storage_temp_scalar_0[()] = outer_in_1_sitsot_storage[(i) % outer_in_1_len]
        outer_in_2_sitsot_storage_temp_scalar_0[()] = outer_in_2_sitsot_storage[(i) % outer_in_2_len]

        (inner_out_0, inner_out_1) = scan_inner_func(outer_in_1_sitsot_storage_temp_scalar_0, outer_in_2_sitsot_storage_temp_scalar_0)

        outer_in_1_sitsot_storage[(i + 1) % outer_in_1_len] = inner_out_0
        outer_in_2_sitsot_storage[(i + 1) % outer_in_2_len] = inner_out_1
        i += 1

    if 1 < outer_in_1_len < (i + 1):
        outer_in_1_sitsot_storage_shift = (i + 1) % (outer_in_1_len)
        if outer_in_1_sitsot_storage_shift > 0:
            outer_in_1_sitsot_storage_left = outer_in_1_sitsot_storage[:outer_in_1_sitsot_storage_shift]
            outer_in_1_sitsot_storage_right = outer_in_1_sitsot_storage[outer_in_1_sitsot_storage_shift:]
            outer_in_1_sitsot_storage = np.concatenate((outer_in_1_sitsot_storage_right, outer_in_1_sitsot_storage_left))
    if 1 < outer_in_2_len < (i + 1):
        outer_in_2_sitsot_storage_shift = (i + 1) % (outer_in_2_len)
        if outer_in_2_sitsot_storage_shift > 0:
            outer_in_2_sitsot_storage_left = outer_in_2_sitsot_storage[:outer_in_2_sitsot_storage_shift]
            outer_in_2_sitsot_storage_right = outer_in_2_sitsot_storage[outer_in_2_sitsot_storage_shift:]
            outer_in_2_sitsot_storage = np.concatenate((outer_in_2_sitsot_storage_right, outer_in_2_sitsot_storage_left))

    return outer_in_1_sitsot_storage, outer_in_2_sitsot_storage
    
from pytensor.link.numba.dispatch.basic import tuple_setitem, to_scalar

@numba.njit
def allocempty(s):
    s_item = to_scalar(s)
    scalar_shape = (s_item,)
    return np.empty(scalar_shape, dtype=np.int32)

@numba.njit
def subtensor(x, idx):
    indices = (idx,)
    z = x[indices]
    return np.asarray(z)

@numba.njit
def inner_scan_func(a, b):
    res = a + b
    return res, a

@numba.njit
def scan_fib(n_steps, a_buf, b_buf):
    a_buf_len = a_buf.shape[0]
    b_buf_len = b_buf.shape[0]

    tmp_a_scalar = np.empty((), dtype=np.int32)
    tmp_b_scalar = np.empty((), dtype=np.int32)

    i = 0
    while i < n_steps:
        tmp_a_scalar[()] = a_buf[i % a_buf_len]
        tmp_b_scalar[()] = b_buf[i % b_buf_len]
        next_a, next_b = inner_scan_func(tmp_a_scalar, tmp_b_scalar)
        a_buf[(i + 1) % a_buf_len] = next_a
        b_buf[(i + 1) % b_buf_len] = next_b
        i += 1

    if 1 < a_buf_len < (i + 1):
        a_buf_shift = (i + 1) % a_buf_len
        if a_buf_shift > 0:
            a_buf = np.concatenate((a_buf[a_buf_shift:], a_buf[:a_buf_shift]))
    if 1 < b_buf_len < (i + 1):
        b_buf_shift = (i + 1) % b_buf_len
        if b_buf_shift > 0:
            b_buf = np.concatenate((b_buf[b_buf_shift:], b_buf[:b_buf_shift]))

    return a_buf, b_buf

@numba.njit
def set_subtensor(x, y, idx):
    indices = (slice(None, idx, None),)
    x[indices] = y
    return np.asarray(x)

@numba.njit
def dimshuffle(x):
    old_shape = x.shape
    old_strides = x.strides

    new_shape = (1,)
    new_strides = (0,)
    new_order = (-1,)
    for i, o in enumerate(new_order):
        if o != -1:
            new_shape = tuple_setitem(new_shape, i, old_shape[o])
            new_strides = tuple_setitem(new_strides, i, old_strides[o])

    return np.lib.stride_tricks.as_strided(x, shape=new_shape, strides=new_strides)
    # return np.expand_dims(x, axis=0)

@numba.njit
def comparable_fibonacci_numba(b):
    a_buf = allocempty(np.array(1, dtype=np.int64))
    # a_buf = np.empty(1, dtype=np.int32)
    # a_buf[:1] = np.array([1], dtype=np.int32)
    a_buf_set = set_subtensor(a_buf, np.array([1], dtype=np.int32), np.int64(1))

    b_buf = allocempty(np.array(2, dtype=np.int64))
    # b_buf = np.empty(2, dtype=np.int32)
    # b_buf[:1] = np.expand_dims(b, axis=0)
    b_expanded = dimshuffle(b)
    b_buf_set = set_subtensor(b_buf, b_expanded, np.int64(1))

    a_buf_updated, b_buf_updated = scan_fib(np.array(N_STEPS, np.int64), a_buf_set, b_buf_set)

    res = subtensor(a_buf_updated, np.uint8(0))

    return (res,)
b = np.ones((), dtype=np.int32)
assert comparable_fibonacci_numba(b) == fibonacci_numba_scalar(b)
%timeit comparable_fibonacci_numba(b)
54.6 μs ± 1.28 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)