Coverage for HARK/distributions/discrete.py: 92%
156 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-02 05:14 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-02 05:14 +0000
1from typing import Any, Callable, Dict, List, Optional, Union
3from copy import deepcopy
4import numpy as np
5import xarray as xr
6from scipy import stats
7from scipy.stats import rv_discrete
8from scipy.stats._distn_infrastructure import rv_discrete_frozen
10from HARK.distributions.base import Distribution
13class DiscreteFrozenDistribution(rv_discrete_frozen, Distribution):
14 """
15 Parameterized discrete distribution from scipy.stats with seed management.
16 """
18 def __init__(
19 self, dist: rv_discrete, *args: Any, seed: int = 0, **kwds: Any
20 ) -> None:
21 """
22 Parameterized discrete distribution from scipy.stats with seed management.
24 Parameters
25 ----------
26 dist : rv_discrete
27 Discrete distribution from scipy.stats.
28 seed : int, optional
29 Seed for random number generator, by default 0
30 """
32 rv_discrete_frozen.__init__(self, dist, *args, **kwds)
33 Distribution.__init__(self, seed=seed)
36class Bernoulli(DiscreteFrozenDistribution):
37 """
38 A Bernoulli distribution.
40 Parameters
41 ----------
42 p : float or [float]
43 Probability or probabilities of the event occurring (True).
45 seed : int
46 Seed for random number generator.
47 """
49 def __init__(self, p=0.5, seed=0):
50 self.p = np.asarray(p)
51 # Set up the RNG
52 super().__init__(stats.bernoulli, p=self.p, seed=seed)
54 self.pmv = np.array([1 - self.p, self.p])
55 self.atoms = np.array(
56 [[0, 1]]
57 ) # Ensure atoms is properly shaped like other distributions
58 self.limit = {
59 "dist": self,
60 "infimum": np.array([0.0]),
61 "supremum": np.array([1.0]),
62 }
63 self.infimum = np.array([0.0])
64 self.supremum = np.array([1.0])
66 def dim(self):
67 """
68 Last dimension of self.atoms indexes "atom."
69 """
70 return self.atoms.shape[:-1]
73class DiscreteDistribution(Distribution):
74 """
75 A representation of a discrete probability distribution.
77 Parameters
78 ----------
79 pmv : np.array
80 An array of floats representing a probability mass function.
81 atoms : np.array
82 Discrete point values for each probability mass.
83 For multivariate distributions, the last dimension of atoms must index
84 "atom" or the random realization. For instance, if atoms.shape == (2,6,4),
85 the random variable has 4 possible realizations and each of them has shape (2,6).
86 limit : dict
87 Dictionary with information about the continuous distribution from which
88 this distribution was generated. The reference distribution is in the entry
89 called 'dist'.
90 seed : int
91 Seed for random number generator.
92 """
94 def __init__(
95 self,
96 pmv: np.ndarray,
97 atoms: np.ndarray,
98 seed: int = 0,
99 limit: Optional[Dict[str, Any]] = None,
100 ) -> None:
101 super().__init__(seed=seed)
103 self.pmv = np.asarray(pmv)
104 self.atoms = np.atleast_2d(atoms)
105 if limit is None:
106 limit = {
107 "infimum": np.min(self.atoms, axis=-1),
108 "supremum": np.max(self.atoms, axis=-1),
109 }
110 self.limit = limit
112 # Check that pmv and atoms have compatible dimensions.
113 if not self.pmv.size == self.atoms.shape[-1]:
114 raise ValueError(
115 "Provided pmv and atoms arrays have incompatible dimensions. "
116 + "The length of the pmv must be equal to that of atoms's last dimension."
117 )
119 def dim(self) -> int:
120 """
121 Last dimension of self.atoms indexes "atom."
122 """
123 return self.atoms.shape[:-1]
125 def draw_events(self, N: int) -> np.ndarray:
126 """
127 Draws N 'events' from the distribution PMF.
128 These events are indices into atoms.
129 """
130 # Generate a cumulative distribution
131 base_draws = self._rng.uniform(size=N)
132 cum_dist = np.cumsum(self.pmv)
134 # Convert the basic uniform draws into discrete draws
135 indices = cum_dist.searchsorted(base_draws)
137 return indices
139 def draw(
140 self,
141 N: int,
142 atoms: Union[None, int, np.ndarray] = None,
143 exact_match: bool = False,
144 ) -> np.ndarray:
145 """
146 Simulates N draws from a discrete distribution with probabilities P and outcomes atoms.
148 Parameters
149 ----------
150 N : int
151 Number of draws to simulate.
152 atoms : None, int, or np.array
153 If None, then use this distribution's atoms for point values.
154 If an int, then the index of atoms for the point values.
155 If an np.array, use the array for the point values.
156 exact_match : boolean
157 Whether the draws should "exactly" match the discrete distribution (as
158 closely as possible given finite draws). When True, returned draws are
159 a random permutation of the N-length list that best fits the discrete
160 distribution. When False (default), each draw is independent from the
161 others and the result could deviate from the input.
163 Returns
164 -------
165 draws : np.array
166 An array of draws from the discrete distribution; each element is a value in atoms.
167 """
168 if atoms is None:
169 atoms = self.atoms
170 elif isinstance(atoms, int):
171 atoms = self.atoms[atoms]
173 if exact_match:
174 events = np.arange(self.pmv.size) # just a list of integers
175 cutoffs = np.round(np.cumsum(self.pmv) * N).astype(
176 int
177 ) # cutoff points between discrete outcomes
178 top = 0
180 # Make a list of event indices that closely matches the discrete distribution
181 event_list = []
182 for j in range(events.size):
183 bot = top
184 top = cutoffs[j]
185 event_list += (top - bot) * [events[j]]
187 # Randomly permute the event indices
188 indices = self._rng.permutation(event_list)
190 # Draw event indices randomly from the discrete distribution
191 else:
192 indices = self.draw_events(N)
194 # Create and fill in the output array of draws based on the output of event indices
195 draws = atoms[..., indices]
197 # TODO: some models expect univariate draws to just be a 1d vector. Fix those models.
198 if len(draws.shape) == 2 and draws.shape[0] == 1:
199 draws = draws.flatten()
201 return draws
203 def expected(
204 self, func: Optional[Callable] = None, *args: np.ndarray
205 ) -> np.ndarray:
206 """
207 Expected value of a function, given an array of configurations of its
208 inputs along with a DiscreteDistribution object that specifies the
209 probability of each configuration.
211 If no function is provided, it's much faster to go straight to dot
212 product instead of calling the dummy function.
214 If a function is provided, we need to add one more dimension,
215 the atom dimension, to any inputs that are n-dim arrays.
216 This allows numpy to easily broadcast the function's output.
217 For more information on broadcasting, see:
218 https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules
220 Parameters
221 ----------
222 func : function
223 The function to be evaluated.
224 This function should take the full array of distribution values
225 and return either arrays of arbitrary shape or scalars.
226 It may also take other arguments \\*args.
227 This function differs from the standalone `calc_expectation`
228 method in that it uses numpy's vectorization and broadcasting
229 rules to avoid costly iteration.
230 Note: If you need to use a function that acts on single outcomes
231 of the distribution, consider `distribution.calc_expectation`.
232 \\*args :
233 Other inputs for func, representing the non-stochastic arguments.
234 The the expectation is computed at ``f(dstn, *args)``.
236 Returns
237 -------
238 f_exp : np.array or scalar
239 The expectation of the function at the queried values.
240 Scalar if only one value.
241 """
243 if func is None:
244 f_query = self.atoms
245 else:
246 args = [
247 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg
248 for arg in args
249 ]
251 f_query = func(self.atoms, *args)
253 f_exp = np.dot(f_query, self.pmv)
255 return f_exp
257 def dist_of_func(
258 self, func: Callable[..., float] = lambda x: x, *args: Any
259 ) -> "DiscreteDistribution":
260 """
261 Finds the distribution of a random variable Y that is a function
262 of discrete random variable atoms, Y=f(atoms).
264 Parameters
265 ----------
266 func : function
267 The function to be evaluated.
268 This function should take the full array of distribution values.
269 It may also take other arguments \\*args.
270 \\*args :
271 Additional non-stochastic arguments for func,
272 The function is computed as ``f(dstn, *args)``.
274 Returns
275 -------
276 f_dstn : DiscreteDistribution
277 The distribution of func(dstn).
278 """
279 # we need to add one more dimension,
280 # the atom dimension, to any inputs that are n-dim arrays.
281 # This allows numpy to easily broadcast the function's output.
282 args = [
283 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg
284 for arg in args
285 ]
286 f_query = func(self.atoms, *args)
288 f_dstn = DiscreteDistribution(list(self.pmv), f_query, seed=self.seed)
290 return f_dstn
292 def discretize(self, N: int, *args: Any, **kwargs: Any) -> "DiscreteDistribution":
293 """
294 `DiscreteDistribution` is already an approximation, so this method
295 returns a copy of the distribution.
297 TODO: print warning message?
298 """
299 return self
301 def make_univariate(self, dim_to_keep, seed=0):
302 """
303 Make a univariate discrete distribution from this distribution, keeping
304 only the specified dimension.
306 Parameters
307 ----------
308 dim_to_keep : int
309 Index of the distribution to be kept. Any other dimensions will be
310 "collapsed" into the univariate atoms, combining probabilities.
311 seed : int, optional
312 Seed for random number generator of univariate distribution
314 Returns
315 -------
316 univariate_dstn : DiscreteDistribution
317 Univariate distribution with only the specified index.
318 """
319 # Do basic validity and triviality checks
320 if (self.atoms.shape[0] == 1) and (dim_to_keep == 0):
321 return deepcopy(self) # Return copy of self if only one dimension
322 if dim_to_keep >= self.atoms.shape[0]:
323 raise ValueError("dim_to_keep exceeds dimensionality of distribution.")
325 # Construct values and probabilities for univariate distribution
326 atoms_temp = self.atoms[dim_to_keep]
327 vals_to_keep = np.unique(atoms_temp)
328 probs_to_keep = np.zeros_like(vals_to_keep)
329 for i in range(vals_to_keep.size):
330 val = vals_to_keep[i]
331 these = atoms_temp == val
332 probs_to_keep[i] = np.sum(self.pmv[these])
334 # Make and return the univariate distribution
335 univariate_dstn = DiscreteDistribution(
336 pmv=probs_to_keep, atoms=vals_to_keep, seed=seed
337 )
338 return univariate_dstn
341class DiscreteDistributionLabeled(DiscreteDistribution):
342 """
343 A representation of a discrete probability distribution
344 stored in an underlying `xarray.Dataset`.
346 Parameters
347 ----------
348 pmv : np.array
349 An array of values representing a probability mass function.
350 data : np.array
351 Discrete point values for each probability mass.
352 For multivariate distributions, the last dimension of atoms must index
353 "atom" or the random realization. For instance, if atoms.shape == (2,6,4),
354 the random variable has 4 possible realizations and each of them has shape (2,6).
355 seed : int
356 Seed for random number generator.
357 name : str
358 Name of the distribution.
359 attrs : dict
360 Attributes for the distribution.
361 var_names : list of str
362 Names of the variables in the distribution.
363 var_attrs : list of dict
364 Attributes of the variables in the distribution.
366 """
368 def __init__(
369 self,
370 pmv: np.ndarray,
371 atoms: np.ndarray,
372 seed: int = 0,
373 limit: Optional[Dict[str, Any]] = None,
374 name: str = "DiscreteDistributionLabeled",
375 attrs: Optional[Dict[str, Any]] = None,
376 var_names: Optional[List[str]] = None,
377 var_attrs: Optional[List[Optional[Dict[str, Any]]]] = None,
378 ):
379 super().__init__(pmv, atoms, seed=seed, limit=limit)
381 # vector-value distributions
383 if self.atoms.ndim > 2:
384 raise NotImplementedError(
385 "Only vector-valued distributions are supported for now."
386 )
388 attrs = {} if attrs is None else attrs
389 if limit is None:
390 limit = {
391 "infimum": np.min(self.atoms, axis=-1),
392 "supremum": np.max(self.atoms, axis=-1),
393 }
394 self.limit = limit
395 attrs.update(limit)
396 attrs["name"] = name
398 n_var = self.atoms.shape[0]
400 # give dummy names to variables if none are provided
401 if var_names is None:
402 var_names = ["var_" + str(i) for i in range(n_var)]
404 assert len(var_names) == n_var, (
405 "Number of variable names does not match number of variables."
406 )
408 # give dummy attributes to variables if none are provided
409 if var_attrs is None:
410 var_attrs = [None] * n_var
412 # a DiscreteDistributionLabeled is an xr.Dataset where the only
413 # dimension is "atom", which indexes the random realizations.
414 self.dataset = xr.Dataset(
415 {
416 var_names[i]: xr.DataArray(
417 self.atoms[i],
418 dims=("atom"),
419 attrs=var_attrs[i],
420 )
421 for i in range(n_var)
422 },
423 attrs=attrs,
424 )
426 # the probability mass values are stored in
427 # a DataArray with dimension "atom"
428 self.probability = xr.DataArray(self.pmv, dims=("atom"))
430 @classmethod
431 def from_unlabeled(
432 cls,
433 dist,
434 name="DiscreteDistributionLabeled",
435 attrs=None,
436 var_names=None,
437 var_attrs=None,
438 ):
439 ldd = cls(
440 dist.pmv,
441 dist.atoms,
442 seed=dist.seed,
443 limit=dist.limit,
444 name=name,
445 attrs=attrs,
446 var_names=var_names,
447 var_attrs=var_attrs,
448 )
450 return ldd
452 @classmethod
453 def from_dataset(cls, x_obj, pmf):
454 ldd = cls.__new__(cls)
456 if isinstance(x_obj, xr.Dataset):
457 ldd.dataset = x_obj
458 elif isinstance(x_obj, xr.DataArray):
459 ldd.dataset = xr.Dataset({x_obj.name: x_obj})
460 elif isinstance(x_obj, dict):
461 ldd.dataset = xr.Dataset(x_obj)
463 ldd.probability = pmf
465 return ldd
467 @property
468 def _weighted(self):
469 """
470 Returns a DatasetWeighted object for the distribution.
471 """
472 return self.dataset.weighted(self.probability)
474 @property
475 def variables(self):
476 """
477 A dict-like container of DataArrays corresponding to
478 the variables of the distribution.
479 """
480 return self.dataset.data_vars
482 @property
483 def name(self):
484 """
485 The distribution's name.
486 """
487 return self.dataset.name
489 @property
490 def attrs(self):
491 """
492 The distribution's attributes.
493 """
495 return self.dataset.attrs
497 def dist_of_func(
498 self, func: Callable = lambda x: x, *args, **kwargs
499 ) -> DiscreteDistribution:
500 """
501 Finds the distribution of a random variable Y that is a function
502 of discrete random variable atoms, Y=f(atoms).
504 Parameters
505 ----------
506 func : function
507 The function to be evaluated.
508 This function should take the full array of distribution values.
509 It may also take other arguments \\*args.
510 \\*args :
511 Additional non-stochastic arguments for func,
512 The function is computed as ``f(dstn, *args)``.
513 \\*\\*kwargs :
514 Additional keyword arguments for func. Must be xarray compatible
515 in order to work with xarray broadcasting.
517 Returns
518 -------
519 f_dstn : DiscreteDistribution or DiscreteDistributionLabeled
520 The distribution of func(dstn).
521 """
523 def func_wrapper(x: np.ndarray, *args: Any) -> np.ndarray:
524 """
525 Wrapper function for `func` that handles labeled indexing.
526 """
528 idx = self.variables.keys()
529 wrapped = dict(zip(idx, x))
531 return func(wrapped, *args)
533 if len(kwargs):
534 f_query = func(self.dataset, **kwargs)
535 ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability)
537 return ldd
539 return super().dist_of_func(func_wrapper, *args)
541 def expected(
542 self, func: Optional[Callable] = None, *args: Any, **kwargs: Any
543 ) -> Union[float, np.ndarray]:
544 """
545 Expectation of a function, given an array of configurations of its inputs
546 along with a DiscreteDistributionLabeled object that specifies the probability
547 of each configuration.
549 Parameters
550 ----------
551 func : function
552 The function to be evaluated.
553 This function should take the full array of distribution values
554 and return either arrays of arbitrary shape or scalars.
555 It may also take other arguments \\*args.
556 This function differs from the standalone `calc_expectation`
557 method in that it uses numpy's vectorization and broadcasting
558 rules to avoid costly iteration.
559 Note: If you need to use a function that acts on single outcomes
560 of the distribution, consider `distribution.calc_expectation`.
561 \\*args :
562 Other inputs for func, representing the non-stochastic arguments.
563 The the expectation is computed at ``f(dstn, *args)``.
564 labels : bool
565 If True, the function should use labeled indexing instead of integer
566 indexing using the distribution's underlying rv coordinates. For example,
567 if `dims = ('rv', 'x')` and `coords = {'rv': ['a', 'b'], }`, then
568 the function can be `lambda x: x["a"] + x["b"]`.
570 Returns
571 -------
572 f_exp : np.array or scalar
573 The expectation of the function at the queried values.
574 Scalar if only one value.
575 """
577 def func_wrapper(x, *args):
578 """
579 Wrapper function for `func` that handles labeled indexing.
580 """
582 idx = self.variables.keys()
583 wrapped = dict(zip(idx, x))
585 return func(wrapped, *args)
587 if len(kwargs):
588 f_query = func(self.dataset, *args, **kwargs)
589 ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability)
591 return ldd._weighted.mean("atom")
592 else:
593 if func is None:
594 return super().expected()
595 else:
596 return super().expected(func_wrapper, *args)