lua-matrix/lua/matrix.lua

1231 lines
31 KiB
Lua
Raw Blame History

--[[
matrix v 0.2.9
Lua 5.1 compatible
'matrix' provides a good selection of matrix functions.
With simple matrices this script is quite useful, though for more
exact calculations, one would probably use a program like Matlab instead.
Matrices of size 100x100 can still be handled very well.
The error for the determinant and the inverted matrix is around 10^-9
with a 100x100 matrix and an element range from -100 to 100.
Characteristics:
- functions called via matrix.<function> should be able to handle
any table matrix of structure t[i][j] = value
- can handle a type of complex matrix
- can handle symbolic matrices. (Symbolic matrices cannot be
used with complex matrices.)
- arithmetic functions do not change the matrix itself
but build and return a new matrix
- functions are intended to be light on checks
since one gets a Lua error on incorrect use anyways
- uses mainly Gauss-Jordan elimination
- for Lua tables optimised determinant calculation (fast)
but not invoking any checks for special types of matrices
- vectors can be set up via vec1 = matrix{{ 1,2,3 }}^'T' or matrix{1,2,3}
- vectors can be multiplied scalar via num = vec1^'T' * vec2
where num will be a matrix with the result in mtx[1][1],
or use num = vec1:scalar( vec2 ), where num is a number
Site: http://luamatrix.luaforge.net
matrix function list:
matrix.add
matrix.columns
matrix.concath
matrix.concatv
matrix.copy
matrix.cross
matrix.det
matrix.div
matrix.divnum
matrix.dogauss
matrix.elementstostring
matrix.getelement
matrix.gsub
matrix.invert
matrix.ipairs
matrix.latex
matrix.len
matrix.mul
matrix.mulnum
matrix:new
matrix.normf
matrix.normmax
matrix.pow
matrix.print
matrix.random
matrix.replace
matrix.root
matrix.rotl
matrix.rotr
matrix.round
matrix.rows
matrix.scalar
matrix.setelement
matrix.size
matrix.solve
matrix.sqrt
matrix.sub
matrix.subm
matrix.tostring
matrix.transpose
matrix.type
Licensed under the same terms as Lua itself.
Developers:
Michael Lutz (chillcode)
David Manura http://lua-users.org/wiki/DavidManura
]]--
--////////////
--// matrix //
--////////////
local matrix = {}
-- access to the metatable we set at the end of the file
local matrix_meta = {}
--/////////////////////////////
--// Get 'new' matrix object //
--/////////////////////////////
--// matrix:new ( rows [, comlumns [, value]] )
-- if rows is a table then sets rows as matrix
-- if rows is a table of structure {1,2,3} then it sets it as a vector matrix
-- if rows and columns are given and are numbers, returns a matrix with size rowsxcolumns
-- if num is given then returns a matrix with given size and all values set to num
-- if rows is given as number and columns is "I", will return an identity matrix of size rowsxrows
function matrix:new( rows, columns, value )
-- check for given matrix
if type( rows ) == "table" then
-- check for vector
if type(rows[1]) ~= "table" then -- expect a vector
return setmetatable( {{rows[1]},{rows[2]},{rows[3]}},matrix_meta )
end
return setmetatable( rows,matrix_meta )
end
-- get matrix table
local mtx = {}
local value = value or 0
-- build identity matrix of given rows
if columns == "I" then
for i = 1,rows do
mtx[i] = {}
for j = 1,rows do
if i == j then
mtx[i][j] = 1
else
mtx[i][j] = 0
end
end
end
-- build new matrix
else
for i = 1,rows do
mtx[i] = {}
for j = 1,columns do
mtx[i][j] = value
end
end
end
-- return matrix with shared metatable
return setmetatable( mtx,matrix_meta )
end
--// matrix ( rows [, comlumns [, value]] )
-- set __call behaviour of matrix
-- for matrix( ... ) as matrix.new( ... )
setmetatable( matrix, { __call = function( ... ) return matrix.new( ... ) end } )
-- functions are designed to be light on checks
-- so we get Lua errors instead on wrong input
-- matrix.<functions> should handle any table of structure t[i][j] = value
-- we always return a matrix with scripts metatable
-- cause its faster than setmetatable( mtx, getmetatable( input matrix ) )
--///////////////////////////////
--// matrix 'matrix' functions //
--///////////////////////////////
--// for real, complex and symbolic matrices //--
-- note: real and complex matrices may be added, subtracted, etc.
-- real and symbolic matrices may also be added, subtracted, etc.
-- but one should avoid using symbolic matrices with complex ones
-- since it is not clear which metatable then is used
--// matrix.add ( m1, m2 )
-- Add two matrices; m2 may be of bigger size than m1
function matrix.add( m1, m2 )
local mtx = {}
for i = 1,#m1 do
local m3i = {}
mtx[i] = m3i
for j = 1,#m1[1] do
m3i[j] = m1[i][j] + m2[i][j]
end
end
return setmetatable( mtx, matrix_meta )
end
--// matrix.sub ( m1 ,m2 )
-- Subtract two matrices; m2 may be of bigger size than m1
function matrix.sub( m1, m2 )
local mtx = {}
for i = 1,#m1 do
local m3i = {}
mtx[i] = m3i
for j = 1,#m1[1] do
m3i[j] = m1[i][j] - m2[i][j]
end
end
return setmetatable( mtx, matrix_meta )
end
--// matrix.mul ( m1, m2 )
-- Multiply two matrices; m1 columns must be equal to m2 rows
-- e.g. #m1[1] == #m2
function matrix.mul( m1, m2 )
-- multiply rows with columns
local mtx = {}
for i = 1,#m1 do
mtx[i] = {}
for j = 1,#m2[1] do
local num = m1[i][1] * m2[1][j]
for n = 2,#m1[1] do
num = num + m1[i][n] * m2[n][j]
end
mtx[i][j] = num
end
end
return setmetatable( mtx, matrix_meta )
end
--// matrix.div ( m1, m2 )
-- Divide two matrices; m1 columns must be equal to m2 rows
-- m2 must be square, to be inverted,
-- if that fails returns the rank of m2 as second argument
-- e.g. #m1[1] == #m2; #m2 == #m2[1]
function matrix.div( m1, m2 )
local rank; m2,rank = matrix.invert( m2 )
if not m2 then return m2, rank end -- singular
return matrix.mul( m1, m2 )
end
--// matrix.mulnum ( m1, num )
-- Multiply matrix with a number
-- num may be of type 'number' or 'complex number'
-- strings get converted to complex number, if that fails then to symbol
function matrix.mulnum( m1, num )
local mtx = {}
-- multiply elements with number
for i = 1,#m1 do
mtx[i] = {}
for j = 1,#m1[1] do
mtx[i][j] = m1[i][j] * num
end
end
return setmetatable( mtx, matrix_meta )
end
--// matrix.divnum ( m1, num )
-- Divide matrix by a number
-- num may be of type 'number' or 'complex number'
-- strings get converted to complex number, if that fails then to symbol
function matrix.divnum( m1, num )
local mtx = {}
-- divide elements by number
for i = 1,#m1 do
local mtxi = {}
mtx[i] = mtxi
for j = 1,#m1[1] do
mtxi[j] = m1[i][j] / num
end
end
return setmetatable( mtx, matrix_meta )
end
--// for real and complex matrices only //--
--// matrix.pow ( m1, num )
-- Power of matrix; mtx^(num)
-- num is an integer and may be negative
-- m1 has to be square
-- if num is negative and inverting m1 fails
-- returns the rank of matrix m1 as second argument
function matrix.pow( m1, num )
assert(num == math.floor(num), "exponent not an integer")
if num == 0 then
return matrix:new( #m1,"I" )
end
if num < 0 then
local rank; m1,rank = matrix.invert( m1 )
if not m1 then return m1, rank end -- singular
num = -num
end
local mtx = matrix.copy( m1 )
for i = 2,num do
mtx = matrix.mul( mtx,m1 )
end
return mtx
end
local function number_norm2(x)
return x * x
end
--// matrix.det ( m1 )
-- Calculate the determinant of a matrix
-- m1 needs to be square
-- Can calc the det for symbolic matrices up to 3x3 too
-- The function to calculate matrices bigger 3x3
-- is quite fast and for matrices of medium size ~(100x100)
-- and average values quite accurate
-- here we try to get the nearest element to |1|, (smallest pivot element)
-- os that usually we have |mtx[i][j]/subdet| > 1 or mtx[i][j];
-- with complex matrices we use the complex.abs function to check if it is bigger or smaller
function matrix.det( m1 )
-- check if matrix is quadratic
assert(#m1 == #m1[1], "matrix not square")
local size = #m1
if size == 1 then
return m1[1][1]
end
if size == 2 then
return m1[1][1]*m1[2][2] - m1[2][1]*m1[1][2]
end
if size == 3 then
return ( m1[1][1]*m1[2][2]*m1[3][3] + m1[1][2]*m1[2][3]*m1[3][1] + m1[1][3]*m1[2][1]*m1[3][2]
- m1[1][3]*m1[2][2]*m1[3][1] - m1[1][1]*m1[2][3]*m1[3][2] - m1[1][2]*m1[2][1]*m1[3][3] )
end
--// no symbolic matrix supported below here
local e = m1[1][1]
local zero = type(e) == "table" and e.zero or 0
local norm2 = type(e) == "table" and e.norm2 or number_norm2
--// matrix is bigger than 3x3
-- get determinant
-- using Gauss elimination and Laplace
-- start eliminating from below better for removals
-- get copy of matrix, set initial determinant
local mtx = matrix.copy( m1 )
local det = 1
-- get det up to the last element
for j = 1,#mtx[1] do
-- get smallest element so that |factor| > 1
-- and set it as last element
local rows = #mtx
local subdet,xrow
for i = 1,rows do
-- get element
local e = mtx[i][j]
-- if no subdet has been found
if not subdet then
-- check if element it is not zero
if e ~= zero then
-- use element as new subdet
subdet,xrow = e,i
end
-- check for elements nearest to 1 or -1
elseif e ~= zero and math.abs(norm2(e)-1) < math.abs(norm2(subdet)-1) then
subdet,xrow = e,i
end
end
-- only cary on if subdet is found
if subdet then
-- check if xrow is the last row,
-- else switch lines and multiply det by -1
if xrow ~= rows then
mtx[rows],mtx[xrow] = mtx[xrow],mtx[rows]
det = -det
end
-- traverse all fields setting element to zero
-- we don't set to zero cause we don't use that column anymore then anyways
for i = 1,rows-1 do
-- factor is the dividor of the first element
-- if element is not already zero
if mtx[i][j] ~= zero then
local factor = mtx[i][j]/subdet
-- update all remaining fields of the matrix, with value from xrow
for n = j+1,#mtx[1] do
mtx[i][n] = mtx[i][n] - factor * mtx[rows][n]
end
end
end
-- update determinant and remove row
if math.fmod( rows,2 ) == 0 then
det = -det
end
det = det * subdet
table.remove( mtx )
else
-- break here table det is 0
return det * 0
end
end
-- det ready to return
return det
end
--// matrix.dogauss ( mtx )
-- Gauss elimination, Gauss-Jordan Method
-- this function changes the matrix itself
-- returns on success: true,
-- returns on failure: false,'rank of matrix'
-- locals
-- checking here for the nearest element to 1 or -1; (smallest pivot element)
-- this way the factor of the evolving number division should be > 1 or the
-- divided number itself, what gives better results
local setelementtosmallest = function( mtx,i,j,zero,one,norm2 )
-- check if element is one
if mtx[i][j] == one then return true end
-- check for lowest value
local _ilow
for _i = i,#mtx do
local e = mtx[_i][j]
if e == one then
break
end
if not _ilow then
if e ~= zero then
_ilow = _i
end
elseif (e ~= zero) and math.abs(norm2(e)-1) < math.abs(norm2(mtx[_ilow][j])-1) then
_ilow = _i
end
end
if _ilow then
-- switch lines if not input line
-- legal operation
if _ilow ~= i then
mtx[i],mtx[_ilow] = mtx[_ilow],mtx[i]
end
return true
end
end
local function copy(x)
return type(x) == "table" and x.copy(x) or x
end
-- note: in --// ... //-- we have a way that does no divison,
-- however with big number and matrices we get problems since we do no reducing
function matrix.dogauss( mtx )
local e = mtx[1][1]
local zero = type(e) == "table" and e.zero or 0
local one = type(e) == "table" and e.one or 1
local norm2 = type(e) == "table" and e.norm2 or number_norm2
local rows,columns = #mtx,#mtx[1]
-- stairs left -> right
for j = 1,rows do
-- check if element can be setted to one
if setelementtosmallest( mtx,j,j,zero,one,norm2 ) then
-- start parsing rows
for i = j+1,rows do
-- check if element is not already zero
if mtx[i][j] ~= zero then
-- we may add x*otherline row, to set element to zero
-- tozero - x*mtx[j][j] = 0; x = tozero/mtx[j][j]
local factor = mtx[i][j]/mtx[j][j]
--// this should not be used although it does no division,
-- yet with big matrices (since we do no reducing and other things)
-- we get too big numbers
--local factor1,factor2 = mtx[i][j],mtx[j][j] //--
mtx[i][j] = copy(zero)
for _j = j+1,columns do
--// mtx[i][_j] = mtx[i][_j] * factor2 - factor1 * mtx[j][_j] //--
mtx[i][_j] = mtx[i][_j] - factor * mtx[j][_j]
end
end
end
else
-- return false and the rank of the matrix
return false,j-1
end
end
-- stairs right <- left
for j = rows,1,-1 do
-- set element to one
-- do division here
local div = mtx[j][j]
for _j = j+1,columns do
mtx[j][_j] = mtx[j][_j] / div
end
-- start parsing rows
for i = j-1,1,-1 do
-- check if element is not already zero
if mtx[i][j] ~= zero then
local factor = mtx[i][j]
for _j = j+1,columns do
mtx[i][_j] = mtx[i][_j] - factor * mtx[j][_j]
end
mtx[i][j] = copy(zero)
end
end
mtx[j][j] = copy(one)
end
return true
end
--// matrix.invert ( m1 )
-- Get the inverted matrix or m1
-- matrix must be square and not singular
-- on success: returns inverted matrix
-- on failure: returns nil,'rank of matrix'
function matrix.invert( m1 )
assert(#m1 == #m1[1], "matrix not square")
local mtx = matrix.copy( m1 )
local ident = setmetatable( {},matrix_meta )
local e = m1[1][1]
local zero = type(e) == "table" and e.zero or 0
local one = type(e) == "table" and e.one or 1
for i = 1,#m1 do
local identi = {}
ident[i] = identi
for j = 1,#m1 do
identi[j] = copy((i == j) and one or zero)
end
end
mtx = matrix.concath( mtx,ident )
local done,rank = matrix.dogauss( mtx )
if done then
return matrix.subm( mtx, 1,(#mtx[1]/2)+1,#mtx,#mtx[1] )
else
return nil,rank
end
end
--// matrix.sqrt ( m1 [,iters] )
-- calculate the square root of a matrix using "Denman<61>Beavers square root iteration"
-- condition: matrix rows == matrix columns; must have a invers matrix and a square root
-- if called without additional arguments, the function finds the first nearest square root to
-- input matrix, there are others but the error between them is very small
-- if called with agument iters, the function will return the matrix by number of iterations
-- the script returns:
-- as first argument, matrix^.5
-- as second argument, matrix^-.5
-- as third argument, the average error between (matrix^.5)^2-inputmatrix
-- you have to determin for yourself if the result is sufficent enough for you
-- local average error
local function get_abs_avg( m1, m2 )
local dist = 0
local e = m1[1][1]
local abs = type(e) == "table" and e.abs or math.abs
for i=1,#m1 do
for j=1,#m1[1] do
dist = dist + abs(m1[i][j]-m2[i][j])
end
end
-- norm by numbers of entries
return dist/(#m1*2)
end
-- square root function
function matrix.sqrt( m1, iters )
assert(#m1 == #m1[1], "matrix not square")
local iters = iters or math.huge
local y = matrix.copy( m1 )
local z = matrix(#y, 'I')
local dist = math.huge
-- iterate, and get the average error
for n=1,iters do
local lasty,lastz = y,z
-- calc square root
-- y, z = (1/2)*(y + z^-1), (1/2)*(z + y^-1)
y, z = matrix.divnum((matrix.add(y,matrix.invert(z))),2),
matrix.divnum((matrix.add(z,matrix.invert(y))),2)
local dist1 = get_abs_avg(y,lasty)
if iters == math.huge then
if dist1 >= dist then
return lasty,lastz,get_abs_avg(matrix.mul(lasty,lasty),m1)
end
end
dist = dist1
end
return y,z,get_abs_avg(matrix.mul(y,y),m1)
end
--// matrix.root ( m1, root [,iters] )
-- calculate any root of a matrix
-- source: http://www.dm.unipi.it/~cortona04/slides/bruno.pdf
-- m1 and root have to be given;(m1 = matrix, root = number)
-- conditions same as matrix.sqrt
-- returns same values as matrix.sqrt
function matrix.root( m1, root, iters )
assert(#m1 == #m1[1], "matrix not square")
local iters = iters or math.huge
local mx = matrix.copy( m1 )
local my = matrix.mul(mx:invert(),mx:pow(root-1))
local dist = math.huge
-- iterate, and get the average error
for n=1,iters do
local lastx,lasty = mx,my
-- calc root of matrix
--mx,my = ((p-1)*mx + my^-1)/p,
-- ((((p-1)*my + mx^-1)/p)*my^-1)^(p-2) *
-- ((p-1)*my + mx^-1)/p
mx,my = mx:mulnum(root-1):add(my:invert()):divnum(root),
my:mulnum(root-1):add(mx:invert()):divnum(root)
:mul(my:invert():pow(root-2)):mul(my:mulnum(root-1)
:add(mx:invert())):divnum(root)
local dist1 = get_abs_avg(mx,lastx)
if iters == math.huge then
if dist1 >= dist then
return lastx,lasty,get_abs_avg(matrix.pow(lastx,root),m1)
end
end
dist = dist1
end
return mx,my,get_abs_avg(matrix.pow(mx,root),m1)
end
--// Norm functions //--
--// matrix.normf ( mtx )
-- calculates the Frobenius norm of the matrix.
-- ||mtx||_F = sqrt(SUM_{i,j} |a_{i,j}|^2)
-- http://en.wikipedia.org/wiki/Frobenius_norm#Frobenius_norm
function matrix.normf(mtx)
local mtype = matrix.type(mtx)
local result = 0
for i = 1,#mtx do
for j = 1,#mtx[1] do
local e = mtx[i][j]
if mtype ~= "number" then e = e:abs() end
result = result + e^2
end
end
local sqrt = (type(result) == "number") and math.sqrt or result.sqrt
return sqrt(result)
end
--// matrix.normmax ( mtx )
-- calculates the max norm of the matrix.
-- ||mtx||_{max} = max{|a_{i,j}|}
-- Does not work with symbolic matrices
-- http://en.wikipedia.org/wiki/Frobenius_norm#Max_norm
function matrix.normmax(mtx)
local abs = (matrix.type(mtx) == "number") and math.abs or mtx[1][1].abs
local result = 0
for i = 1,#mtx do
for j = 1,#mtx[1] do
local e = abs(mtx[i][j])
if e > result then result = e end
end
end
return result
end
--// only for number and complex type //--
-- Functions changing the matrix itself
--// matrix.round ( mtx [, idp] )
-- perform round on elements
local numround = function( num,mult )
return math.floor( num * mult + 0.5 ) / mult
end
local tround = function( t,mult )
for i,v in ipairs(t) do
t[i] = math.floor( v * mult + 0.5 ) / mult
end
return t
end
function matrix.round( mtx, idp )
local mult = 10^( idp or 0 )
local fround = matrix.type( mtx ) == "number" and numround or tround
for i = 1,#mtx do
for j = 1,#mtx[1] do
mtx[i][j] = fround(mtx[i][j],mult)
end
end
return mtx
end
--// matrix.random( mtx [,start] [, stop] [, idip] )
-- fillmatrix with random values
local numfill = function( _,start,stop,idp )
return math.random( start,stop ) / idp
end
local tfill = function( t,start,stop,idp )
for i in ipairs(t) do
t[i] = math.random( start,stop ) / idp
end
return t
end
function matrix.random( mtx,start,stop,idp )
local start,stop,idp = start or -10,stop or 10,idp or 1
local ffill = matrix.type( mtx ) == "number" and numfill or tfill
for i = 1,#mtx do
for j = 1,#mtx[1] do
mtx[i][j] = ffill( mtx[i][j], start, stop, idp )
end
end
return mtx
end
--//////////////////////////////
--// Object Utility Functions //
--//////////////////////////////
--// for all types and matrices //--
--// matrix.type ( mtx )
-- get type of matrix, normal/complex/symbol or tensor
function matrix.type( mtx )
local e = mtx[1][1]
if type(e) == "table" then
if e.type then
return e:type()
end
return "tensor"
end
return "number"
end
-- local functions to copy matrix values
local num_copy = function( num )
return num
end
local t_copy = function( t )
local newt = setmetatable( {}, getmetatable( t ) )
for i,v in ipairs( t ) do
newt[i] = v
end
return newt
end
--// matrix.copy ( m1 )
-- Copy a matrix
-- simple copy, one can write other functions oneself
function matrix.copy( m1 )
local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
local mtx = {}
for i = 1,#m1[1] do
mtx[i] = {}
for j = 1,#m1 do
mtx[i][j] = docopy( m1[i][j] )
end
end
return setmetatable( mtx, matrix_meta )
end
--// matrix.transpose ( m1 )
-- Transpose a matrix
-- switch rows and columns
function matrix.transpose( m1 )
local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
local mtx = {}
for i = 1,#m1[1] do
mtx[i] = {}
for j = 1,#m1 do
mtx[i][j] = docopy( m1[j][i] )
end
end
return setmetatable( mtx, matrix_meta )
end
--// matrix.subm ( m1, i1, j1, i2, j2 )
-- Submatrix out of a matrix
-- input: i1,j1,i2,j2
-- i1,j1 are the start element
-- i2,j2 are the end element
-- condition: i1,j1,i2,j2 are elements of the matrix
function matrix.subm( m1,i1,j1,i2,j2 )
local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
local mtx = {}
for i = i1,i2 do
local _i = i-i1+1
mtx[_i] = {}
for j = j1,j2 do
local _j = j-j1+1
mtx[_i][_j] = docopy( m1[i][j] )
end
end
return setmetatable( mtx, matrix_meta )
end
--// matrix.concath( m1, m2 )
-- Concatenate two matrices, horizontal
-- will return m1m2; rows have to be the same
-- e.g.: #m1 == #m2
function matrix.concath( m1,m2 )
assert(#m1 == #m2, "matrix size mismatch")
local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
local mtx = {}
local offset = #m1[1]
for i = 1,#m1 do
mtx[i] = {}
for j = 1,offset do
mtx[i][j] = docopy( m1[i][j] )
end
for j = 1,#m2[1] do
mtx[i][j+offset] = docopy( m2[i][j] )
end
end
return setmetatable( mtx, matrix_meta )
end
--// matrix.concatv ( m1, m2 )
-- Concatenate two matrices, vertical
-- will return m1
-- m2
-- columns have to be the same; e.g.: #m1[1] == #m2[1]
function matrix.concatv( m1,m2 )
assert(#m1[1] == #m2[1], "matrix size mismatch")
local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
local mtx = {}
for i = 1,#m1 do
mtx[i] = {}
for j = 1,#m1[1] do
mtx[i][j] = docopy( m1[i][j] )
end
end
local offset = #mtx
for i = 1,#m2 do
local _i = i + offset
mtx[_i] = {}
for j = 1,#m2[1] do
mtx[_i][j] = docopy( m2[i][j] )
end
end
return setmetatable( mtx, matrix_meta )
end
--// matrix.rotl ( m1 )
-- Rotate Left, 90 degrees
function matrix.rotl( m1 )
local mtx = matrix:new( #m1[1],#m1 )
local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
for i = 1,#m1 do
for j = 1,#m1[1] do
mtx[#m1[1]-j+1][i] = docopy( m1[i][j] )
end
end
return mtx
end
--// matrix.rotr ( m1 )
-- Rotate Right, 90 degrees
function matrix.rotr( m1 )
local mtx = matrix:new( #m1[1],#m1 )
local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
for i = 1,#m1 do
for j = 1,#m1[1] do
mtx[j][#m1-i+1] = docopy( m1[i][j] )
end
end
return mtx
end
local function tensor_tostring( t,fstr )
if not fstr then return "["..table.concat(t,",").."]" end
local tval = {}
for i,v in ipairs( t ) do
tval[i] = string.format( fstr,v )
end
return "["..table.concat(tval,",").."]"
end
local function number_tostring( e,fstr )
return fstr and string.format( fstr,e ) or e
end
--// matrix.tostring ( mtx, formatstr )
-- tostring function
function matrix.tostring( mtx, formatstr )
local ts = {}
local mtype = matrix.type( mtx )
local e = mtx[1][1]
local tostring = mtype == "tensor" and tensor_tostring or
type(e) == "table" and e.tostring or number_tostring
for i = 1,#mtx do
local tstr = {}
for j = 1,#mtx[1] do
tstr[j] = tostring(mtx[i][j],formatstr)
end
ts[i] = table.concat(tstr, "\t")
end
return table.concat(ts, "\n")
end
--// matrix.print ( mtx [, formatstr] )
-- print out the matrix, just calls tostring
function matrix.print( ... )
print( matrix.tostring( ... ) )
end
--// matrix.latex ( mtx [, align] )
-- LaTeX output
function matrix.latex( mtx, align )
-- align : option to align the elements
-- c = center; l = left; r = right
-- \usepackage{dcolumn}; D{.}{,}{-1}; aligns number by . replaces it with ,
local align = align or "c"
local str = "$\\left( \\begin{array}{"..string.rep( align, #mtx[1] ).."}\n"
local getstr = matrix.type( mtx ) == "tensor" and tensor_tostring or number_tostring
for i = 1,#mtx do
str = str.."\t"..getstr(mtx[i][1])
for j = 2,#mtx[1] do
str = str.." & "..getstr(mtx[i][j])
end
-- close line
if i == #mtx then
str = str.."\n"
else
str = str.." \\\\\n"
end
end
return str.."\\end{array} \\right)$"
end
--// Functions not changing the matrix
--// matrix.rows ( mtx )
-- return number of rows
function matrix.rows( mtx )
return #mtx
end
--// matrix.columns ( mtx )
-- return number of columns
function matrix.columns( mtx )
return #mtx[1]
end
--// matrix.size ( mtx )
-- get matrix size as string rows,columns
function matrix.size( mtx )
if matrix.type( mtx ) == "tensor" then
return #mtx,#mtx[1],#mtx[1][1]
end
return #mtx,#mtx[1]
end
--// matrix.getelement ( mtx, i, j )
-- return specific element ( row,column )
-- returns element on success and nil on failure
function matrix.getelement( mtx,i,j )
if mtx[i] and mtx[i][j] then
return mtx[i][j]
end
end
--// matrix.setelement( mtx, i, j, value )
-- set an element ( i, j, value )
-- returns 1 on success and nil on failure
function matrix.setelement( mtx,i,j,value )
if matrix.getelement( mtx,i,j ) then
-- check if value type is number
mtx[i][j] = value
return 1
end
end
--// matrix.ipairs ( mtx )
-- iteration, same for complex
function matrix.ipairs( mtx )
local i,j,rows,columns = 1,0,#mtx,#mtx[1]
local function iter()
j = j + 1
if j > columns then -- return first element from next row
i,j = i + 1,1
end
if i <= rows then
return i,j
end
end
return iter
end
--///////////////////////////////
--// matrix 'vector' functions //
--///////////////////////////////
-- a vector is defined as a 3x1 matrix
-- get a vector; vec = matrix{{ 1,2,3 }}^'T'
--// matrix.scalar ( m1, m2 )
-- returns the Scalar Product of two 3x1 matrices (vectors)
function matrix.scalar( m1, m2 )
return m1[1][1]*m2[1][1] + m1[2][1]*m2[2][1] + m1[3][1]*m2[3][1]
end
--// matrix.cross ( m1, m2 )
-- returns the Cross Product of two 3x1 matrices (vectors)
function matrix.cross( m1, m2 )
local mtx = {}
mtx[1] = { m1[2][1]*m2[3][1] - m1[3][1]*m2[2][1] }
mtx[2] = { m1[3][1]*m2[1][1] - m1[1][1]*m2[3][1] }
mtx[3] = { m1[1][1]*m2[2][1] - m1[2][1]*m2[1][1] }
return setmetatable( mtx, matrix_meta )
end
--// matrix.len ( m1 )
-- returns the Length of a 3x1 matrix (vector)
function matrix.len( m1 )
return math.sqrt( m1[1][1]^2 + m1[2][1]^2 + m1[3][1]^2 )
end
--// matrix.replace (mtx, func, ...)
-- for each element e in the matrix mtx, replace it with func(mtx, ...).
function matrix.replace( m1, func, ... )
local mtx = {}
for i = 1,#m1 do
local m1i = m1[i]
local mtxi = {}
for j = 1,#m1i do
mtxi[j] = func( m1i[j], ... )
end
mtx[i] = mtxi
end
return setmetatable( mtx, matrix_meta )
end
--// matrix.remcomplex ( mtx )
-- set the matrix elements to strings
-- IMPROVE: tostring v.s. tostringelements confusing
function matrix.elementstostrings( mtx )
local e = mtx[1][1]
local tostring = type(e) == "table" and e.tostring or tostring
return matrix.replace(mtx, tostring)
end
--// matrix.solve ( m1 )
-- solve; tries to solve a symbolic matrix to a number
function matrix.solve( m1 )
assert( matrix.type( m1 ) == "symbol", "matrix not of type 'symbol'" )
local mtx = {}
for i = 1,#m1 do
mtx[i] = {}
for j = 1,#m1[1] do
mtx[i][j] = tonumber( loadstring( "return "..m1[i][j][1] )() )
end
end
return setmetatable( mtx, matrix_meta )
end
--////////////////////////--
--// METATABLE HANDLING //--
--////////////////////////--
--// MetaTable
-- as we declaired on top of the page
-- local/shared metatable
-- matrix_meta
-- note '...' is always faster than 'arg1,arg2,...' if it can be used
-- Set add "+" behaviour
matrix_meta.__add = function( ... )
return matrix.add( ... )
end
-- Set subtract "-" behaviour
matrix_meta.__sub = function( ... )
return matrix.sub( ... )
end
-- Set multiply "*" behaviour
matrix_meta.__mul = function( m1,m2 )
if getmetatable( m1 ) ~= matrix_meta then
return matrix.mulnum( m2,m1 )
elseif getmetatable( m2 ) ~= matrix_meta then
return matrix.mulnum( m1,m2 )
end
return matrix.mul( m1,m2 )
end
-- Set division "/" behaviour
matrix_meta.__div = function( m1,m2 )
if getmetatable( m1 ) ~= matrix_meta then
return matrix.mulnum( matrix.invert(m2),m1 )
elseif getmetatable( m2 ) ~= matrix_meta then
return matrix.divnum( m1,m2 )
end
return matrix.div( m1,m2 )
end
-- Set unary minus "-" behavior
matrix_meta.__unm = function( mtx )
return matrix.mulnum( mtx,-1 )
end
-- Set power "^" behaviour
-- if opt is any integer number will do mtx^opt
-- (returning nil if answer doesn't exist)
-- if opt is 'T' then it will return the transpose matrix
-- only for complex:
-- if opt is '*' then it returns the complex conjugate matrix
local option = {
-- only for complex
["*"] = function( m1 ) return matrix.conjugate( m1 ) end,
-- for both
["T"] = function( m1 ) return matrix.transpose( m1 ) end,
}
matrix_meta.__pow = function( m1, opt )
return option[opt] and option[opt]( m1 ) or matrix.pow( m1,opt )
end
-- Set equal "==" behaviour
matrix_meta.__eq = function( m1, m2 )
-- check same type
if matrix.type( m1 ) ~= matrix.type( m2 ) then
return false
end
-- check same size
if #m1 ~= #m2 or #m1[1] ~= #m2[1] then
return false
end
-- check elements equal
for i = 1,#m1 do
for j = 1,#m1[1] do
if m1[i][j] ~= m2[i][j] then
return false
end
end
end
return true
end
-- Set tostring "tostring( mtx )" behaviour
matrix_meta.__tostring = function( ... )
return matrix.tostring( ... )
end
-- set __call "mtx( [formatstr] )" behaviour, mtx [, formatstr]
matrix_meta.__call = function( ... )
matrix.print( ... )
end
--// __index handling
matrix_meta.__index = {}
for k,v in pairs( matrix ) do
matrix_meta.__index[k] = v
end
--/////////////////////////////////
--// symbol class implementation
--/////////////////////////////////
-- access to the symbolic metatable
local symbol_meta = {}; symbol_meta.__index = symbol_meta
local symbol = symbol_meta
function symbol_meta.new(o)
return setmetatable({tostring(o)}, symbol_meta)
end
symbol_meta.to = symbol_meta.new
-- symbol( arg )
-- same as symbol.to( arg )
-- set __call behaviour of symbol
setmetatable( symbol_meta, { __call = function( _,s ) return symbol_meta.to( s ) end } )
-- Converts object to string, optionally with formatting.
function symbol_meta.tostring( e,fstr )
return string.format( fstr,e[1] )
end
-- Returns "symbol" if object is a symbol type, else nothing.
function symbol_meta:type()
if getmetatable(self) == symbol_meta then
return "symbol"
end
end
-- Performs string.gsub on symbol.
-- for use in matrix.replace
function symbol_meta:gsub(from, to)
return symbol.to( string.gsub( self[1],from,to ) )
end
-- creates function that replaces one letter by something else
-- makereplacer( "a",4,"b",7, ... )(x)
-- will replace a with 4 and b with 7 in symbol x.
-- for use in matrix.replace
function symbol_meta.makereplacer( ... )
local tosub = {}
local args = {...}
for i = 1,#args,2 do
tosub[args[i]] = args[i+1]
end
local function func( a ) return tosub[a] or a end
return function(sym)
return symbol.to( string.gsub( sym[1], "%a", func ) )
end
end
-- applies abs function to symbol
function symbol_meta.abs(a)
return symbol.to("(" .. a[1] .. "):abs()")
end
-- applies sqrt function to symbol
function symbol_meta.sqrt(a)
return symbol.to("(" .. a[1] .. "):sqrt()")
end
function symbol_meta.__add(a,b)
return symbol.to(a .. "+" .. b)
end
function symbol_meta.__sub(a,b)
return symbol.to(a .. "-" .. b)
end
function symbol_meta.__mul(a,b)
return symbol.to("(" .. a .. ")*(" .. b .. ")")
end
function symbol_meta.__div(a,b)
return symbol.to("(" .. a .. ")/(" .. b .. ")")
end
function symbol_meta.__pow(a,b)
return symbol.to("(" .. a .. ")^(" .. b .. ")")
end
function symbol_meta.__eq(a,b)
return a[1] == b[1]
end
function symbol_meta.__tostring(a)
return a[1]
end
function symbol_meta.__concat(a,b)
return tostring(a) .. tostring(b)
end
matrix.symbol = symbol
-- return matrix
return matrix
--///////////////--
--// chillcode //--
--///////////////--