# Source code for dit.divergences.cross_entropy

"""
The cross entropy.
"""

import numpy as np

from ..exceptions import InvalidOutcome
from ..helpers import normalize_rvs
from ..utils import flatten, unitful

__all__ = ('cross_entropy',
)

def get_prob(d, o):
"""
Get the probability of o, if it's not in the sample space return 0.

Parameters
----------
d : Distribution
The distribution to get the outcomes of.
o : object
The event to get the probability of.

Returns
-------
p : float
The probability of o.
"""
try:
p = d[o]
except InvalidOutcome:
p = 0
return p

def get_pmfs_like(d1, d2, rvs, rv_mode=None):
"""
Get the pmf from d1 for rvs, and the pmf from d2 for the events in
d1

Parameters
----------
d1 : Distribution
The distribution to get the pmf for.
d2 : Distribution
The distribution to get the pmf for, with the outcomes from d1.
rvs : list, None
The random variables to get the pmf for.
rv_mode : str, None
Specifies how to interpret rvs and crvs. Valid options are:
{'indices', 'names'}. If equal to 'indices', then the elements of
crvs and rvs are interpreted as random variable indices. If equal
to 'names', the the elements are interpreted as random variable names.
If None, then the value of dist._rv_mode is consulted, which
defaults to 'indices'.

Returns
-------
ps : ndarray
The pmf of d1.
qs : ndarray
A matching pmf from d2.
"""
dp = d1.marginal(rvs, rv_mode)
dq = d2.marginal(rvs, rv_mode)
ps = dp.pmf
qs = np.asarray([get_prob(dq, o) for o in dp.outcomes])
return ps, qs

[docs]@unitful
def cross_entropy(dist1, dist2, rvs=None, crvs=None, rv_mode=None):
"""
The cross entropy between dist1 and dist2.

Parameters
----------
dist1 : Distribution
The first distribution in the cross entropy.
dist2 : Distribution
The second distribution in the cross entropy.
rvs : list, None
The indexes of the random variable used to calculate the cross entropy
between. If None, then the cross entropy is calculated over all random
variables.
rv_mode : str, None
Specifies how to interpret rvs and crvs. Valid options are:
{'indices', 'names'}. If equal to 'indices', then the elements of
crvs and rvs are interpreted as random variable indices. If equal
to 'names', the the elements are interpreted as random variable names.
If None, then the value of dist._rv_mode is consulted, which
defaults to 'indices'.

Returns
-------
xh : float
The cross entropy between dist1 and dist2.

Raises
------
ditException
Raised if either dist1 or dist2 doesn't have rvs or, if rvs is
None, if dist2 has an outcome length different than dist1.
"""
rvs, crvs, rv_mode = normalize_rvs(dist1, rvs, crvs, rv_mode)
rvs, crvs = list(flatten(rvs)), list(flatten(crvs))
normalize_rvs(dist2, rvs, crvs, rv_mode)

p1s, q1s = get_pmfs_like(dist1, dist2, rvs+crvs, rv_mode)
xh = -np.nansum(p1s * np.log2(q1s))

if crvs:
p2s, q2s = get_pmfs_like(dist1, dist2, crvs, rv_mode)
xh2 = -np.nansum(p2s * np.log2(q2s))
xh -= xh2

return xh