Source code for dit.other.perplexity

"""
The perplexity of a distribution.
"""

from ..helpers import RV_MODES
from ..shannon import conditional_entropy, entropy
from ..utils.misc import flatten


[docs]def perplexity(dist, rvs=None, crvs=None, rv_mode=None): """ Parameters ---------- dist : Distribution The distribution from which the perplexity is calculated. rvs : list, None The indexes of the random variable used to calculate the perplexity. If None, then the perpelxity is calculated over all random variables. crvs : list, None The indexes of the random variables to condition on. If None, then no variables are condition on. rv_mode : str, None Specifies how to interpret the elements of `rvs`. Valid options are: {'indices', 'names'}. If equal to 'indices', then the elements of `rvs` are interpreted as random variable indices. If equal to 'names', the the elements are interpreted as random variable names. If `None`, then the value of `dist._rv_mode` is consulted. Returns ------- P : float The perplexity. """ base = dist.get_base(numerical=True) if dist.is_log() else 2 if dist.is_joint(): if rvs is None: # Set to entropy of entire distribution rvs = list(range(dist.outcome_length())) rv_mode = RV_MODES.INDICES else: # this will allow inputs of the form [0, 1, 2] or [[0, 1], [2]], # allowing uniform behavior with the mutual information like # measures. rvs = set(flatten(rvs)) if crvs is None: crvs = [] else: return base**entropy(dist) return base**conditional_entropy(dist, rvs, crvs, rv_mode=rv_mode)