1
Fork 0

python(denoising): Implemented iterative method

Signed-off-by: prescientmoon <git@moonythm.dev>
This commit is contained in:
Matei Adriel 2023-03-29 03:05:50 +02:00 committed by prescientmoon
parent 4c8d96a79b
commit 58c07abff0
Signed by: prescientmoon
SSH key fingerprint: SHA256:UUF9JT2s8Xfyv76b8ZuVL7XrmimH4o49p4b+iexbVH4
3 changed files with 253 additions and 27 deletions

Binary file not shown.

File diff suppressed because one or more lines are too long

View file

@ -1,10 +1,26 @@
import numpy as np
import scipy as sp
def assert_valid_tridiagonal(a, c, e):
"""
Validates a tridiagonal matrix.
"""
assert len(a) == len(c) + 1 == len(e) + 1
def create(a, c, e):
"""
Validates a tridiagonal matrix before creating it.
"""
assert_valid_tridiagonal(a, c, e)
return (a, c, e)
def decompose(a, c, e):
"""
Computes the LU decomposition of a tridiagonal matrix.
"""
assert_valid_tridiagonal(a, c, e)
α = np.zeros(len(c))
β = a.copy()
@ -25,6 +41,8 @@ def to_csr(a, c, e):
"""
Converts a tridiagonal matrix into a scipy csr sparse matrix.
"""
assert_valid_tridiagonal(a, c, e)
n = len(c)
values = np.zeros(n * 3 + 1)
@ -47,6 +65,7 @@ def to_array(a, c, e):
"""
Converts a tridiagonal matrix into a numpy matrix.
"""
assert_valid_tridiagonal(a, c, e)
return to_csr(a, c, e).toarray()
def from_lower(α):
@ -59,7 +78,7 @@ def from_lower(α):
print(from_lower(α))
```
"""
return (np.ones(len(α) + 1), np.zeros(len(α)), α)
return create(np.ones(len(α) + 1), np.zeros(len(α)), α)
def from_upper(β, c):
"""
@ -71,7 +90,7 @@ def from_upper(β, c):
print(from_upper(β, c))
```
"""
return (β, c, np.zeros(len(c)))
return create(β, c, np.zeros(len(c)))
def solve_lower(α, rhs):
"""
@ -110,6 +129,11 @@ def solve_upper(β, c, rhs):
return x
def solve(a, c, e, rhs):
"""
Solves a system of linear equations defined by a tridiagonal matrix.
"""
assert_valid_tridiagonal(a, c, e)
α, β = decompose(a, c, e)
x = solve_upper(β, c, solve_lower(α, rhs))
@ -118,9 +142,97 @@ def solve(a, c, e, rhs):
return x
def multiply_vector(a, c, e, x):
"""
Performs a matrix-vector multiplication where the matrix is tridiagonal.
"""
assert_valid_tridiagonal(a, c, e)
assert len(x) == len(a)
result = np.zeros_like(x)
for i in range(len(x)):
result[i] = a[i] * x[i]
if i > 0:
result[i] += e[i - 1] * x[i - 1]
if i < len(x) - 1:
result[i] += c[i] * x[i + 1]
# Sanity check
if len(a) <= 10:
assert np.allclose(to_array(a, c, e) @ x, result)
return result
def largest_eigenvalue(a, c, e, initial_x, kmax):
"""
Computes the largest eigenvalue of a positive definite tridiagonal matrix.
"""
assert_valid_tridiagonal(a, c, e)
x = initial_x
for i in range(kmax):
q = multiply_vector(a, c, e, x)
assert not np.allclose(q, 0)
x = q/np.linalg.norm(q)
# Computes the eigenvalue from the eigenvector
eigenvalue = (x @ multiply_vector(a, c, e, x)) / (x @ x)
# Sanity check
if len(initial_x) <= 10:
actual_eigenvalues, _ = np.linalg.eig(to_array(a, c, e))
assert np.allclose(
0,
np.min(
np.abs(
actual_eigenvalues - eigenvalue
)
)
)
return eigenvalue
def add(a, b):
"""
Adds two tridiagonal matrices.
"""
assert_valid_tridiagonal(*a)
assert_valid_tridiagonal(*b)
assert len(a[0]) == len(b[0])
result = create(a[0] + b[0], a[1] + b[1], a[2] + b[2])
# Sanity check
if len(a[0]) <= 10:
assert np.allclose(to_array(*result), to_array(*a) + to_array(*b))
return result
def identity(n):
"""
Returns the tridiagonal identity n*n matrix.
"""
assert n > 0 # Sanity check
return create(np.ones(n), np.zeros(n - 1), np.zeros(n - 1))
def scale(s, m):
"""
Multiplies a tridiagonal matrix by a scalar
"""
assert_valid_tridiagonal(*m)
result = create(s * m[0], s * m[1], s * m[2])
# Sanity check
if len(m[0]) <= 10:
assert np.allclose(result, s * to_array(*m))
return result
# Small sanity check for the above code
def main():
a, c, e = (np.ones(4), 2*np.ones(3), 3*np.ones(3))
a, c, e = create(3*np.ones(4), 2*np.ones(3), 3*np.ones(3))
rhs = np.ones(4)
result = solve(a, c, e, rhs)
@ -128,5 +240,6 @@ def main():
print(f"{rhs=}")
print(f"{result=}")
print(to_array(a, c, e) @ result)
print(largest_eigenvalue(a, c, e, np.ones(4), 50))
main()
# main()