Skip to content

Instantly share code, notes, and snippets.

@WuSiYu
Created April 29, 2019 12:37
Show Gist options
  • Select an option

  • Save WuSiYu/a0edf1bc816459ba57bbb0f8c958c9f3 to your computer and use it in GitHub Desktop.

Select an option

Save WuSiYu/a0edf1bc816459ba57bbb0f8c958c9f3 to your computer and use it in GitHub Desktop.
Simple matrix store and process module (test only)
#!/usr/bin/env python3
"""Simple matrix store and process module.
This module using array or list + Fraction to store each row, and list for
these rows. It support the add, sub, mul, inverse and transpose of a matrix
and do plu_factorization/plu_solve. It may not very efficiency for large scale
matrix calculation, but can help you to do your linear algebra homework.
Example:
m1 = Matrix(3, 3)
m2 = Matrix(2, 2, initializer=((1, 2), (3, 4)))
print(m2 * m2.transpose())
print(m2.inverse())
plu_matrix_tuple = plu_factorization(m2)
plu_solve(plu_matrix_tuple, (2, 5))
"""
__author__ = "SiYu Wu"
import array
import copy
from fractions import Fraction
class Matrix(object):
def __init__(self, m, n, typecode='[fraction]', initializer=None):
"""typecode should be "[fraction]" if use Fraction, or an array's typecode"""
self.M = m
self.N = n
self.type = typecode
if typecode == '[fraction]':
if initializer:
self._content = list(list(Fraction(num) for num in initializer[x]) for x in range(m))
else:
self._content = list(list(Fraction(0) for num in range(n)) for x in range(m))
else:
if initializer:
self._content = list(array.array(typecode, initializer[x]) for x in range(m))
else:
self._content = list(array.array(typecode, [0] * n) for x in range(m))
def __str__(self):
row_widths = tuple(max(map(len, map(str, (self[m][n] for m in range(self.M))))) for n in range(self.N))
if self.M > 1:
lines = []
for m in range(self.M):
line_strs = ('{0:>{1}}'.format(str(self[m][n]), row_widths[n]) for n in range(self.N))
lines.append('| ' + ' '.join(line_strs) + ' |')
return '\n'.join(lines) + '\n'
def __repr__(self):
info = 'Matrix object %d by %d, type=%s\n' % (self.M, self.N, self.type)
return info + self.__str__()
def __getitem__(self, i):
return self._content[i]
def __add__(self, other):
if isinstance(other, Matrix):
if self.M != other.M or self.N != other.N:
raise ValueError("Invalid matrices size")
res = Matrix(self.M, self.N)
for m in range(self.M):
for n in range(self.N):
res[m][n] = self[m][n] + other[m][n]
return res
else:
return NotImplemented
def __sub__(self, other):
if isinstance(other, Matrix):
if self.M != other.M or self.N != other.N:
raise ValueError("Invalid matrices size")
res = Matrix(self.M, self.N)
for m in range(self.M):
for n in range(self.N):
res[m][n] = self[m][n] - other[m][n]
return res
else:
return NotImplemented
def __mul__(self, other):
if isinstance(other, Matrix):
if self.N != other.M:
raise ValueError("Invalid matrices size")
res = Matrix(self.M, other.N)
for m in range(self.M):
for n in range(other.N):
res[m][n] = sum(self[m][i] * other[i][n] for i in range(self.N))
return res
elif type(other) in (int, float):
res = Matrix(self.M, self.N)
for m in range(self.M):
for n in range(self.N):
res[m][n] = self[m][n] * other
return res
else:
return NotImplemented
def __rmul__(self, other):
return self.__mul__(other)
def __getattr__(self, attr):
if attr == 'T':
return self.transpose()
elif attr == 'i':
return self.inverse()
else:
raise AttributeError('\'Matrix\' object has no attribute \'%s\'' % attr)
def swap_row(self, row1, row2):
self._content[row1], self._content[row2] = self._content[row2], self._content[row1]
def inverse(self):
"""return the inverse of this matrix"""
if self.M != self.N:
raise ValueError("This matrix is not invertible")
A_matrix = copy.deepcopy(self)
S_matrix = Matrix(self.M, self.N)
for x in range(self.M):
S_matrix[x][x] = 1
for i in range(self.M):
pivot = A_matrix[i][i]
if not pivot:
raise ValueError("This matrix is not invertible")
for j in range(i + 1, self.M):
factor = A_matrix[j][i] / pivot
for k in range(i, self.M):
A_matrix[j][k] -= A_matrix[i][k] * factor
for k in range(self.M):
S_matrix[j][k] -= S_matrix[i][k] * factor
for i in reversed(range(self.M)):
for j in range(i + 1, self.M):
for k in range(self.M):
S_matrix[i][k] -= S_matrix[j][k] * A_matrix[i][j]
A_matrix[i][j] = 0
for k in range(self.M):
S_matrix[i][k] /= A_matrix[i][i]
A_matrix[i][i] = 1
return S_matrix
def transpose(self):
"""return the transpose of this matrix"""
T_matrix = Matrix(self.N, self.M, self.type)
for i in range(self.M):
for j in range(self.N):
T_matrix[j][i] = self[i][j]
return T_matrix
def plu_factorization(mat):
identity_matrix = Matrix(mat.M, mat.N, mat.type)
for x in range(mat.M):
identity_matrix[x][x] = 1
P_matrix = Matrix(mat.M, mat.N, mat.type, identity_matrix)
L_matrix = Matrix(mat.M, mat.N, mat.type, identity_matrix)
U_matrix = Matrix(mat.M, mat.N, mat.type, mat)
for i in range(mat.M):
# search for a non-zero i-Col number as pivot
for p in range(i, mat.M):
if U_matrix[p][i]:
pivot = U_matrix[p][i]
if p != i: # swap rows of the matrices
U_matrix.swap_row(i, p)
L_matrix -= identity_matrix # TODO: optimization this operation
L_matrix.swap_row(i, p)
L_matrix += identity_matrix
P_matrix.swap_row(i, p)
break
else:
continue
# solve this line
Ln_matrix = Matrix(mat.M, mat.N, mat.type, identity_matrix)
for j in range(i + 1, mat.M):
factor = U_matrix[j][i] / pivot
L_matrix[j][i] = U_matrix[j][i] / pivot
for k in range(i, mat.M):
U_matrix[j][k] -= U_matrix[i][k] * factor
# append this line's trans to L_matrix
L_matrix *= Ln_matrix
return P_matrix, L_matrix, U_matrix
def plu_solve(plu, b_vector):
m_size = len(b_vector)
# using Ly = b to solve y_vector
L_matrix = plu[1]
y_vector = [0] * m_size
for i in range(m_size):
b_i = b_vector[i]
for j in range(i):
b_i -= L_matrix[i][j] * y_vector[j]
y_vector[i] = b_i / L_matrix[i][i]
# using Ux = y to solve x_vector
U_matrix = plu[2]
x_vector = Matrix(m_size, 1)
# x_vector = [0] * m_size
for i in reversed(range(m_size)):
y_i = y_vector[i]
for j in reversed(range(i, m_size)):
y_i -= U_matrix[i][j] * x_vector[j][0]
x_vector[i][0] = y_i / U_matrix[i][i]
x_vector = plu[0] * x_vector
return x_vector
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment