Generateurv2/backend/env/lib/python3.10/site-packages/sympy/stats/joint_rv.py
2022-06-24 17:14:37 +02:00

418 lines
15 KiB
Python

"""
Joint Random Variables Module
See Also
========
sympy.stats.rv
sympy.stats.frv
sympy.stats.crv
sympy.stats.drv
"""
from sympy import (Basic, Lambda, sympify, Indexed, Symbol, ProductSet, S,
Dummy, prod)
from sympy.concrete.products import Product
from sympy.concrete.summations import Sum, summation
from sympy.core.compatibility import iterable
from sympy.core.containers import Tuple
from sympy.integrals.integrals import Integral, integrate
from sympy.matrices import ImmutableMatrix, matrix2numpy, list2numpy
from sympy.stats.crv import SingleContinuousDistribution, SingleContinuousPSpace
from sympy.stats.drv import SingleDiscreteDistribution, SingleDiscretePSpace
from sympy.stats.rv import (ProductPSpace, NamedArgsMixin, Distribution,
ProductDomain, RandomSymbol, random_symbols,
SingleDomain, _symbol_converter)
from sympy.utilities.misc import filldedent
from sympy.external import import_module
# __all__ = ['marginal_distribution']
class JointPSpace(ProductPSpace):
"""
Represents a joint probability space. Represented using symbols for
each component and a distribution.
"""
def __new__(cls, sym, dist):
if isinstance(dist, SingleContinuousDistribution):
return SingleContinuousPSpace(sym, dist)
if isinstance(dist, SingleDiscreteDistribution):
return SingleDiscretePSpace(sym, dist)
sym = _symbol_converter(sym)
return Basic.__new__(cls, sym, dist)
@property
def set(self):
return self.domain.set
@property
def symbol(self):
return self.args[0]
@property
def distribution(self):
return self.args[1]
@property
def value(self):
return JointRandomSymbol(self.symbol, self)
@property
def component_count(self):
_set = self.distribution.set
if isinstance(_set, ProductSet):
return S(len(_set.args))
elif isinstance(_set, Product):
return _set.limits[0][-1]
return S.One
@property
def pdf(self):
sym = [Indexed(self.symbol, i) for i in range(self.component_count)]
return self.distribution(*sym)
@property
def domain(self):
rvs = random_symbols(self.distribution)
if not rvs:
return SingleDomain(self.symbol, self.distribution.set)
return ProductDomain(*[rv.pspace.domain for rv in rvs])
def component_domain(self, index):
return self.set.args[index]
def marginal_distribution(self, *indices):
count = self.component_count
if count.atoms(Symbol):
raise ValueError("Marginal distributions cannot be computed "
"for symbolic dimensions. It is a work under progress.")
orig = [Indexed(self.symbol, i) for i in range(count)]
all_syms = [Symbol(str(i)) for i in orig]
replace_dict = dict(zip(all_syms, orig))
sym = tuple(Symbol(str(Indexed(self.symbol, i))) for i in indices)
limits = list([i,] for i in all_syms if i not in sym)
index = 0
for i in range(count):
if i not in indices:
limits[index].append(self.distribution.set.args[i])
limits[index] = tuple(limits[index])
index += 1
if self.distribution.is_Continuous:
f = Lambda(sym, integrate(self.distribution(*all_syms), *limits))
elif self.distribution.is_Discrete:
f = Lambda(sym, summation(self.distribution(*all_syms), *limits))
return f.xreplace(replace_dict)
def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs):
syms = tuple(self.value[i] for i in range(self.component_count))
rvs = rvs or syms
if not any([i in rvs for i in syms]):
return expr
expr = expr*self.pdf
for rv in rvs:
if isinstance(rv, Indexed):
expr = expr.xreplace({rv: Indexed(str(rv.base), rv.args[1])})
elif isinstance(rv, RandomSymbol):
expr = expr.xreplace({rv: rv.symbol})
if self.value in random_symbols(expr):
raise NotImplementedError(filldedent('''
Expectations of expression with unindexed joint random symbols
cannot be calculated yet.'''))
limits = tuple((Indexed(str(rv.base),rv.args[1]),
self.distribution.set.args[rv.args[1]]) for rv in syms)
return Integral(expr, *limits)
def where(self, condition):
raise NotImplementedError()
def compute_density(self, expr):
raise NotImplementedError()
def sample(self, size=(), library='scipy', seed=None):
"""
Internal sample method
Returns dictionary mapping RandomSymbol to realization value.
"""
return {RandomSymbol(self.symbol, self): self.distribution.sample(size,
library=library, seed=seed)}
def probability(self, condition):
raise NotImplementedError()
class SampleJointScipy:
"""Returns the sample from scipy of the given distribution"""
def __new__(cls, dist, size, seed=None):
return cls._sample_scipy(dist, size, seed)
@classmethod
def _sample_scipy(cls, dist, size, seed):
"""Sample from SciPy."""
import numpy
if seed is None or isinstance(seed, int):
rand_state = numpy.random.default_rng(seed=seed)
else:
rand_state = seed
from scipy import stats as scipy_stats
scipy_rv_map = {
'MultivariateNormalDistribution': lambda dist, size: scipy_stats.multivariate_normal.rvs(
mean=matrix2numpy(dist.mu).flatten(),
cov=matrix2numpy(dist.sigma), size=size, random_state=rand_state),
'MultivariateBetaDistribution': lambda dist, size: scipy_stats.dirichlet.rvs(
alpha=list2numpy(dist.alpha, float).flatten(), size=size, random_state=rand_state),
'MultinomialDistribution': lambda dist, size: scipy_stats.multinomial.rvs(
n=int(dist.n), p=list2numpy(dist.p, float).flatten(), size=size, random_state=rand_state)
}
sample_shape = {
'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape,
'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape,
'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape
}
dist_list = scipy_rv_map.keys()
if dist.__class__.__name__ not in dist_list:
return None
samples = scipy_rv_map[dist.__class__.__name__](dist, size)
return samples.reshape(size + sample_shape[dist.__class__.__name__](dist))
class SampleJointNumpy:
"""Returns the sample from numpy of the given distribution"""
def __new__(cls, dist, size, seed=None):
return cls._sample_numpy(dist, size, seed)
@classmethod
def _sample_numpy(cls, dist, size, seed):
"""Sample from NumPy."""
import numpy
if seed is None or isinstance(seed, int):
rand_state = numpy.random.default_rng(seed=seed)
else:
rand_state = seed
numpy_rv_map = {
'MultivariateNormalDistribution': lambda dist, size: rand_state.multivariate_normal(
mean=matrix2numpy(dist.mu, float).flatten(),
cov=matrix2numpy(dist.sigma, float), size=size),
'MultivariateBetaDistribution': lambda dist, size: rand_state.dirichlet(
alpha=list2numpy(dist.alpha, float).flatten(), size=size),
'MultinomialDistribution': lambda dist, size: rand_state.multinomial(
n=int(dist.n), pvals=list2numpy(dist.p, float).flatten(), size=size)
}
sample_shape = {
'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape,
'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape,
'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape
}
dist_list = numpy_rv_map.keys()
if dist.__class__.__name__ not in dist_list:
return None
samples = numpy_rv_map[dist.__class__.__name__](dist, prod(size))
return samples.reshape(size + sample_shape[dist.__class__.__name__](dist))
class SampleJointPymc:
"""Returns the sample from pymc3 of the given distribution"""
def __new__(cls, dist, size, seed=None):
return cls._sample_pymc3(dist, size, seed)
@classmethod
def _sample_pymc3(cls, dist, size, seed):
"""Sample from PyMC3."""
import pymc3
pymc3_rv_map = {
'MultivariateNormalDistribution': lambda dist:
pymc3.MvNormal('X', mu=matrix2numpy(dist.mu, float).flatten(),
cov=matrix2numpy(dist.sigma, float), shape=(1, dist.mu.shape[0])),
'MultivariateBetaDistribution': lambda dist:
pymc3.Dirichlet('X', a=list2numpy(dist.alpha, float).flatten()),
'MultinomialDistribution': lambda dist:
pymc3.Multinomial('X', n=int(dist.n),
p=list2numpy(dist.p, float).flatten(), shape=(1, len(dist.p)))
}
sample_shape = {
'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape,
'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape,
'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape
}
dist_list = pymc3_rv_map.keys()
if dist.__class__.__name__ not in dist_list:
return None
import logging
logging.getLogger("pymc3").setLevel(logging.ERROR)
with pymc3.Model():
pymc3_rv_map[dist.__class__.__name__](dist)
samples = pymc3.sample(draws=prod(size), chains=1, progressbar=False, random_seed=seed, return_inferencedata=False, compute_convergence_checks=False)[:]['X']
return samples.reshape(size + sample_shape[dist.__class__.__name__](dist))
_get_sample_class_jrv = {
'scipy': SampleJointScipy,
'pymc3': SampleJointPymc,
'numpy': SampleJointNumpy
}
class JointDistribution(Distribution, NamedArgsMixin):
"""
Represented by the random variables part of the joint distribution.
Contains methods for PDF, CDF, sampling, marginal densities, etc.
"""
_argnames = ('pdf', )
def __new__(cls, *args):
args = list(map(sympify, args))
for i in range(len(args)):
if isinstance(args[i], list):
args[i] = ImmutableMatrix(args[i])
return Basic.__new__(cls, *args)
@property
def domain(self):
return ProductDomain(self.symbols)
@property
def pdf(self):
return self.density.args[1]
def cdf(self, other):
if not isinstance(other, dict):
raise ValueError("%s should be of type dict, got %s"%(other, type(other)))
rvs = other.keys()
_set = self.domain.set.sets
expr = self.pdf(tuple(i.args[0] for i in self.symbols))
for i in range(len(other)):
if rvs[i].is_Continuous:
density = Integral(expr, (rvs[i], _set[i].inf,
other[rvs[i]]))
elif rvs[i].is_Discrete:
density = Sum(expr, (rvs[i], _set[i].inf,
other[rvs[i]]))
return density
def sample(self, size=(), library='scipy', seed=None):
""" A random realization from the distribution """
libraries = ['scipy', 'numpy', 'pymc3']
if library not in libraries:
raise NotImplementedError("Sampling from %s is not supported yet."
% str(library))
if not import_module(library):
raise ValueError("Failed to import %s" % library)
samps = _get_sample_class_jrv[library](self, size, seed=seed)
if samps is not None:
return samps
raise NotImplementedError(
"Sampling for %s is not currently implemented from %s"
% (self.__class__.__name__, library)
)
def __call__(self, *args):
return self.pdf(*args)
class JointRandomSymbol(RandomSymbol):
"""
Representation of random symbols with joint probability distributions
to allow indexing."
"""
def __getitem__(self, key):
if isinstance(self.pspace, JointPSpace):
if (self.pspace.component_count <= key) == True:
raise ValueError("Index keys for %s can only up to %s." %
(self.name, self.pspace.component_count - 1))
return Indexed(self, key)
class MarginalDistribution(Distribution):
"""
Represents the marginal distribution of a joint probability space.
Initialised using a probability distribution and random variables(or
their indexed components) which should be a part of the resultant
distribution.
"""
def __new__(cls, dist, *rvs):
if len(rvs) == 1 and iterable(rvs[0]):
rvs = tuple(rvs[0])
if not all([isinstance(rv, (Indexed, RandomSymbol))] for rv in rvs):
raise ValueError(filldedent('''Marginal distribution can be
intitialised only in terms of random variables or indexed random
variables'''))
rvs = Tuple.fromiter(rv for rv in rvs)
if not isinstance(dist, JointDistribution) and len(random_symbols(dist)) == 0:
return dist
return Basic.__new__(cls, dist, rvs)
def check(self):
pass
@property
def set(self):
rvs = [i for i in self.args[1] if isinstance(i, RandomSymbol)]
return ProductSet(*[rv.pspace.set for rv in rvs])
@property
def symbols(self):
rvs = self.args[1]
return {rv.pspace.symbol for rv in rvs}
def pdf(self, *x):
expr, rvs = self.args[0], self.args[1]
marginalise_out = [i for i in random_symbols(expr) if i not in rvs]
if isinstance(expr, JointDistribution):
count = len(expr.domain.args)
x = Dummy('x', real=True, finite=True)
syms = tuple(Indexed(x, i) for i in count)
expr = expr.pdf(syms)
else:
syms = tuple(rv.pspace.symbol if isinstance(rv, RandomSymbol) else rv.args[0] for rv in rvs)
return Lambda(syms, self.compute_pdf(expr, marginalise_out))(*x)
def compute_pdf(self, expr, rvs):
for rv in rvs:
lpdf = 1
if isinstance(rv, RandomSymbol):
lpdf = rv.pspace.pdf
expr = self.marginalise_out(expr*lpdf, rv)
return expr
def marginalise_out(self, expr, rv):
from sympy.concrete.summations import Sum
if isinstance(rv, RandomSymbol):
dom = rv.pspace.set
elif isinstance(rv, Indexed):
dom = rv.base.component_domain(
rv.pspace.component_domain(rv.args[1]))
expr = expr.xreplace({rv: rv.pspace.symbol})
if rv.pspace.is_Continuous:
#TODO: Modify to support integration
#for all kinds of sets.
expr = Integral(expr, (rv.pspace.symbol, dom))
elif rv.pspace.is_Discrete:
#incorporate this into `Sum`/`summation`
if dom in (S.Integers, S.Naturals, S.Naturals0):
dom = (dom.inf, dom.sup)
expr = Sum(expr, (rv.pspace.symbol, dom))
return expr
def __call__(self, *args):
return self.pdf(*args)