Coverage for HARK / distributions / discrete.py: 92%
173 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 05:31 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 05:31 +0000
1from typing import Any, Callable, Dict, List, Optional, Union
3from copy import deepcopy
4import numpy as np
5import xarray as xr
6from scipy import stats
7from scipy.stats import rv_discrete
8from scipy.stats._distn_infrastructure import rv_discrete_frozen
10from HARK.distributions.base import Distribution
13class DiscreteFrozenDistribution(rv_discrete_frozen, Distribution):
14 """
15 Parameterized discrete distribution from scipy.stats with seed management.
16 """
18 def __init__(
19 self, dist: rv_discrete, *args: Any, seed: int = None, **kwds: Any
20 ) -> None:
21 """
22 Parameterized discrete distribution from scipy.stats with seed management.
24 Parameters
25 ----------
26 dist : rv_discrete
27 Discrete distribution from scipy.stats.
28 seed : int, optional
29 Seed for random number generator, by default 0
30 """
32 rv_discrete_frozen.__init__(self, dist, *args, **kwds)
33 Distribution.__init__(self, seed=seed)
36class Bernoulli(DiscreteFrozenDistribution):
37 """
38 A Bernoulli distribution.
40 Parameters
41 ----------
42 p : float or [float]
43 Probability or probabilities of the event occurring (True).
45 seed : int
46 Seed for random number generator.
47 """
49 def __init__(self, p=0.5, seed=None):
50 self.p = np.asarray(p)
51 # Set up the RNG
52 super().__init__(stats.bernoulli, p=self.p, seed=seed)
54 self.pmv = np.array([1 - self.p, self.p])
55 self.atoms = np.array(
56 [[0, 1]]
57 ) # Ensure atoms is properly shaped like other distributions
58 self.limit = {
59 "dist": self,
60 "infimum": np.array([0.0]),
61 "supremum": np.array([1.0]),
62 }
63 self.infimum = np.array([0.0])
64 self.supremum = np.array([1.0])
66 def dim(self):
67 """
68 Last dimension of self.atoms indexes "atom."
69 """
70 return self.atoms.shape[:-1]
73class DiscreteDistribution(Distribution):
74 """
75 A representation of a discrete probability distribution.
77 Parameters
78 ----------
79 pmv : np.array
80 An array of floats representing a probability mass function.
81 atoms : np.array
82 Discrete point values for each probability mass.
83 For multivariate distributions, the last dimension of atoms must index
84 "atom" or the random realization. For instance, if atoms.shape == (2,6,4),
85 the random variable has 4 possible realizations and each of them has shape (2,6).
86 limit : dict
87 Dictionary with information about the continuous distribution from which
88 this distribution was generated. The reference distribution is in the entry
89 called 'dist'.
90 seed : int
91 Seed for random number generator.
92 """
94 def __init__(
95 self,
96 pmv: np.ndarray,
97 atoms: np.ndarray,
98 seed: int = None,
99 limit: Optional[Dict[str, Any]] = None,
100 ) -> None:
101 super().__init__(seed=seed)
103 self.pmv = np.asarray(pmv)
104 self.atoms = np.atleast_2d(atoms)
105 if limit is None:
106 limit = {
107 "infimum": np.min(self.atoms, axis=-1),
108 "supremum": np.max(self.atoms, axis=-1),
109 }
110 self.limit = limit
112 # Check that pmv and atoms have compatible dimensions.
113 if not self.pmv.size == self.atoms.shape[-1]:
114 raise ValueError(
115 "Provided pmv and atoms arrays have incompatible dimensions. "
116 + "The length of the pmv must be equal to that of atoms's last dimension."
117 )
119 def __repr__(self):
120 out = self.__class__.__name__ + " with " + str(self.pmv.size) + " atoms, "
121 if self.atoms.shape[0] > 1:
122 out += "inf=" + str(tuple(self.limit["infimum"])) + ", "
123 out += "sup=" + str(tuple(self.limit["supremum"])) + ", "
124 else:
125 out += "inf=" + str(self.limit["infimum"][0]) + ", "
126 out += "sup=" + str(self.limit["supremum"][0]) + ", "
127 out += "seed=" + str(self.seed)
128 return out
130 def dim(self) -> int:
131 """
132 Last dimension of self.atoms indexes "atom."
133 """
134 return self.atoms.shape[:-1]
136 def draw_events(self, N: int) -> np.ndarray:
137 """
138 Draws N 'events' from the distribution PMF.
139 These events are indices into atoms.
140 """
141 # Generate a cumulative distribution
142 base_draws = self._rng.uniform(size=N)
143 cum_dist = np.cumsum(self.pmv)
145 # Convert the basic uniform draws into discrete draws
146 indices = cum_dist.searchsorted(base_draws)
148 return indices
150 def draw(
151 self,
152 N: int,
153 atoms: Union[None, int, np.ndarray] = None,
154 shuffle: bool = False,
155 ) -> np.ndarray:
156 """
157 Simulates N draws from a discrete distribution with probabilities P and outcomes atoms.
159 Parameters
160 ----------
161 N : int
162 Number of draws to simulate.
163 atoms : None, int, or np.array
164 If None, then use this distribution's atoms for point values.
165 If an int, then the index of atoms for the point values.
166 If an np.array, use the array for the point values.
167 shuffle : boolean
168 Whether the draws should "shuffle" the discrete distribution, matching
169 proportions of outcomes as closely as possible to the probabilities given
170 finite draws. When True, returned draws are a random permutation of the
171 N-length list that best fits the discrete distribution. When False
172 (default), each draw is independent from the others and the result could
173 deviate from the probabilities.
175 Returns
176 -------
177 draws : np.array
178 An array of draws from the discrete distribution; each element is a value in atoms.
179 """
180 if atoms is None:
181 atoms = self.atoms
182 elif isinstance(atoms, int):
183 atoms = self.atoms[atoms]
185 # "Shuffle" an almost-exact population of draws based on the pmv
186 if shuffle:
187 P = self.pmv
188 K_exact = N * P # slots per outcome in real numbers
189 K = np.floor(K_exact).astype(int) # number of slots allocated to each atom
190 M = N - np.sum(K) # number of unallocated slots
191 J = P.size
192 eps = 1.0 / N
193 Q = K_exact - eps * K # "missing" probability mass
194 draws = self._rng.random(M) # uniform draws for "extra" slots
196 # Fill in each unallocated slot, one by one
197 for m in range(M):
198 Q_adj = Q / np.sum(Q) # probabilities for this pass
199 Q_sum = np.cumsum(Q_adj)
200 j = np.searchsorted(Q_sum, draws[m]) # find index for this draw
201 K[j] += 1 # increment its allocated slots
202 Q[j] = 0.0 # zero out its probability because we used it
204 # Make an array of atom indices based on the final slot counts
205 nested_events = [K[j] * [j] for j in range(J)]
206 events = np.array([i for sublist in nested_events for i in sublist])
208 # Draw a random permutation of the indices
209 indices = self._rng.permutation(events)
211 # Draw event indices randomly from the discrete distribution
212 else:
213 indices = self.draw_events(N)
215 # Create and fill in the output array of draws based on the output of event indices
216 draws = atoms[..., indices]
218 # TODO: some models expect univariate draws to just be a 1d vector. Fix those models.
219 if len(draws.shape) == 2 and draws.shape[0] == 1:
220 draws = draws.flatten()
222 return draws
224 def expected(
225 self, func: Optional[Callable] = None, *args: np.ndarray
226 ) -> np.ndarray:
227 """
228 Expected value of a function, given an array of configurations of its
229 inputs along with a DiscreteDistribution object that specifies the
230 probability of each configuration.
232 If no function is provided, it's much faster to go straight to dot
233 product instead of calling the dummy function.
235 If a function is provided, we need to add one more dimension,
236 the atom dimension, to any inputs that are n-dim arrays.
237 This allows numpy to easily broadcast the function's output.
238 For more information on broadcasting, see:
239 https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules
241 Parameters
242 ----------
243 func : function
244 The function to be evaluated.
245 This function should take the full array of distribution values
246 and return either arrays of arbitrary shape or scalars.
247 It may also take other arguments \\*args.
248 This function differs from the standalone `calc_expectation`
249 method in that it uses numpy's vectorization and broadcasting
250 rules to avoid costly iteration.
251 Note: If you need to use a function that acts on single outcomes
252 of the distribution, consider `distribution.calc_expectation`.
253 \\*args :
254 Other inputs for func, representing the non-stochastic arguments.
255 The the expectation is computed at ``f(dstn, *args)``.
257 Returns
258 -------
259 f_exp : np.array or scalar
260 The expectation of the function at the queried values.
261 Scalar if only one value.
262 """
264 if func is None:
265 f_query = self.atoms
266 else:
267 args = [
268 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg
269 for arg in args
270 ]
272 f_query = func(self.atoms, *args)
274 f_exp = np.dot(f_query, self.pmv)
276 return f_exp
278 def dist_of_func(
279 self, func: Callable[..., float] = lambda x: x, *args: Any
280 ) -> "DiscreteDistribution":
281 """
282 Finds the distribution of a random variable Y that is a function
283 of discrete random variable atoms, Y=f(atoms).
285 Parameters
286 ----------
287 func : function
288 The function to be evaluated.
289 This function should take the full array of distribution values.
290 It may also take other arguments \\*args.
291 \\*args :
292 Additional non-stochastic arguments for func,
293 The function is computed as ``f(dstn, *args)``.
295 Returns
296 -------
297 f_dstn : DiscreteDistribution
298 The distribution of func(dstn).
299 """
300 # we need to add one more dimension,
301 # the atom dimension, to any inputs that are n-dim arrays.
302 # This allows numpy to easily broadcast the function's output.
303 args = [
304 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg
305 for arg in args
306 ]
307 f_query = func(self.atoms, *args)
309 f_dstn = DiscreteDistribution(list(self.pmv), f_query, seed=self.seed)
311 return f_dstn
313 def discretize(self, N: int, *args: Any, **kwargs: Any) -> "DiscreteDistribution":
314 """
315 `DiscreteDistribution` is already an approximation, so this method
316 returns a copy of the distribution.
318 TODO: print warning message?
319 """
320 return self
322 def make_univariate(self, dim_to_keep, seed=0):
323 """
324 Make a univariate discrete distribution from this distribution, keeping
325 only the specified dimension.
327 Parameters
328 ----------
329 dim_to_keep : int
330 Index of the distribution to be kept. Any other dimensions will be
331 "collapsed" into the univariate atoms, combining probabilities.
332 seed : int, optional
333 Seed for random number generator of univariate distribution
335 Returns
336 -------
337 univariate_dstn : DiscreteDistribution
338 Univariate distribution with only the specified index.
339 """
340 # Do basic validity and triviality checks
341 if (self.atoms.shape[0] == 1) and (dim_to_keep == 0):
342 return deepcopy(self) # Return copy of self if only one dimension
343 if dim_to_keep >= self.atoms.shape[0]:
344 raise ValueError("dim_to_keep exceeds dimensionality of distribution.")
346 # Construct values and probabilities for univariate distribution
347 atoms_temp = self.atoms[dim_to_keep]
348 vals_to_keep = np.unique(atoms_temp)
349 probs_to_keep = np.zeros_like(vals_to_keep)
350 for i in range(vals_to_keep.size):
351 val = vals_to_keep[i]
352 these = atoms_temp == val
353 probs_to_keep[i] = np.sum(self.pmv[these])
355 # Make and return the univariate distribution
356 univariate_dstn = DiscreteDistribution(
357 pmv=probs_to_keep, atoms=vals_to_keep, seed=seed
358 )
359 return univariate_dstn
362class DiscreteDistributionLabeled(DiscreteDistribution):
363 """
364 A representation of a discrete probability distribution stored in an underlying `xarray.Dataset`.
366 Parameters
367 ----------
368 pmv : np.array
369 An array of values representing a probability mass function.
370 data : np.array
371 Discrete point values for each probability mass.
372 For multivariate distributions, the last dimension of atoms must index
373 "atom" or the random realization. For instance, if atoms.shape == (2,6,4),
374 the random variable has 4 possible realizations and each of them has shape (2,6).
375 seed : int
376 Seed for random number generator.
377 name : str
378 Name of the distribution.
379 attrs : dict
380 Attributes for the distribution.
381 var_names : list of str
382 Names of the variables in the distribution.
383 var_attrs : list of dict
384 Attributes of the variables in the distribution.
385 """
387 def __init__(
388 self,
389 pmv: np.ndarray,
390 atoms: np.ndarray,
391 seed: int = None,
392 limit: Optional[Dict[str, Any]] = None,
393 name: str = "DiscreteDistributionLabeled",
394 attrs: Optional[Dict[str, Any]] = None,
395 var_names: Optional[List[str]] = None,
396 var_attrs: Optional[List[Optional[Dict[str, Any]]]] = None,
397 ):
398 super().__init__(pmv, atoms, seed=seed, limit=limit)
400 # vector-value distributions
402 if self.atoms.ndim > 2:
403 raise NotImplementedError(
404 "Only vector-valued distributions are supported for now."
405 )
407 attrs = {} if attrs is None else attrs
408 if limit is None:
409 limit = {
410 "infimum": np.min(self.atoms, axis=-1),
411 "supremum": np.max(self.atoms, axis=-1),
412 }
413 self.limit = limit
414 attrs.update(limit)
415 attrs["name"] = name
417 n_var = self.atoms.shape[0]
419 # give dummy names to variables if none are provided
420 if var_names is None:
421 var_names = ["var_" + str(i) for i in range(n_var)]
423 assert len(var_names) == n_var, (
424 "Number of variable names does not match number of variables."
425 )
427 # give dummy attributes to variables if none are provided
428 if var_attrs is None:
429 var_attrs = [None] * n_var
431 # a DiscreteDistributionLabeled is an xr.Dataset where the only
432 # dimension is "atom", which indexes the random realizations.
433 self.dataset = xr.Dataset(
434 {
435 var_names[i]: xr.DataArray(
436 self.atoms[i],
437 dims=("atom"),
438 attrs=var_attrs[i],
439 )
440 for i in range(n_var)
441 },
442 attrs=attrs,
443 )
445 # the probability mass values are stored in
446 # a DataArray with dimension "atom"
447 self.probability = xr.DataArray(self.pmv, dims=("atom"))
449 @classmethod
450 def from_unlabeled(
451 cls,
452 dist,
453 name="DiscreteDistributionLabeled",
454 attrs=None,
455 var_names=None,
456 var_attrs=None,
457 ):
458 ldd = cls(
459 dist.pmv,
460 dist.atoms,
461 seed=dist.seed,
462 limit=dist.limit,
463 name=name,
464 attrs=attrs,
465 var_names=var_names,
466 var_attrs=var_attrs,
467 )
469 return ldd
471 @classmethod
472 def from_dataset(cls, x_obj, pmf):
473 ldd = cls.__new__(cls)
475 if isinstance(x_obj, xr.Dataset):
476 ldd.dataset = x_obj
477 elif isinstance(x_obj, xr.DataArray):
478 ldd.dataset = xr.Dataset({x_obj.name: x_obj})
479 elif isinstance(x_obj, dict):
480 ldd.dataset = xr.Dataset(x_obj)
482 ldd.probability = pmf
484 return ldd
486 @property
487 def _weighted(self):
488 """
489 Returns a DatasetWeighted object for the distribution.
490 """
491 return self.dataset.weighted(self.probability)
493 @property
494 def variables(self):
495 """
496 A dict-like container of DataArrays corresponding to
497 the variables of the distribution.
498 """
499 return self.dataset.data_vars
501 @property
502 def name(self):
503 """
504 The distribution's name.
505 """
506 return self.dataset.name
508 @property
509 def attrs(self):
510 """
511 The distribution's attributes.
512 """
514 return self.dataset.attrs
516 def dist_of_func(
517 self, func: Callable = lambda x: x, *args, **kwargs
518 ) -> DiscreteDistribution:
519 """
520 Finds the distribution of a random variable Y that is a function
521 of discrete random variable atoms, Y=f(atoms).
523 Parameters
524 ----------
525 func : function
526 The function to be evaluated.
527 This function should take the full array of distribution values.
528 It may also take other arguments \\*args.
529 \\*args :
530 Additional non-stochastic arguments for func,
531 The function is computed as ``f(dstn, *args)``.
532 \\*\\*kwargs :
533 Additional keyword arguments for func. Must be xarray compatible
534 in order to work with xarray broadcasting.
536 Returns
537 -------
538 f_dstn : DiscreteDistribution or DiscreteDistributionLabeled
539 The distribution of func(dstn).
540 """
542 def func_wrapper(x: np.ndarray, *args: Any) -> np.ndarray:
543 """
544 Wrapper function for `func` that handles labeled indexing.
545 """
547 idx = self.variables.keys()
548 wrapped = dict(zip(idx, x))
550 return func(wrapped, *args)
552 if len(kwargs):
553 f_query = func(self.dataset, **kwargs)
554 ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability)
556 return ldd
558 return super().dist_of_func(func_wrapper, *args)
560 def expected(
561 self, func: Optional[Callable] = None, *args: Any, **kwargs: Any
562 ) -> Union[float, np.ndarray]:
563 """
564 Expectation of a function, given an array of configurations of its inputs
565 along with a DiscreteDistributionLabeled object that specifies the probability
566 of each configuration.
568 Parameters
569 ----------
570 func : function
571 The function to be evaluated.
572 This function should take the full array of distribution values
573 and return either arrays of arbitrary shape or scalars.
574 It may also take other arguments \\*args.
575 This function differs from the standalone `calc_expectation`
576 method in that it uses numpy's vectorization and broadcasting
577 rules to avoid costly iteration.
578 Note: If you need to use a function that acts on single outcomes
579 of the distribution, consider `distribution.calc_expectation`.
580 \\*args :
581 Other inputs for func, representing the non-stochastic arguments.
582 The the expectation is computed at ``f(dstn, *args)``.
583 labels : bool
584 If True, the function should use labeled indexing instead of integer
585 indexing using the distribution's underlying rv coordinates. For example,
586 if `dims = ('rv', 'x')` and `coords = {'rv': ['a', 'b'], }`, then
587 the function can be `lambda x: x["a"] + x["b"]`.
589 Returns
590 -------
591 f_exp : np.array or scalar
592 The expectation of the function at the queried values.
593 Scalar if only one value.
594 """
596 def func_wrapper(x, *args):
597 """
598 Wrapper function for `func` that handles labeled indexing.
599 """
601 idx = self.variables.keys()
602 wrapped = dict(zip(idx, x))
604 return func(wrapped, *args)
606 if len(kwargs):
607 f_query = func(self.dataset, *args, **kwargs)
608 ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability)
610 return ldd._weighted.mean("atom")
611 else:
612 if func is None:
613 return super().expected()
614 else:
615 return super().expected(func_wrapper, *args)