1000 lines
32 KiB
Python
1000 lines
32 KiB
Python
![]() |
from typing import Tuple as tTuple
|
||
|
|
||
|
from sympy.core.logic import FuzzyBool
|
||
|
|
||
|
from functools import wraps, reduce
|
||
|
import collections
|
||
|
|
||
|
from sympy.core import S, Symbol, Integer, Basic, Expr, Mul, Add
|
||
|
from sympy.core.decorators import call_highest_priority
|
||
|
from sympy.core.compatibility import SYMPY_INTS, default_sort_key
|
||
|
from sympy.core.symbol import Str
|
||
|
from sympy.core.sympify import SympifyError, _sympify
|
||
|
from sympy.functions import conjugate, adjoint
|
||
|
from sympy.functions.special.tensor_functions import KroneckerDelta
|
||
|
from sympy.matrices.common import NonSquareMatrixError
|
||
|
from sympy.simplify import simplify
|
||
|
from sympy.matrices.matrices import MatrixKind
|
||
|
from sympy.utilities.misc import filldedent
|
||
|
from sympy.multipledispatch import dispatch
|
||
|
|
||
|
|
||
|
def _sympifyit(arg, retval=None):
|
||
|
# This version of _sympifyit sympifies MutableMatrix objects
|
||
|
def deco(func):
|
||
|
@wraps(func)
|
||
|
def __sympifyit_wrapper(a, b):
|
||
|
try:
|
||
|
b = _sympify(b)
|
||
|
return func(a, b)
|
||
|
except SympifyError:
|
||
|
return retval
|
||
|
|
||
|
return __sympifyit_wrapper
|
||
|
|
||
|
return deco
|
||
|
|
||
|
|
||
|
class MatrixExpr(Expr):
|
||
|
"""Superclass for Matrix Expressions
|
||
|
|
||
|
MatrixExprs represent abstract matrices, linear transformations represented
|
||
|
within a particular basis.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import MatrixSymbol
|
||
|
>>> A = MatrixSymbol('A', 3, 3)
|
||
|
>>> y = MatrixSymbol('y', 3, 1)
|
||
|
>>> x = (A.T*A).I * A * y
|
||
|
|
||
|
See Also
|
||
|
========
|
||
|
|
||
|
MatrixSymbol, MatAdd, MatMul, Transpose, Inverse
|
||
|
"""
|
||
|
|
||
|
# Should not be considered iterable by the
|
||
|
# sympy.core.compatibility.iterable function. Subclass that actually are
|
||
|
# iterable (i.e., explicit matrices) should set this to True.
|
||
|
_iterable = False
|
||
|
|
||
|
_op_priority = 11.0
|
||
|
|
||
|
is_Matrix = True # type: bool
|
||
|
is_MatrixExpr = True # type: bool
|
||
|
is_Identity = None # type: FuzzyBool
|
||
|
is_Inverse = False
|
||
|
is_Transpose = False
|
||
|
is_ZeroMatrix = False
|
||
|
is_MatAdd = False
|
||
|
is_MatMul = False
|
||
|
|
||
|
is_commutative = False
|
||
|
is_number = False
|
||
|
is_symbol = False
|
||
|
is_scalar = False
|
||
|
|
||
|
kind = MatrixKind()
|
||
|
|
||
|
def __new__(cls, *args, **kwargs):
|
||
|
args = map(_sympify, args)
|
||
|
return Basic.__new__(cls, *args, **kwargs)
|
||
|
|
||
|
# The following is adapted from the core Expr object
|
||
|
|
||
|
@property
|
||
|
def shape(self) -> tTuple[Expr, Expr]:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@property
|
||
|
def _add_handler(self):
|
||
|
return MatAdd
|
||
|
|
||
|
@property
|
||
|
def _mul_handler(self):
|
||
|
return MatMul
|
||
|
|
||
|
def __neg__(self):
|
||
|
return MatMul(S.NegativeOne, self).doit()
|
||
|
|
||
|
def __abs__(self):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@_sympifyit('other', NotImplemented)
|
||
|
@call_highest_priority('__radd__')
|
||
|
def __add__(self, other):
|
||
|
return MatAdd(self, other, check=True).doit()
|
||
|
|
||
|
@_sympifyit('other', NotImplemented)
|
||
|
@call_highest_priority('__add__')
|
||
|
def __radd__(self, other):
|
||
|
return MatAdd(other, self, check=True).doit()
|
||
|
|
||
|
@_sympifyit('other', NotImplemented)
|
||
|
@call_highest_priority('__rsub__')
|
||
|
def __sub__(self, other):
|
||
|
return MatAdd(self, -other, check=True).doit()
|
||
|
|
||
|
@_sympifyit('other', NotImplemented)
|
||
|
@call_highest_priority('__sub__')
|
||
|
def __rsub__(self, other):
|
||
|
return MatAdd(other, -self, check=True).doit()
|
||
|
|
||
|
@_sympifyit('other', NotImplemented)
|
||
|
@call_highest_priority('__rmul__')
|
||
|
def __mul__(self, other):
|
||
|
return MatMul(self, other).doit()
|
||
|
|
||
|
@_sympifyit('other', NotImplemented)
|
||
|
@call_highest_priority('__rmul__')
|
||
|
def __matmul__(self, other):
|
||
|
return MatMul(self, other).doit()
|
||
|
|
||
|
@_sympifyit('other', NotImplemented)
|
||
|
@call_highest_priority('__mul__')
|
||
|
def __rmul__(self, other):
|
||
|
return MatMul(other, self).doit()
|
||
|
|
||
|
@_sympifyit('other', NotImplemented)
|
||
|
@call_highest_priority('__mul__')
|
||
|
def __rmatmul__(self, other):
|
||
|
return MatMul(other, self).doit()
|
||
|
|
||
|
@_sympifyit('other', NotImplemented)
|
||
|
@call_highest_priority('__rpow__')
|
||
|
def __pow__(self, other):
|
||
|
return MatPow(self, other).doit()
|
||
|
|
||
|
@_sympifyit('other', NotImplemented)
|
||
|
@call_highest_priority('__pow__')
|
||
|
def __rpow__(self, other):
|
||
|
raise NotImplementedError("Matrix Power not defined")
|
||
|
|
||
|
@_sympifyit('other', NotImplemented)
|
||
|
@call_highest_priority('__rtruediv__')
|
||
|
def __truediv__(self, other):
|
||
|
return self * other**S.NegativeOne
|
||
|
|
||
|
@_sympifyit('other', NotImplemented)
|
||
|
@call_highest_priority('__truediv__')
|
||
|
def __rtruediv__(self, other):
|
||
|
raise NotImplementedError()
|
||
|
#return MatMul(other, Pow(self, S.NegativeOne))
|
||
|
|
||
|
@property
|
||
|
def rows(self):
|
||
|
return self.shape[0]
|
||
|
|
||
|
@property
|
||
|
def cols(self):
|
||
|
return self.shape[1]
|
||
|
|
||
|
@property
|
||
|
def is_square(self):
|
||
|
return self.rows == self.cols
|
||
|
|
||
|
def _eval_conjugate(self):
|
||
|
from sympy.matrices.expressions.adjoint import Adjoint
|
||
|
from sympy.matrices.expressions.transpose import Transpose
|
||
|
return Adjoint(Transpose(self))
|
||
|
|
||
|
def as_real_imag(self, deep=True, **hints):
|
||
|
from sympy import I
|
||
|
real = S.Half * (self + self._eval_conjugate())
|
||
|
im = (self - self._eval_conjugate())/(2*I)
|
||
|
return (real, im)
|
||
|
|
||
|
def _eval_inverse(self):
|
||
|
from sympy.matrices.expressions.inverse import Inverse
|
||
|
return Inverse(self)
|
||
|
|
||
|
def _eval_transpose(self):
|
||
|
return Transpose(self)
|
||
|
|
||
|
def _eval_power(self, exp):
|
||
|
"""
|
||
|
Override this in sub-classes to implement simplification of powers. The cases where the exponent
|
||
|
is -1, 0, 1 are already covered in MatPow.doit(), so implementations can exclude these cases.
|
||
|
"""
|
||
|
return MatPow(self, exp)
|
||
|
|
||
|
def _eval_simplify(self, **kwargs):
|
||
|
if self.is_Atom:
|
||
|
return self
|
||
|
else:
|
||
|
return self.func(*[simplify(x, **kwargs) for x in self.args])
|
||
|
|
||
|
def _eval_adjoint(self):
|
||
|
from sympy.matrices.expressions.adjoint import Adjoint
|
||
|
return Adjoint(self)
|
||
|
|
||
|
def _eval_derivative_n_times(self, x, n):
|
||
|
return Basic._eval_derivative_n_times(self, x, n)
|
||
|
|
||
|
def _eval_derivative(self, x):
|
||
|
# `x` is a scalar:
|
||
|
if self.has(x):
|
||
|
# See if there are other methods using it:
|
||
|
return super()._eval_derivative(x)
|
||
|
else:
|
||
|
return ZeroMatrix(*self.shape)
|
||
|
|
||
|
@classmethod
|
||
|
def _check_dim(cls, dim):
|
||
|
"""Helper function to check invalid matrix dimensions"""
|
||
|
from sympy.core.assumptions import check_assumptions
|
||
|
ok = check_assumptions(dim, integer=True, nonnegative=True)
|
||
|
if ok is False:
|
||
|
raise ValueError(
|
||
|
"The dimension specification {} should be "
|
||
|
"a nonnegative integer.".format(dim))
|
||
|
|
||
|
|
||
|
def _entry(self, i, j, **kwargs):
|
||
|
raise NotImplementedError(
|
||
|
"Indexing not implemented for %s" % self.__class__.__name__)
|
||
|
|
||
|
def adjoint(self):
|
||
|
return adjoint(self)
|
||
|
|
||
|
def as_coeff_Mul(self, rational=False):
|
||
|
"""Efficiently extract the coefficient of a product. """
|
||
|
return S.One, self
|
||
|
|
||
|
def conjugate(self):
|
||
|
return conjugate(self)
|
||
|
|
||
|
def transpose(self):
|
||
|
from sympy.matrices.expressions.transpose import transpose
|
||
|
return transpose(self)
|
||
|
|
||
|
@property
|
||
|
def T(self):
|
||
|
'''Matrix transposition'''
|
||
|
return self.transpose()
|
||
|
|
||
|
def inverse(self):
|
||
|
if not self.is_square:
|
||
|
raise NonSquareMatrixError('Inverse of non-square matrix')
|
||
|
return self._eval_inverse()
|
||
|
|
||
|
def inv(self):
|
||
|
return self.inverse()
|
||
|
|
||
|
@property
|
||
|
def I(self):
|
||
|
return self.inverse()
|
||
|
|
||
|
def valid_index(self, i, j):
|
||
|
def is_valid(idx):
|
||
|
return isinstance(idx, (int, Integer, Symbol, Expr))
|
||
|
return (is_valid(i) and is_valid(j) and
|
||
|
(self.rows is None or
|
||
|
(0 <= i) != False and (i < self.rows) != False) and
|
||
|
(0 <= j) != False and (j < self.cols) != False)
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
if not isinstance(key, tuple) and isinstance(key, slice):
|
||
|
from sympy.matrices.expressions.slice import MatrixSlice
|
||
|
return MatrixSlice(self, key, (0, None, 1))
|
||
|
if isinstance(key, tuple) and len(key) == 2:
|
||
|
i, j = key
|
||
|
if isinstance(i, slice) or isinstance(j, slice):
|
||
|
from sympy.matrices.expressions.slice import MatrixSlice
|
||
|
return MatrixSlice(self, i, j)
|
||
|
i, j = _sympify(i), _sympify(j)
|
||
|
if self.valid_index(i, j) != False:
|
||
|
return self._entry(i, j)
|
||
|
else:
|
||
|
raise IndexError("Invalid indices (%s, %s)" % (i, j))
|
||
|
elif isinstance(key, (SYMPY_INTS, Integer)):
|
||
|
# row-wise decomposition of matrix
|
||
|
rows, cols = self.shape
|
||
|
# allow single indexing if number of columns is known
|
||
|
if not isinstance(cols, Integer):
|
||
|
raise IndexError(filldedent('''
|
||
|
Single indexing is only supported when the number
|
||
|
of columns is known.'''))
|
||
|
key = _sympify(key)
|
||
|
i = key // cols
|
||
|
j = key % cols
|
||
|
if self.valid_index(i, j) != False:
|
||
|
return self._entry(i, j)
|
||
|
else:
|
||
|
raise IndexError("Invalid index %s" % key)
|
||
|
elif isinstance(key, (Symbol, Expr)):
|
||
|
raise IndexError(filldedent('''
|
||
|
Only integers may be used when addressing the matrix
|
||
|
with a single index.'''))
|
||
|
raise IndexError("Invalid index, wanted %s[i,j]" % self)
|
||
|
|
||
|
def as_explicit(self):
|
||
|
"""
|
||
|
Returns a dense Matrix with elements represented explicitly
|
||
|
|
||
|
Returns an object of type ImmutableDenseMatrix.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import Identity
|
||
|
>>> I = Identity(3)
|
||
|
>>> I
|
||
|
I
|
||
|
>>> I.as_explicit()
|
||
|
Matrix([
|
||
|
[1, 0, 0],
|
||
|
[0, 1, 0],
|
||
|
[0, 0, 1]])
|
||
|
|
||
|
See Also
|
||
|
========
|
||
|
as_mutable: returns mutable Matrix type
|
||
|
|
||
|
"""
|
||
|
if (not isinstance(self.rows, (SYMPY_INTS, Integer))
|
||
|
or not isinstance(self.cols, (SYMPY_INTS, Integer))):
|
||
|
raise ValueError(
|
||
|
'Matrix with symbolic shape '
|
||
|
'cannot be represented explicitly.')
|
||
|
from sympy.matrices.immutable import ImmutableDenseMatrix
|
||
|
return ImmutableDenseMatrix([[self[i, j]
|
||
|
for j in range(self.cols)]
|
||
|
for i in range(self.rows)])
|
||
|
|
||
|
def as_mutable(self):
|
||
|
"""
|
||
|
Returns a dense, mutable matrix with elements represented explicitly
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import Identity
|
||
|
>>> I = Identity(3)
|
||
|
>>> I
|
||
|
I
|
||
|
>>> I.shape
|
||
|
(3, 3)
|
||
|
>>> I.as_mutable()
|
||
|
Matrix([
|
||
|
[1, 0, 0],
|
||
|
[0, 1, 0],
|
||
|
[0, 0, 1]])
|
||
|
|
||
|
See Also
|
||
|
========
|
||
|
as_explicit: returns ImmutableDenseMatrix
|
||
|
"""
|
||
|
return self.as_explicit().as_mutable()
|
||
|
|
||
|
def __array__(self):
|
||
|
from numpy import empty
|
||
|
a = empty(self.shape, dtype=object)
|
||
|
for i in range(self.rows):
|
||
|
for j in range(self.cols):
|
||
|
a[i, j] = self[i, j]
|
||
|
return a
|
||
|
|
||
|
def equals(self, other):
|
||
|
"""
|
||
|
Test elementwise equality between matrices, potentially of different
|
||
|
types
|
||
|
|
||
|
>>> from sympy import Identity, eye
|
||
|
>>> Identity(3).equals(eye(3))
|
||
|
True
|
||
|
"""
|
||
|
return self.as_explicit().equals(other)
|
||
|
|
||
|
def canonicalize(self):
|
||
|
return self
|
||
|
|
||
|
def as_coeff_mmul(self):
|
||
|
return 1, MatMul(self)
|
||
|
|
||
|
@staticmethod
|
||
|
def from_index_summation(expr, first_index=None, last_index=None, dimensions=None):
|
||
|
r"""
|
||
|
Parse expression of matrices with explicitly summed indices into a
|
||
|
matrix expression without indices, if possible.
|
||
|
|
||
|
This transformation expressed in mathematical notation:
|
||
|
|
||
|
`\sum_{j=0}^{N-1} A_{i,j} B_{j,k} \Longrightarrow \mathbf{A}\cdot \mathbf{B}`
|
||
|
|
||
|
Optional parameter ``first_index``: specify which free index to use as
|
||
|
the index starting the expression.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import MatrixSymbol, MatrixExpr, Sum
|
||
|
>>> from sympy.abc import i, j, k, l, N
|
||
|
>>> A = MatrixSymbol("A", N, N)
|
||
|
>>> B = MatrixSymbol("B", N, N)
|
||
|
>>> expr = Sum(A[i, j]*B[j, k], (j, 0, N-1))
|
||
|
>>> MatrixExpr.from_index_summation(expr)
|
||
|
A*B
|
||
|
|
||
|
Transposition is detected:
|
||
|
|
||
|
>>> expr = Sum(A[j, i]*B[j, k], (j, 0, N-1))
|
||
|
>>> MatrixExpr.from_index_summation(expr)
|
||
|
A.T*B
|
||
|
|
||
|
Detect the trace:
|
||
|
|
||
|
>>> expr = Sum(A[i, i], (i, 0, N-1))
|
||
|
>>> MatrixExpr.from_index_summation(expr)
|
||
|
Trace(A)
|
||
|
|
||
|
More complicated expressions:
|
||
|
|
||
|
>>> expr = Sum(A[i, j]*B[k, j]*A[l, k], (j, 0, N-1), (k, 0, N-1))
|
||
|
>>> MatrixExpr.from_index_summation(expr)
|
||
|
A*B.T*A.T
|
||
|
"""
|
||
|
from sympy import Sum, Mul, Add, MatMul, transpose, trace
|
||
|
from sympy.strategies.traverse import bottom_up
|
||
|
|
||
|
def remove_matelement(expr, i1, i2):
|
||
|
|
||
|
def repl_match(pos):
|
||
|
def func(x):
|
||
|
if not isinstance(x, MatrixElement):
|
||
|
return False
|
||
|
if x.args[pos] != i1:
|
||
|
return False
|
||
|
if x.args[3-pos] == 0:
|
||
|
if x.args[0].shape[2-pos] == 1:
|
||
|
return True
|
||
|
else:
|
||
|
return False
|
||
|
return True
|
||
|
return func
|
||
|
|
||
|
expr = expr.replace(repl_match(1),
|
||
|
lambda x: x.args[0])
|
||
|
expr = expr.replace(repl_match(2),
|
||
|
lambda x: transpose(x.args[0]))
|
||
|
|
||
|
# Make sure that all Mul are transformed to MatMul and that they
|
||
|
# are flattened:
|
||
|
rule = bottom_up(lambda x: reduce(lambda a, b: a*b, x.args) if isinstance(x, (Mul, MatMul)) else x)
|
||
|
return rule(expr)
|
||
|
|
||
|
def recurse_expr(expr, index_ranges={}):
|
||
|
if expr.is_Mul:
|
||
|
nonmatargs = []
|
||
|
pos_arg = []
|
||
|
pos_ind = []
|
||
|
dlinks = {}
|
||
|
link_ind = []
|
||
|
counter = 0
|
||
|
args_ind = []
|
||
|
for arg in expr.args:
|
||
|
retvals = recurse_expr(arg, index_ranges)
|
||
|
assert isinstance(retvals, list)
|
||
|
if isinstance(retvals, list):
|
||
|
for i in retvals:
|
||
|
args_ind.append(i)
|
||
|
else:
|
||
|
args_ind.append(retvals)
|
||
|
for arg_symbol, arg_indices in args_ind:
|
||
|
if arg_indices is None:
|
||
|
nonmatargs.append(arg_symbol)
|
||
|
continue
|
||
|
if isinstance(arg_symbol, MatrixElement):
|
||
|
arg_symbol = arg_symbol.args[0]
|
||
|
pos_arg.append(arg_symbol)
|
||
|
pos_ind.append(arg_indices)
|
||
|
link_ind.append([None]*len(arg_indices))
|
||
|
for i, ind in enumerate(arg_indices):
|
||
|
if ind in dlinks:
|
||
|
other_i = dlinks[ind]
|
||
|
link_ind[counter][i] = other_i
|
||
|
link_ind[other_i[0]][other_i[1]] = (counter, i)
|
||
|
dlinks[ind] = (counter, i)
|
||
|
counter += 1
|
||
|
counter2 = 0
|
||
|
lines = {}
|
||
|
while counter2 < len(link_ind):
|
||
|
for i, e in enumerate(link_ind):
|
||
|
if None in e:
|
||
|
line_start_index = (i, e.index(None))
|
||
|
break
|
||
|
cur_ind_pos = line_start_index
|
||
|
cur_line = []
|
||
|
index1 = pos_ind[cur_ind_pos[0]][cur_ind_pos[1]]
|
||
|
while True:
|
||
|
d, r = cur_ind_pos
|
||
|
if pos_arg[d] != 1:
|
||
|
if r % 2 == 1:
|
||
|
cur_line.append(transpose(pos_arg[d]))
|
||
|
else:
|
||
|
cur_line.append(pos_arg[d])
|
||
|
next_ind_pos = link_ind[d][1-r]
|
||
|
counter2 += 1
|
||
|
# Mark as visited, there will be no `None` anymore:
|
||
|
link_ind[d] = (-1, -1)
|
||
|
if next_ind_pos is None:
|
||
|
index2 = pos_ind[d][1-r]
|
||
|
lines[(index1, index2)] = cur_line
|
||
|
break
|
||
|
cur_ind_pos = next_ind_pos
|
||
|
lines = {k: MatMul.fromiter(v) if len(v) != 1 else v[0] for k, v in lines.items()}
|
||
|
return [(Mul.fromiter(nonmatargs), None)] + [
|
||
|
(MatrixElement(a, i, j), (i, j)) for (i, j), a in lines.items()
|
||
|
]
|
||
|
elif expr.is_Add:
|
||
|
res = [recurse_expr(i) for i in expr.args]
|
||
|
d = collections.defaultdict(list)
|
||
|
for res_addend in res:
|
||
|
scalar = 1
|
||
|
for elem, indices in res_addend:
|
||
|
if indices is None:
|
||
|
scalar = elem
|
||
|
continue
|
||
|
indices = tuple(sorted(indices, key=default_sort_key))
|
||
|
d[indices].append(scalar*remove_matelement(elem, *indices))
|
||
|
scalar = 1
|
||
|
return [(MatrixElement(Add.fromiter(v), *k), k) for k, v in d.items()]
|
||
|
elif isinstance(expr, KroneckerDelta):
|
||
|
i1, i2 = expr.args
|
||
|
shape = dimensions
|
||
|
if shape is None:
|
||
|
shape = []
|
||
|
for kr_ind in expr.args:
|
||
|
if kr_ind not in index_ranges:
|
||
|
continue
|
||
|
r1, r2 = index_ranges[kr_ind]
|
||
|
if r1 != 0:
|
||
|
raise ValueError(f"index ranges should start from zero: {index_ranges}")
|
||
|
shape.append(r2)
|
||
|
if len(shape) == 0:
|
||
|
shape = None
|
||
|
elif len(shape) == 1:
|
||
|
shape = (shape[0] + 1, shape[0] + 1)
|
||
|
else:
|
||
|
shape = (shape[0] + 1, shape[1] + 1)
|
||
|
if shape[0] != shape[1]:
|
||
|
raise ValueError(f"upper index ranges should be equal: {index_ranges}")
|
||
|
|
||
|
identity = Identity(shape[0])
|
||
|
return [(MatrixElement(identity, i1, i2), (i1, i2))]
|
||
|
elif isinstance(expr, MatrixElement):
|
||
|
matrix_symbol, i1, i2 = expr.args
|
||
|
if i1 in index_ranges:
|
||
|
r1, r2 = index_ranges[i1]
|
||
|
if r1 != 0 or matrix_symbol.shape[0] != r2+1:
|
||
|
raise ValueError("index range mismatch: {} vs. (0, {})".format(
|
||
|
(r1, r2), matrix_symbol.shape[0]))
|
||
|
if i2 in index_ranges:
|
||
|
r1, r2 = index_ranges[i2]
|
||
|
if r1 != 0 or matrix_symbol.shape[1] != r2+1:
|
||
|
raise ValueError("index range mismatch: {} vs. (0, {})".format(
|
||
|
(r1, r2), matrix_symbol.shape[1]))
|
||
|
if (i1 == i2) and (i1 in index_ranges):
|
||
|
return [(trace(matrix_symbol), None)]
|
||
|
return [(MatrixElement(matrix_symbol, i1, i2), (i1, i2))]
|
||
|
elif isinstance(expr, Sum):
|
||
|
return recurse_expr(
|
||
|
expr.args[0],
|
||
|
index_ranges={i[0]: i[1:] for i in expr.args[1:]}
|
||
|
)
|
||
|
else:
|
||
|
return [(expr, None)]
|
||
|
|
||
|
retvals = recurse_expr(expr)
|
||
|
factors, indices = zip(*retvals)
|
||
|
retexpr = Mul.fromiter(factors)
|
||
|
if len(indices) == 0 or list(set(indices)) == [None]:
|
||
|
return retexpr
|
||
|
if first_index is None:
|
||
|
for i in indices:
|
||
|
if i is not None:
|
||
|
ind0 = i
|
||
|
break
|
||
|
return remove_matelement(retexpr, *ind0)
|
||
|
else:
|
||
|
return remove_matelement(retexpr, first_index, last_index)
|
||
|
|
||
|
def applyfunc(self, func):
|
||
|
from .applyfunc import ElementwiseApplyFunction
|
||
|
return ElementwiseApplyFunction(func, self)
|
||
|
|
||
|
@dispatch(MatrixExpr, Expr)
|
||
|
def _eval_is_eq(lhs, rhs): # noqa:F811
|
||
|
return False
|
||
|
|
||
|
@dispatch(MatrixExpr, MatrixExpr) # type: ignore
|
||
|
def _eval_is_eq(lhs, rhs): # noqa:F811
|
||
|
if lhs.shape != rhs.shape:
|
||
|
return False
|
||
|
if (lhs - rhs).is_ZeroMatrix:
|
||
|
return True
|
||
|
|
||
|
def get_postprocessor(cls):
|
||
|
def _postprocessor(expr):
|
||
|
# To avoid circular imports, we can't have MatMul/MatAdd on the top level
|
||
|
mat_class = {Mul: MatMul, Add: MatAdd}[cls]
|
||
|
nonmatrices = []
|
||
|
matrices = []
|
||
|
for term in expr.args:
|
||
|
if isinstance(term, MatrixExpr):
|
||
|
matrices.append(term)
|
||
|
else:
|
||
|
nonmatrices.append(term)
|
||
|
|
||
|
if not matrices:
|
||
|
return cls._from_args(nonmatrices)
|
||
|
|
||
|
if nonmatrices:
|
||
|
if cls == Mul:
|
||
|
for i in range(len(matrices)):
|
||
|
if not matrices[i].is_MatrixExpr:
|
||
|
# If one of the matrices explicit, absorb the scalar into it
|
||
|
# (doit will combine all explicit matrices into one, so it
|
||
|
# doesn't matter which)
|
||
|
matrices[i] = matrices[i].__mul__(cls._from_args(nonmatrices))
|
||
|
nonmatrices = []
|
||
|
break
|
||
|
|
||
|
else:
|
||
|
# Maintain the ability to create Add(scalar, matrix) without
|
||
|
# raising an exception. That way different algorithms can
|
||
|
# replace matrix expressions with non-commutative symbols to
|
||
|
# manipulate them like non-commutative scalars.
|
||
|
return cls._from_args(nonmatrices + [mat_class(*matrices).doit(deep=False)])
|
||
|
|
||
|
if mat_class == MatAdd:
|
||
|
return mat_class(*matrices).doit(deep=False)
|
||
|
return mat_class(cls._from_args(nonmatrices), *matrices).doit(deep=False)
|
||
|
return _postprocessor
|
||
|
|
||
|
|
||
|
Basic._constructor_postprocessor_mapping[MatrixExpr] = {
|
||
|
"Mul": [get_postprocessor(Mul)],
|
||
|
"Add": [get_postprocessor(Add)],
|
||
|
}
|
||
|
|
||
|
|
||
|
def _matrix_derivative(expr, x):
|
||
|
from sympy.tensor.array.array_derivatives import ArrayDerivative
|
||
|
lines = expr._eval_derivative_matrix_lines(x)
|
||
|
|
||
|
parts = [i.build() for i in lines]
|
||
|
|
||
|
from sympy.tensor.array.expressions.conv_array_to_matrix import convert_array_to_matrix
|
||
|
|
||
|
parts = [[convert_array_to_matrix(j) for j in i] for i in parts]
|
||
|
|
||
|
def _get_shape(elem):
|
||
|
if isinstance(elem, MatrixExpr):
|
||
|
return elem.shape
|
||
|
return 1, 1
|
||
|
|
||
|
def get_rank(parts):
|
||
|
return sum([j not in (1, None) for i in parts for j in _get_shape(i)])
|
||
|
|
||
|
ranks = [get_rank(i) for i in parts]
|
||
|
rank = ranks[0]
|
||
|
|
||
|
def contract_one_dims(parts):
|
||
|
if len(parts) == 1:
|
||
|
return parts[0]
|
||
|
else:
|
||
|
p1, p2 = parts[:2]
|
||
|
if p2.is_Matrix:
|
||
|
p2 = p2.T
|
||
|
if p1 == Identity(1):
|
||
|
pbase = p2
|
||
|
elif p2 == Identity(1):
|
||
|
pbase = p1
|
||
|
else:
|
||
|
pbase = p1*p2
|
||
|
if len(parts) == 2:
|
||
|
return pbase
|
||
|
else: # len(parts) > 2
|
||
|
if pbase.is_Matrix:
|
||
|
raise ValueError("")
|
||
|
return pbase*Mul.fromiter(parts[2:])
|
||
|
|
||
|
if rank <= 2:
|
||
|
return Add.fromiter([contract_one_dims(i) for i in parts])
|
||
|
|
||
|
return ArrayDerivative(expr, x)
|
||
|
|
||
|
|
||
|
class MatrixElement(Expr):
|
||
|
parent = property(lambda self: self.args[0])
|
||
|
i = property(lambda self: self.args[1])
|
||
|
j = property(lambda self: self.args[2])
|
||
|
_diff_wrt = True
|
||
|
is_symbol = True
|
||
|
is_commutative = True
|
||
|
|
||
|
def __new__(cls, name, n, m):
|
||
|
n, m = map(_sympify, (n, m))
|
||
|
from sympy import MatrixBase
|
||
|
if isinstance(name, (MatrixBase,)):
|
||
|
if n.is_Integer and m.is_Integer:
|
||
|
return name[n, m]
|
||
|
if isinstance(name, str):
|
||
|
name = Symbol(name)
|
||
|
else:
|
||
|
name = _sympify(name)
|
||
|
if not isinstance(name.kind, MatrixKind):
|
||
|
raise TypeError("First argument of MatrixElement should be a matrix")
|
||
|
obj = Expr.__new__(cls, name, n, m)
|
||
|
return obj
|
||
|
|
||
|
def doit(self, **kwargs):
|
||
|
deep = kwargs.get('deep', True)
|
||
|
if deep:
|
||
|
args = [arg.doit(**kwargs) for arg in self.args]
|
||
|
else:
|
||
|
args = self.args
|
||
|
return args[0][args[1], args[2]]
|
||
|
|
||
|
@property
|
||
|
def indices(self):
|
||
|
return self.args[1:]
|
||
|
|
||
|
def _eval_derivative(self, v):
|
||
|
from sympy import Sum, symbols, Dummy
|
||
|
|
||
|
if not isinstance(v, MatrixElement):
|
||
|
from sympy import MatrixBase
|
||
|
if isinstance(self.parent, MatrixBase):
|
||
|
return self.parent.diff(v)[self.i, self.j]
|
||
|
return S.Zero
|
||
|
|
||
|
M = self.args[0]
|
||
|
|
||
|
m, n = self.parent.shape
|
||
|
|
||
|
if M == v.args[0]:
|
||
|
return KroneckerDelta(self.args[1], v.args[1], (0, m-1)) * \
|
||
|
KroneckerDelta(self.args[2], v.args[2], (0, n-1))
|
||
|
|
||
|
if isinstance(M, Inverse):
|
||
|
i, j = self.args[1:]
|
||
|
i1, i2 = symbols("z1, z2", cls=Dummy)
|
||
|
Y = M.args[0]
|
||
|
r1, r2 = Y.shape
|
||
|
return -Sum(M[i, i1]*Y[i1, i2].diff(v)*M[i2, j], (i1, 0, r1-1), (i2, 0, r2-1))
|
||
|
|
||
|
if self.has(v.args[0]):
|
||
|
return None
|
||
|
|
||
|
return S.Zero
|
||
|
|
||
|
|
||
|
class MatrixSymbol(MatrixExpr):
|
||
|
"""Symbolic representation of a Matrix object
|
||
|
|
||
|
Creates a SymPy Symbol to represent a Matrix. This matrix has a shape and
|
||
|
can be included in Matrix Expressions
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import MatrixSymbol, Identity
|
||
|
>>> A = MatrixSymbol('A', 3, 4) # A 3 by 4 Matrix
|
||
|
>>> B = MatrixSymbol('B', 4, 3) # A 4 by 3 Matrix
|
||
|
>>> A.shape
|
||
|
(3, 4)
|
||
|
>>> 2*A*B + Identity(3)
|
||
|
I + 2*A*B
|
||
|
"""
|
||
|
is_commutative = False
|
||
|
is_symbol = True
|
||
|
_diff_wrt = True
|
||
|
|
||
|
def __new__(cls, name, n, m):
|
||
|
n, m = _sympify(n), _sympify(m)
|
||
|
|
||
|
cls._check_dim(m)
|
||
|
cls._check_dim(n)
|
||
|
|
||
|
if isinstance(name, str):
|
||
|
name = Str(name)
|
||
|
obj = Basic.__new__(cls, name, n, m)
|
||
|
return obj
|
||
|
|
||
|
@property
|
||
|
def shape(self):
|
||
|
return self.args[1], self.args[2]
|
||
|
|
||
|
@property
|
||
|
def name(self):
|
||
|
return self.args[0].name
|
||
|
|
||
|
def _entry(self, i, j, **kwargs):
|
||
|
return MatrixElement(self, i, j)
|
||
|
|
||
|
@property
|
||
|
def free_symbols(self):
|
||
|
return {self}
|
||
|
|
||
|
def _eval_simplify(self, **kwargs):
|
||
|
return self
|
||
|
|
||
|
def _eval_derivative(self, x):
|
||
|
# x is a scalar:
|
||
|
return ZeroMatrix(self.shape[0], self.shape[1])
|
||
|
|
||
|
def _eval_derivative_matrix_lines(self, x):
|
||
|
if self != x:
|
||
|
first = ZeroMatrix(x.shape[0], self.shape[0]) if self.shape[0] != 1 else S.Zero
|
||
|
second = ZeroMatrix(x.shape[1], self.shape[1]) if self.shape[1] != 1 else S.Zero
|
||
|
return [_LeftRightArgs(
|
||
|
[first, second],
|
||
|
)]
|
||
|
else:
|
||
|
first = Identity(self.shape[0]) if self.shape[0] != 1 else S.One
|
||
|
second = Identity(self.shape[1]) if self.shape[1] != 1 else S.One
|
||
|
return [_LeftRightArgs(
|
||
|
[first, second],
|
||
|
)]
|
||
|
|
||
|
|
||
|
def matrix_symbols(expr):
|
||
|
return [sym for sym in expr.free_symbols if sym.is_Matrix]
|
||
|
|
||
|
|
||
|
class _LeftRightArgs:
|
||
|
r"""
|
||
|
Helper class to compute matrix derivatives.
|
||
|
|
||
|
The logic: when an expression is derived by a matrix `X_{mn}`, two lines of
|
||
|
matrix multiplications are created: the one contracted to `m` (first line),
|
||
|
and the one contracted to `n` (second line).
|
||
|
|
||
|
Transposition flips the side by which new matrices are connected to the
|
||
|
lines.
|
||
|
|
||
|
The trace connects the end of the two lines.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, lines, higher=S.One):
|
||
|
self._lines = [i for i in lines]
|
||
|
self._first_pointer_parent = self._lines
|
||
|
self._first_pointer_index = 0
|
||
|
self._first_line_index = 0
|
||
|
self._second_pointer_parent = self._lines
|
||
|
self._second_pointer_index = 1
|
||
|
self._second_line_index = 1
|
||
|
self.higher = higher
|
||
|
|
||
|
@property
|
||
|
def first_pointer(self):
|
||
|
return self._first_pointer_parent[self._first_pointer_index]
|
||
|
|
||
|
@first_pointer.setter
|
||
|
def first_pointer(self, value):
|
||
|
self._first_pointer_parent[self._first_pointer_index] = value
|
||
|
|
||
|
@property
|
||
|
def second_pointer(self):
|
||
|
return self._second_pointer_parent[self._second_pointer_index]
|
||
|
|
||
|
@second_pointer.setter
|
||
|
def second_pointer(self, value):
|
||
|
self._second_pointer_parent[self._second_pointer_index] = value
|
||
|
|
||
|
def __repr__(self):
|
||
|
built = [self._build(i) for i in self._lines]
|
||
|
return "_LeftRightArgs(lines=%s, higher=%s)" % (
|
||
|
built,
|
||
|
self.higher,
|
||
|
)
|
||
|
|
||
|
def transpose(self):
|
||
|
self._first_pointer_parent, self._second_pointer_parent = self._second_pointer_parent, self._first_pointer_parent
|
||
|
self._first_pointer_index, self._second_pointer_index = self._second_pointer_index, self._first_pointer_index
|
||
|
self._first_line_index, self._second_line_index = self._second_line_index, self._first_line_index
|
||
|
return self
|
||
|
|
||
|
@staticmethod
|
||
|
def _build(expr):
|
||
|
from sympy.core.expr import ExprBuilder
|
||
|
if isinstance(expr, ExprBuilder):
|
||
|
return expr.build()
|
||
|
if isinstance(expr, list):
|
||
|
if len(expr) == 1:
|
||
|
return expr[0]
|
||
|
else:
|
||
|
return expr[0](*[_LeftRightArgs._build(i) for i in expr[1]])
|
||
|
else:
|
||
|
return expr
|
||
|
|
||
|
def build(self):
|
||
|
data = [self._build(i) for i in self._lines]
|
||
|
if self.higher != 1:
|
||
|
data += [self._build(self.higher)]
|
||
|
data = [i for i in data]
|
||
|
return data
|
||
|
|
||
|
def matrix_form(self):
|
||
|
if self.first != 1 and self.higher != 1:
|
||
|
raise ValueError("higher dimensional array cannot be represented")
|
||
|
|
||
|
def _get_shape(elem):
|
||
|
if isinstance(elem, MatrixExpr):
|
||
|
return elem.shape
|
||
|
return (None, None)
|
||
|
|
||
|
if _get_shape(self.first)[1] != _get_shape(self.second)[1]:
|
||
|
# Remove one-dimensional identity matrices:
|
||
|
# (this is needed by `a.diff(a)` where `a` is a vector)
|
||
|
if _get_shape(self.second) == (1, 1):
|
||
|
return self.first*self.second[0, 0]
|
||
|
if _get_shape(self.first) == (1, 1):
|
||
|
return self.first[1, 1]*self.second.T
|
||
|
raise ValueError("incompatible shapes")
|
||
|
if self.first != 1:
|
||
|
return self.first*self.second.T
|
||
|
else:
|
||
|
return self.higher
|
||
|
|
||
|
def rank(self):
|
||
|
"""
|
||
|
Number of dimensions different from trivial (warning: not related to
|
||
|
matrix rank).
|
||
|
"""
|
||
|
rank = 0
|
||
|
if self.first != 1:
|
||
|
rank += sum([i != 1 for i in self.first.shape])
|
||
|
if self.second != 1:
|
||
|
rank += sum([i != 1 for i in self.second.shape])
|
||
|
if self.higher != 1:
|
||
|
rank += 2
|
||
|
return rank
|
||
|
|
||
|
def _multiply_pointer(self, pointer, other):
|
||
|
from sympy.core.expr import ExprBuilder
|
||
|
from ...tensor.array.expressions.array_expressions import ArrayTensorProduct
|
||
|
from ...tensor.array.expressions.array_expressions import ArrayContraction
|
||
|
|
||
|
subexpr = ExprBuilder(
|
||
|
ArrayContraction,
|
||
|
[
|
||
|
ExprBuilder(
|
||
|
ArrayTensorProduct,
|
||
|
[
|
||
|
pointer,
|
||
|
other
|
||
|
]
|
||
|
),
|
||
|
(1, 2)
|
||
|
],
|
||
|
validator=ArrayContraction._validate
|
||
|
)
|
||
|
|
||
|
return subexpr
|
||
|
|
||
|
def append_first(self, other):
|
||
|
self.first_pointer *= other
|
||
|
|
||
|
def append_second(self, other):
|
||
|
self.second_pointer *= other
|
||
|
|
||
|
|
||
|
def _make_matrix(x):
|
||
|
from sympy import ImmutableDenseMatrix
|
||
|
if isinstance(x, MatrixExpr):
|
||
|
return x
|
||
|
return ImmutableDenseMatrix([[x]])
|
||
|
|
||
|
|
||
|
from .matmul import MatMul
|
||
|
from .matadd import MatAdd
|
||
|
from .matpow import MatPow
|
||
|
from .transpose import Transpose
|
||
|
from .inverse import Inverse
|
||
|
from .special import ZeroMatrix, Identity
|