Generateurv2/backend/env/lib/python3.10/site-packages/sympy/codegen/matrix_nodes.py
2022-06-24 17:14:37 +02:00

67 lines
2.1 KiB
Python

"""
Additional AST nodes for operations on matrices. The nodes in this module
are meant to represent optimization of matrix expressions within codegen's
target languages that cannot be represented by SymPy expressions.
As an example, we can use :meth:`sympy.codegen.rewriting.optimize` and the
``matin_opt`` optimization provided in :mod:`sympy.codegen.rewriting` to
transform matrix multiplication under certain assumptions:
>>> from sympy import symbols, MatrixSymbol
>>> n = symbols('n', integer=True)
>>> A = MatrixSymbol('A', n, n)
>>> x = MatrixSymbol('x', n, 1)
>>> expr = A**(-1) * x
>>> from sympy.assumptions import assuming, Q
>>> from sympy.codegen.rewriting import matinv_opt, optimize
>>> with assuming(Q.fullrank(A)):
... optimize(expr, [matinv_opt])
MatrixSolve(A, vector=x)
"""
from .ast import Token
from sympy.matrices import MatrixExpr
from sympy.core.sympify import sympify
class MatrixSolve(Token, MatrixExpr):
"""Represents an operation to solve a linear matrix equation.
Parameters
==========
matrix : MatrixSymbol
Matrix representing the coefficients of variables in the linear
equation. This matrix must be square and full-rank (i.e. all columns must
be linearly independent) for the solving operation to be valid.
vector : MatrixSymbol
One-column matrix representing the solutions to the equations
represented in ``matrix``.
Examples
========
>>> from sympy import symbols, MatrixSymbol
>>> from sympy.codegen.matrix_nodes import MatrixSolve
>>> n = symbols('n', integer=True)
>>> A = MatrixSymbol('A', n, n)
>>> x = MatrixSymbol('x', n, 1)
>>> from sympy.printing.numpy import NumPyPrinter
>>> NumPyPrinter().doprint(MatrixSolve(A, x))
'numpy.linalg.solve(A, x)'
>>> from sympy.printing import octave_code
>>> octave_code(MatrixSolve(A, x))
'A \\\\ x'
"""
__slots__ = ('matrix', 'vector')
_construct_matrix = staticmethod(sympify)
@property
def shape(self):
return self.vector.shape