Coverage for HARK / distributions / discrete.py: 92%
173 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
1from typing import Any, Callable, Dict, List, Optional, Union
3from copy import deepcopy
4import numpy as np
5import xarray as xr
6from scipy import stats
7from scipy.stats import rv_discrete
8from scipy.stats._distn_infrastructure import rv_discrete_frozen
10from HARK.distributions.base import Distribution
13class DiscreteFrozenDistribution(rv_discrete_frozen, Distribution):
14 """
15 Parameterized discrete distribution from scipy.stats with seed management.
16 """
18 def __init__(
19 self, dist: rv_discrete, *args: Any, seed: int = None, **kwds: Any
20 ) -> None:
21 """
22 Parameterized discrete distribution from scipy.stats with seed management.
24 Parameters
25 ----------
26 dist : rv_discrete
27 Discrete distribution from scipy.stats.
28 seed : int, optional
29 Seed for random number generator, by default 0
30 """
32 rv_discrete_frozen.__init__(self, dist, *args, **kwds)
33 Distribution.__init__(self, seed=seed)
36class Bernoulli(DiscreteFrozenDistribution):
37 """
38 A Bernoulli distribution.
40 Parameters
41 ----------
42 p : float or [float]
43 Probability or probabilities of the event occurring (True).
45 seed : int
46 Seed for random number generator.
47 """
49 def __init__(self, p=0.5, seed=None):
50 self.p = np.asarray(p)
51 # Set up the RNG
52 super().__init__(stats.bernoulli, p=self.p, seed=seed)
54 self.pmv = np.array([1 - self.p, self.p])
55 self.atoms = np.array(
56 [[0, 1]]
57 ) # Ensure atoms is properly shaped like other distributions
58 self.limit = {
59 "dist": self,
60 "infimum": np.array([0.0]),
61 "supremum": np.array([1.0]),
62 }
63 self.infimum = np.array([0.0])
64 self.supremum = np.array([1.0])
66 def dim(self):
67 """
68 Last dimension of self.atoms indexes "atom."
69 """
70 return self.atoms.shape[:-1]
73class DiscreteDistribution(Distribution):
74 """
75 A representation of a discrete probability distribution.
77 Parameters
78 ----------
79 pmv : np.array
80 An array of floats representing a probability mass function.
81 atoms : np.array
82 Discrete point values for each probability mass.
83 For multivariate distributions, the last dimension of atoms must index
84 "atom" or the random realization. For instance, if atoms.shape == (2,6,4),
85 the random variable has 4 possible realizations and each of them has shape (2,6).
86 limit : dict
87 Dictionary with information about the continuous distribution from which
88 this distribution was generated. The reference distribution is in the entry
89 called 'dist'.
90 seed : int
91 Seed for random number generator.
92 """
94 def __init__(
95 self,
96 pmv: np.ndarray,
97 atoms: np.ndarray,
98 seed: int = None,
99 limit: Optional[Dict[str, Any]] = None,
100 ) -> None:
101 super().__init__(seed=seed)
103 self.pmv = np.asarray(pmv)
104 self.atoms = np.atleast_2d(atoms)
105 if limit is None:
106 limit = {
107 "infimum": np.min(self.atoms, axis=-1),
108 "supremum": np.max(self.atoms, axis=-1),
109 }
110 self.limit = limit
112 # Check that pmv and atoms have compatible dimensions.
113 if not self.pmv.size == self.atoms.shape[-1]:
114 raise ValueError(
115 "Provided pmv and atoms arrays have incompatible dimensions. "
116 + "The length of the pmv must be equal to that of atoms's last dimension."
117 )
119 def __repr__(self):
120 out = self.__class__.__name__ + " with " + str(self.pmv.size) + " atoms, "
121 if self.atoms.shape[0] > 1:
122 out += "inf=" + str(tuple(self.limit["infimum"])) + ", "
123 out += "sup=" + str(tuple(self.limit["supremum"])) + ", "
124 else:
125 out += "inf=" + str(self.limit["infimum"][0]) + ", "
126 out += "sup=" + str(self.limit["supremum"][0]) + ", "
127 out += "seed=" + str(self.seed)
128 return out
130 def dim(self) -> int:
131 """
132 Last dimension of self.atoms indexes "atom."
133 """
134 return self.atoms.shape[:-1]
136 def draw_events(self, N: int) -> np.ndarray:
137 """
138 Draws N 'events' from the distribution PMF.
139 These events are indices into atoms.
140 """
141 # Generate a cumulative distribution
142 base_draws = self._rng.uniform(size=N)
143 cum_dist = np.cumsum(self.pmv)
145 # Convert the basic uniform draws into discrete draws
146 indices = cum_dist.searchsorted(base_draws)
148 return indices
150 def draw(
151 self,
152 N: int,
153 atoms: Union[None, int, np.ndarray] = None,
154 shuffle: bool = False,
155 ) -> np.ndarray:
156 """
157 Simulates N draws from a discrete distribution with probabilities P and outcomes atoms.
159 Parameters
160 ----------
161 N : int
162 Number of draws to simulate.
163 atoms : None, int, or np.array
164 If None, then use this distribution's atoms for point values.
165 If an int, then the index of atoms for the point values.
166 If an np.array, use the array for the point values.
167 shuffle : boolean
168 Whether the draws should "shuffle" the discrete distribution, matching
169 proportions of outcomes as closely as possible to the probabilities given
170 finite draws. When True, returned draws are a random permutation of the
171 N-length list that best fits the discrete distribution. When False
172 (default), each draw is independent from the others and the result could
173 deviate from the probabilities.
175 Returns
176 -------
177 draws : np.array
178 An array of draws from the discrete distribution; each element is a value in atoms.
179 """
180 if atoms is None:
181 atoms = self.atoms
182 elif isinstance(atoms, int):
183 atoms = self.atoms[atoms]
185 # "Shuffle" an almost-exact population of draws based on the pmv
186 if shuffle:
187 P = self.pmv
188 K_exact = N * P # slots per outcome in real numbers
189 K = np.floor(K_exact).astype(int) # number of slots allocated to each atom
190 M = N - np.sum(K) # number of unallocated slots
191 J = P.size
192 eps = 1.0 / N
193 Q = K_exact - eps * K # "missing" probability mass
194 draws = self._rng.random(M) # uniform draws for "extra" slots
196 # Fill in each unallocated slot, one by one
197 for m in range(M):
198 Q_adj = Q / np.sum(Q) # probabilities for this pass
199 Q_sum = np.cumsum(Q_adj)
200 j = np.searchsorted(Q_sum, draws[m]) # find index for this draw
201 K[j] += 1 # increment its allocated slots
202 Q[j] = 0.0 # zero out its probability because we used it
204 # Make an array of atom indices based on the final slot counts
205 nested_events = [K[j] * [j] for j in range(J)]
206 events = np.array([i for sublist in nested_events for i in sublist])
208 # Draw a random permutation of the indices
209 indices = self._rng.permutation(events)
211 # Draw event indices randomly from the discrete distribution
212 else:
213 indices = self.draw_events(N)
215 # Create and fill in the output array of draws based on the output of event indices
216 draws = atoms[..., indices]
218 # TODO: some models expect univariate draws to just be a 1d vector. Fix those models.
219 if len(draws.shape) == 2 and draws.shape[0] == 1:
220 draws = draws.flatten()
222 return draws
224 def expected(
225 self, func: Optional[Callable] = None, *args: np.ndarray
226 ) -> np.ndarray:
227 """
228 Expected value of a function, given an array of configurations of its
229 inputs along with a DiscreteDistribution object that specifies the
230 probability of each configuration.
232 If no function is provided, it's much faster to go straight to dot
233 product instead of calling the dummy function.
235 If a function is provided, we need to add one more dimension,
236 the atom dimension, to any inputs that are n-dim arrays.
237 This allows numpy to easily broadcast the function's output.
238 For more information on broadcasting, see:
239 https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules
241 Parameters
242 ----------
243 func : function
244 The function to be evaluated.
245 This function should take the full array of distribution values
246 and return either arrays of arbitrary shape or scalars.
247 It may also take other arguments \\*args.
248 This function differs from the standalone `calc_expectation`
249 method in that it uses numpy's vectorization and broadcasting
250 rules to avoid costly iteration.
251 Note: If you need to use a function that acts on single outcomes
252 of the distribution, consider `distribution.calc_expectation`.
253 \\*args :
254 Other inputs for func, representing the non-stochastic arguments.
255 The the expectation is computed at ``f(dstn, *args)``.
257 Returns
258 -------
259 f_exp : np.array or scalar
260 The expectation of the function at the queried values.
261 Scalar if only one value.
262 """
264 if func is None:
265 f_query = self.atoms
266 else:
267 args = [
268 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg
269 for arg in args
270 ]
272 f_query = func(self.atoms, *args)
274 f_exp = np.dot(f_query, self.pmv)
276 return f_exp
278 def dist_of_func(
279 self, func: Callable[..., float] = lambda x: x, *args: Any
280 ) -> "DiscreteDistribution":
281 """
282 Finds the distribution of a random variable Y that is a function
283 of discrete random variable atoms, Y=f(atoms).
285 Parameters
286 ----------
287 func : function
288 The function to be evaluated.
289 This function should take the full array of distribution values.
290 It may also take other arguments \\*args.
291 \\*args :
292 Additional non-stochastic arguments for func,
293 The function is computed as ``f(dstn, *args)``.
295 Returns
296 -------
297 f_dstn : DiscreteDistribution
298 The distribution of func(dstn).
299 """
300 # we need to add one more dimension,
301 # the atom dimension, to any inputs that are n-dim arrays.
302 # This allows numpy to easily broadcast the function's output.
303 args = [
304 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg
305 for arg in args
306 ]
307 f_query = func(self.atoms, *args)
309 f_dstn = DiscreteDistribution(list(self.pmv), f_query, seed=self.seed)
311 return f_dstn
313 def discretize(self, N: int, *args: Any, **kwargs: Any) -> "DiscreteDistribution":
314 """
315 `DiscreteDistribution` is already an approximation, so this method
316 returns a copy of the distribution.
318 TODO: print warning message?
319 """
320 return self
322 def make_univariate(self, dim_to_keep, seed=0):
323 """
324 Make a univariate discrete distribution from this distribution, keeping
325 only the specified dimension.
327 Parameters
328 ----------
329 dim_to_keep : int
330 Index of the distribution to be kept. Any other dimensions will be
331 "collapsed" into the univariate atoms, combining probabilities.
332 seed : int, optional
333 Seed for random number generator of univariate distribution
335 Returns
336 -------
337 univariate_dstn : DiscreteDistribution
338 Univariate distribution with only the specified index.
339 """
340 # Do basic validity and triviality checks
341 if (self.atoms.shape[0] == 1) and (dim_to_keep == 0):
342 return deepcopy(self) # Return copy of self if only one dimension
343 if dim_to_keep >= self.atoms.shape[0]:
344 raise ValueError("dim_to_keep exceeds dimensionality of distribution.")
346 # Construct values and probabilities for univariate distribution
347 atoms_temp = self.atoms[dim_to_keep]
348 vals_to_keep = np.unique(atoms_temp)
349 probs_to_keep = np.zeros_like(vals_to_keep)
350 for i in range(vals_to_keep.size):
351 val = vals_to_keep[i]
352 these = atoms_temp == val
353 probs_to_keep[i] = np.sum(self.pmv[these])
355 # Make and return the univariate distribution
356 univariate_dstn = DiscreteDistribution(
357 pmv=probs_to_keep, atoms=vals_to_keep, seed=seed
358 )
359 return univariate_dstn
362class DiscreteDistributionLabeled(DiscreteDistribution):
363 """
364 A representation of a discrete probability distribution
365 stored in an underlying `xarray.Dataset`.
367 Parameters
368 ----------
369 pmv : np.array
370 An array of values representing a probability mass function.
371 data : np.array
372 Discrete point values for each probability mass.
373 For multivariate distributions, the last dimension of atoms must index
374 "atom" or the random realization. For instance, if atoms.shape == (2,6,4),
375 the random variable has 4 possible realizations and each of them has shape (2,6).
376 seed : int
377 Seed for random number generator.
378 name : str
379 Name of the distribution.
380 attrs : dict
381 Attributes for the distribution.
382 var_names : list of str
383 Names of the variables in the distribution.
384 var_attrs : list of dict
385 Attributes of the variables in the distribution.
387 """
389 def __init__(
390 self,
391 pmv: np.ndarray,
392 atoms: np.ndarray,
393 seed: int = None,
394 limit: Optional[Dict[str, Any]] = None,
395 name: str = "DiscreteDistributionLabeled",
396 attrs: Optional[Dict[str, Any]] = None,
397 var_names: Optional[List[str]] = None,
398 var_attrs: Optional[List[Optional[Dict[str, Any]]]] = None,
399 ):
400 super().__init__(pmv, atoms, seed=seed, limit=limit)
402 # vector-value distributions
404 if self.atoms.ndim > 2:
405 raise NotImplementedError(
406 "Only vector-valued distributions are supported for now."
407 )
409 attrs = {} if attrs is None else attrs
410 if limit is None:
411 limit = {
412 "infimum": np.min(self.atoms, axis=-1),
413 "supremum": np.max(self.atoms, axis=-1),
414 }
415 self.limit = limit
416 attrs.update(limit)
417 attrs["name"] = name
419 n_var = self.atoms.shape[0]
421 # give dummy names to variables if none are provided
422 if var_names is None:
423 var_names = ["var_" + str(i) for i in range(n_var)]
425 assert len(var_names) == n_var, (
426 "Number of variable names does not match number of variables."
427 )
429 # give dummy attributes to variables if none are provided
430 if var_attrs is None:
431 var_attrs = [None] * n_var
433 # a DiscreteDistributionLabeled is an xr.Dataset where the only
434 # dimension is "atom", which indexes the random realizations.
435 self.dataset = xr.Dataset(
436 {
437 var_names[i]: xr.DataArray(
438 self.atoms[i],
439 dims=("atom"),
440 attrs=var_attrs[i],
441 )
442 for i in range(n_var)
443 },
444 attrs=attrs,
445 )
447 # the probability mass values are stored in
448 # a DataArray with dimension "atom"
449 self.probability = xr.DataArray(self.pmv, dims=("atom"))
451 @classmethod
452 def from_unlabeled(
453 cls,
454 dist,
455 name="DiscreteDistributionLabeled",
456 attrs=None,
457 var_names=None,
458 var_attrs=None,
459 ):
460 ldd = cls(
461 dist.pmv,
462 dist.atoms,
463 seed=dist.seed,
464 limit=dist.limit,
465 name=name,
466 attrs=attrs,
467 var_names=var_names,
468 var_attrs=var_attrs,
469 )
471 return ldd
473 @classmethod
474 def from_dataset(cls, x_obj, pmf):
475 ldd = cls.__new__(cls)
477 if isinstance(x_obj, xr.Dataset):
478 ldd.dataset = x_obj
479 elif isinstance(x_obj, xr.DataArray):
480 ldd.dataset = xr.Dataset({x_obj.name: x_obj})
481 elif isinstance(x_obj, dict):
482 ldd.dataset = xr.Dataset(x_obj)
484 ldd.probability = pmf
486 return ldd
488 @property
489 def _weighted(self):
490 """
491 Returns a DatasetWeighted object for the distribution.
492 """
493 return self.dataset.weighted(self.probability)
495 @property
496 def variables(self):
497 """
498 A dict-like container of DataArrays corresponding to
499 the variables of the distribution.
500 """
501 return self.dataset.data_vars
503 @property
504 def name(self):
505 """
506 The distribution's name.
507 """
508 return self.dataset.name
510 @property
511 def attrs(self):
512 """
513 The distribution's attributes.
514 """
516 return self.dataset.attrs
518 def dist_of_func(
519 self, func: Callable = lambda x: x, *args, **kwargs
520 ) -> DiscreteDistribution:
521 """
522 Finds the distribution of a random variable Y that is a function
523 of discrete random variable atoms, Y=f(atoms).
525 Parameters
526 ----------
527 func : function
528 The function to be evaluated.
529 This function should take the full array of distribution values.
530 It may also take other arguments \\*args.
531 \\*args :
532 Additional non-stochastic arguments for func,
533 The function is computed as ``f(dstn, *args)``.
534 \\*\\*kwargs :
535 Additional keyword arguments for func. Must be xarray compatible
536 in order to work with xarray broadcasting.
538 Returns
539 -------
540 f_dstn : DiscreteDistribution or DiscreteDistributionLabeled
541 The distribution of func(dstn).
542 """
544 def func_wrapper(x: np.ndarray, *args: Any) -> np.ndarray:
545 """
546 Wrapper function for `func` that handles labeled indexing.
547 """
549 idx = self.variables.keys()
550 wrapped = dict(zip(idx, x))
552 return func(wrapped, *args)
554 if len(kwargs):
555 f_query = func(self.dataset, **kwargs)
556 ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability)
558 return ldd
560 return super().dist_of_func(func_wrapper, *args)
562 def expected(
563 self, func: Optional[Callable] = None, *args: Any, **kwargs: Any
564 ) -> Union[float, np.ndarray]:
565 """
566 Expectation of a function, given an array of configurations of its inputs
567 along with a DiscreteDistributionLabeled object that specifies the probability
568 of each configuration.
570 Parameters
571 ----------
572 func : function
573 The function to be evaluated.
574 This function should take the full array of distribution values
575 and return either arrays of arbitrary shape or scalars.
576 It may also take other arguments \\*args.
577 This function differs from the standalone `calc_expectation`
578 method in that it uses numpy's vectorization and broadcasting
579 rules to avoid costly iteration.
580 Note: If you need to use a function that acts on single outcomes
581 of the distribution, consider `distribution.calc_expectation`.
582 \\*args :
583 Other inputs for func, representing the non-stochastic arguments.
584 The the expectation is computed at ``f(dstn, *args)``.
585 labels : bool
586 If True, the function should use labeled indexing instead of integer
587 indexing using the distribution's underlying rv coordinates. For example,
588 if `dims = ('rv', 'x')` and `coords = {'rv': ['a', 'b'], }`, then
589 the function can be `lambda x: x["a"] + x["b"]`.
591 Returns
592 -------
593 f_exp : np.array or scalar
594 The expectation of the function at the queried values.
595 Scalar if only one value.
596 """
598 def func_wrapper(x, *args):
599 """
600 Wrapper function for `func` that handles labeled indexing.
601 """
603 idx = self.variables.keys()
604 wrapped = dict(zip(idx, x))
606 return func(wrapped, *args)
608 if len(kwargs):
609 f_query = func(self.dataset, *args, **kwargs)
610 ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability)
612 return ldd._weighted.mean("atom")
613 else:
614 if func is None:
615 return super().expected()
616 else:
617 return super().expected(func_wrapper, *args)