""" Classes and functions useful for rewriting expressions for optimized code generation. Some languages (or standards thereof), e.g. C99, offer specialized math functions for better performance and/or precision. Using the ``optimize`` function in this module, together with a collection of rules (represented as instances of ``Optimization``), one can rewrite the expressions for this purpose:: >>> from sympy import Symbol, exp, log >>> from sympy.codegen.rewriting import optimize, optims_c99 >>> x = Symbol('x') >>> optimize(3*exp(2*x) - 3, optims_c99) 3*expm1(2*x) >>> optimize(exp(2*x) - 1 - exp(-33), optims_c99) expm1(2*x) - exp(-33) >>> optimize(log(3*x + 3), optims_c99) log1p(x) + log(3) >>> optimize(log(2*x + 3), optims_c99) log(2*x + 3) The ``optims_c99`` imported above is tuple containing the following instances (which may be imported from ``sympy.codegen.rewriting``): - ``expm1_opt`` - ``log1p_opt`` - ``exp2_opt`` - ``log2_opt`` - ``log2const_opt`` """ from sympy import cos, exp, log, Max, Min, Wild, expand_log, sign, sin, sinc, S from sympy.assumptions import Q, ask from sympy.codegen.cfunctions import log1p, log2, exp2, expm1 from sympy.codegen.matrix_nodes import MatrixSolve from sympy.core.expr import UnevaluatedExpr from sympy.core.power import Pow from sympy.codegen.numpy_nodes import logaddexp, logaddexp2 from sympy.codegen.scipy_nodes import cosm1 from sympy.core.mul import Mul from sympy.matrices.expressions.matexpr import MatrixSymbol from sympy.utilities.iterables import sift class Optimization: """ Abstract base class for rewriting optimization. Subclasses should implement ``__call__`` taking an expression as argument. Parameters ========== cost_function : callable returning number priority : number """ def __init__(self, cost_function=None, priority=1): self.cost_function = cost_function self.priority=priority def cheapest(self, *args): return sorted(args, key=self.cost_function)[0] class ReplaceOptim(Optimization): """ Rewriting optimization calling replace on expressions. Explanation =========== The instance can be used as a function on expressions for which it will apply the ``replace`` method (see :meth:`sympy.core.basic.Basic.replace`). Parameters ========== query : First argument passed to replace. value : Second argument passed to replace. Examples ======== >>> from sympy import Symbol >>> from sympy.codegen.rewriting import ReplaceOptim >>> from sympy.codegen.cfunctions import exp2 >>> x = Symbol('x') >>> exp2_opt = ReplaceOptim(lambda p: p.is_Pow and p.base == 2, ... lambda p: exp2(p.exp)) >>> exp2_opt(2**x) exp2(x) """ def __init__(self, query, value, **kwargs): super().__init__(**kwargs) self.query = query self.value = value def __call__(self, expr): return expr.replace(self.query, self.value) def optimize(expr, optimizations): """ Apply optimizations to an expression. Parameters ========== expr : expression optimizations : iterable of ``Optimization`` instances The optimizations will be sorted with respect to ``priority`` (highest first). Examples ======== >>> from sympy import log, Symbol >>> from sympy.codegen.rewriting import optims_c99, optimize >>> x = Symbol('x') >>> optimize(log(x+3)/log(2) + log(x**2 + 1), optims_c99) log1p(x**2) + log2(x + 3) """ for optim in sorted(optimizations, key=lambda opt: opt.priority, reverse=True): new_expr = optim(expr) if optim.cost_function is None: expr = new_expr else: expr = optim.cheapest(expr, new_expr) return expr exp2_opt = ReplaceOptim( lambda p: p.is_Pow and p.base == 2, lambda p: exp2(p.exp) ) _d = Wild('d', properties=[lambda x: x.is_Dummy]) _u = Wild('u', properties=[lambda x: not x.is_number and not x.is_Add]) _v = Wild('v') _w = Wild('w') _n = Wild('n', properties=[lambda x: x.is_number]) sinc_opt1 = ReplaceOptim( sin(_w)/_w, sinc(_w) ) sinc_opt2 = ReplaceOptim( sin(_n*_w)/_w, _n*sinc(_n*_w) ) sinc_opts = (sinc_opt1, sinc_opt2) log2_opt = ReplaceOptim(_v*log(_w)/log(2), _v*log2(_w), cost_function=lambda expr: expr.count( lambda e: ( # division & eval of transcendentals are expensive floating point operations... e.is_Pow and e.exp.is_negative # division or (isinstance(e, (log, log2)) and not e.args[0].is_number)) # transcendental ) ) log2const_opt = ReplaceOptim(log(2)*log2(_w), log(_w)) logsumexp_2terms_opt = ReplaceOptim( lambda l: (isinstance(l, log) and l.args[0].is_Add and len(l.args[0].args) == 2 and all(isinstance(t, exp) for t in l.args[0].args)), lambda l: ( Max(*[e.args[0] for e in l.args[0].args]) + log1p(exp(Min(*[e.args[0] for e in l.args[0].args]))) ) ) class FuncMinusOneOptim(ReplaceOptim): """Specialization of ReplaceOptim for functions evaluating "f(x) - 1". Explanation =========== Numerical functions which go toward one as x go toward zero is often best implemented by a dedicated function in order to avoid catastrophic cancellation. One such example is ``expm1(x)`` in the C standard library which evaluates ``exp(x) - 1``. Such functions preserves many more significant digits when its argument is much smaller than one, compared to subtracting one afterwards. Parameters ========== func : The function which is subtracted by one. func_m_1 : The specialized function evaluating ``func(x) - 1``. opportunistic : bool When ``True``, apply the transformation as long as the magnitude of the remaining number terms decreases. When ``False``, only apply the transformation if it completely eliminates the number term. Examples ======== >>> from sympy import symbols, exp >>> from sympy.codegen.rewriting import FuncMinusOneOptim >>> from sympy.codegen.cfunctions import expm1 >>> x, y = symbols('x y') >>> expm1_opt = FuncMinusOneOptim(exp, expm1) >>> expm1_opt(exp(x) + 2*exp(5*y) - 3) expm1(x) + 2*expm1(5*y) """ def __init__(self, func, func_m_1, opportunistic=True): weight = 10 # <-- this is an arbitrary number (heuristic) super().__init__(lambda e: e.is_Add, self.replace_in_Add, cost_function=lambda expr: expr.count_ops() - weight*expr.count(func_m_1)) self.func = func self.func_m_1 = func_m_1 self.opportunistic = opportunistic def _group_Add_terms(self, add): numbers, non_num = sift(add.args, lambda arg: arg.is_number, binary=True) numsum = sum(numbers) terms_with_func, other = sift(non_num, lambda arg: arg.has(self.func), binary=True) return numsum, terms_with_func, other def replace_in_Add(self, e): """ passed as second argument to Basic.replace(...) """ numsum, terms_with_func, other_non_num_terms = self._group_Add_terms(e) if numsum == 0: return e substituted, untouched = [], [] for with_func in terms_with_func: if with_func.is_Mul: func, coeff = sift(with_func.args, lambda arg: arg.func == self.func, binary=True) if len(func) == 1 and len(coeff) == 1: func, coeff = func[0], coeff[0] else: coeff = None elif with_func.func == self.func: func, coeff = with_func, S.One else: coeff = None if coeff is not None and coeff.is_number and sign(coeff) == -sign(numsum): if self.opportunistic: do_substitute = abs(coeff+numsum) < abs(numsum) else: do_substitute = coeff+numsum == 0 if do_substitute: # advantageous substitution numsum += coeff substituted.append(coeff*self.func_m_1(*func.args)) continue untouched.append(with_func) return e.func(numsum, *substituted, *untouched, *other_non_num_terms) def __call__(self, expr): alt1 = super().__call__(expr) alt2 = super().__call__(expr.factor()) return self.cheapest(alt1, alt2) expm1_opt = FuncMinusOneOptim(exp, expm1) cosm1_opt = FuncMinusOneOptim(cos, cosm1) log1p_opt = ReplaceOptim( lambda e: isinstance(e, log), lambda l: expand_log(l.replace( log, lambda arg: log(arg.factor()) )).replace(log(_u+1), log1p(_u)) ) def create_expand_pow_optimization(limit, *, base_req=lambda b: b.is_symbol): """ Creates an instance of :class:`ReplaceOptim` for expanding ``Pow``. Explanation =========== The requirements for expansions are that the base needs to be a symbol and the exponent needs to be an Integer (and be less than or equal to ``limit``). Parameters ========== limit : int The highest power which is expanded into multiplication. base_req : function returning bool Requirement on base for expansion to happen, default is to return the ``is_symbol`` attribute of the base. Examples ======== >>> from sympy import Symbol, sin >>> from sympy.codegen.rewriting import create_expand_pow_optimization >>> x = Symbol('x') >>> expand_opt = create_expand_pow_optimization(3) >>> expand_opt(x**5 + x**3) x**5 + x*x*x >>> expand_opt(x**5 + x**3 + sin(x)**3) x**5 + sin(x)**3 + x*x*x >>> opt2 = create_expand_pow_optimization(3 , base_req=lambda b: not b.is_Function) >>> opt2((x+1)**2 + sin(x)**2) sin(x)**2 + (x + 1)*(x + 1) """ return ReplaceOptim( lambda e: e.is_Pow and base_req(e.base) and e.exp.is_Integer and abs(e.exp) <= limit, lambda p: ( UnevaluatedExpr(Mul(*([p.base]*+p.exp), evaluate=False)) if p.exp > 0 else 1/UnevaluatedExpr(Mul(*([p.base]*-p.exp), evaluate=False)) )) # Optimization procedures for turning A**(-1) * x into MatrixSolve(A, x) def _matinv_predicate(expr): # TODO: We should be able to support more than 2 elements if expr.is_MatMul and len(expr.args) == 2: left, right = expr.args if left.is_Inverse and right.shape[1] == 1: inv_arg = left.arg if isinstance(inv_arg, MatrixSymbol): return bool(ask(Q.fullrank(left.arg))) return False def _matinv_transform(expr): left, right = expr.args inv_arg = left.arg return MatrixSolve(inv_arg, right) matinv_opt = ReplaceOptim(_matinv_predicate, _matinv_transform) logaddexp_opt = ReplaceOptim(log(exp(_v)+exp(_w)), logaddexp(_v, _w)) logaddexp2_opt = ReplaceOptim(log(Pow(2, _v)+Pow(2, _w)), logaddexp2(_v, _w)*log(2)) # Collections of optimizations: optims_c99 = (expm1_opt, log1p_opt, exp2_opt, log2_opt, log2const_opt) optims_numpy = optims_c99 + (logaddexp_opt, logaddexp2_opt,) + sinc_opts optims_scipy = (cosm1_opt,)