import numpy as np
import scipy as sp

def decompose(a, c, e):
  """
  Computes the LU decomposition of a tridiagonal matrix.
  """
  α = np.zeros(len(c))
  β = a.copy()

  for i in range(len(c)):
    α[i] = e[i] / β[i]
    β[i + 1] -= c[i] * α[i]

  # Sanity check
  if len(a) <= 10:
    assert np.allclose(
      to_array(a, c, e),
      to_array(*from_lower(α)) @ to_array(*from_upper(β, c))
    )

  return (α, β)

def to_csr(a, c, e):
  """
  Converts a tridiagonal matrix into a scipy csr sparse matrix.
  """
  n = len(c)

  values = np.zeros(n * 3 + 1)
  values[::3] = a
  values[1::3] = c
  values[2::3] = e
  
  col_indices = np.zeros_like(values)
  col_indices[1::3] = np.arange(1, n + 1)
  col_indices[2::3] = np.arange(0, n)
  col_indices[3::3] = np.arange(1, n + 1)

  index_ptr = np.zeros(n + 2)
  index_ptr[1:n+1] = np.arange(2, n * 3 + 2, 3)
  index_ptr[n+1] = n * 3 + 1

  return sp.sparse.csr_array((values, col_indices, index_ptr))

def to_array(a, c, e):
  """
  Converts a tridiagonal matrix into a numpy matrix.
  """
  return to_csr(a, c, e).toarray()

def from_lower(α):
  """
  Turns the lower vector of a decomposition into a tridiagonal matrix.

  Example ussage:
  ```py 
  α, β = decompose(m)
  print(from_lower(α))
  ```
  """
  return (np.ones(len(α) + 1), np.zeros(len(α)), α) 

def from_upper(β, c):
  """
  Turns the upper vectors of a decomposition into a tridiagonal matrix.

  Example ussage:
  ```py 
  α, β = decompose((a, c, e))
  print(from_upper(β, c))
  ```
  """
  return (β, c, np.zeros(len(c)))

def solve_lower(α, rhs):
  """
  Solve a linear system of equations Mx = v
  where M is a lower triangular matrix constructed
  by LU decomposing a tridiagonal matrix.
  """
  x = np.zeros_like(rhs)

  x[0] = rhs[0]

  for i in range(1, len(rhs)):
    x[i] = rhs[i] - α[i - 1] * x[i - 1]

  if len(α) <= 10:
    assert np.allclose(to_array(*from_lower(α)) @ x, rhs)

  return x

def solve_upper(β, c, rhs):
  """
  Solve a linear system of equations Mx = v
  where M is an upper triangular matrix constructed
  by LU decomposing a tridiagonal matrix.
  """
  x = np.zeros_like(rhs)

  x[-1] = rhs[-1] / β[-1]

  for i in reversed(range(len(rhs) - 1)):
    x[i] = (rhs[i] - c[i] * x[i+1]) / β[i]

  if len(β) <= 10:
    assert np.allclose(to_array(*from_upper(β, c)) @ x, rhs)

  return x

def solve(a, c, e, rhs):
  α, β = decompose(a, c, e)
  x = solve_upper(β, c, solve_lower(α, rhs))

  if len(α) <= 10:
    assert np.allclose(to_array(a, c, e)@x, rhs)

  return x

# Small sanity check for the above code
def main():
  a, c, e = (np.ones(4), 2*np.ones(3), 3*np.ones(3))

  rhs = np.ones(4)
  result = solve(a, c, e, rhs)
  print(f"m={to_array(a, c, e)}")
  print(f"{rhs=}")
  print(f"{result=}")
  print(to_array(a, c, e) @ result)

main()