Coverage for HARK / distributions / discrete.py: 92%
202 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-10 06:19 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-10 06:19 +0000
1from typing import Any, Callable, Dict, List, Optional, Union
3from copy import deepcopy
4import numpy as np
5import xarray as xr
6from scipy import stats
7from scipy.stats import rv_discrete
8from scipy.stats._distn_infrastructure import rv_discrete_frozen
10from HARK.distributions.base import Distribution
13def _weighted_mean_var(var, pmv):
14 """Compute weighted mean of a single DataArray over its ``atom`` dimension.
16 If *var* is not an :class:`xr.DataArray` or lacks an ``atom`` dimension it
17 is returned unchanged. This pass-through is intentional: variables without
18 an atom dimension (e.g. scalar metadata or grid coordinates) are preserved
19 as-is when called from :func:`_weighted_mean`.
20 """
21 if not isinstance(var, xr.DataArray) or "atom" not in var.dims:
22 return var
23 atom_axis = var.dims.index("atom")
24 avg = np.tensordot(pmv, var.values, axes=([0], [atom_axis]))
25 remaining_dims = tuple(d for d in var.dims if d != "atom")
26 if remaining_dims:
27 remaining_coords = {d: var.coords[d] for d in remaining_dims if d in var.coords}
28 return xr.DataArray(avg, dims=remaining_dims, coords=remaining_coords)
29 return float(avg) if np.ndim(avg) == 0 else avg
32def _weighted_mean(data, pmv):
33 """Compute weighted mean over the ``atom`` dimension using numpy.
35 Used in the kwargs path of ``DDL.expected()`` when the caller passes
36 xarray-compatible keyword arguments. Replaces the slow
37 ``Dataset.weighted(prob).mean("atom")`` with ``np.tensordot`` per
38 variable. The common no-kwargs path uses a separate ``np.dot`` on
39 the cached ``_wrapped_atoms`` dict and does not call this function.
41 Returns an :class:`xr.Dataset` when *data* is a Dataset or dict,
42 an :class:`xr.DataArray` or scalar when *data* is a DataArray.
43 Raises :class:`TypeError` for other types.
44 """
45 if isinstance(data, xr.Dataset):
46 return xr.Dataset(
47 {name: _weighted_mean_var(var, pmv) for name, var in data.data_vars.items()}
48 )
49 elif isinstance(data, dict):
50 return xr.Dataset(
51 {
52 name: _weighted_mean_var(val, pmv)
53 if isinstance(val, xr.DataArray)
54 else val
55 for name, val in data.items()
56 }
57 )
58 elif isinstance(data, xr.DataArray):
59 return _weighted_mean_var(data, pmv)
60 raise TypeError(
61 f"_weighted_mean: unsupported data type '{type(data).__name__}'. "
62 "Expected xr.Dataset, xr.DataArray, or dict. "
63 "Ensure the function passed to expected() returns one of these types "
64 "when keyword arguments are used."
65 )
68class DiscreteFrozenDistribution(rv_discrete_frozen, Distribution):
69 """
70 Parameterized discrete distribution from scipy.stats with seed management.
71 """
73 def __init__(
74 self, dist: rv_discrete, *args: Any, seed: int = None, **kwds: Any
75 ) -> None:
76 """
77 Parameterized discrete distribution from scipy.stats with seed management.
79 Parameters
80 ----------
81 dist : rv_discrete
82 Discrete distribution from scipy.stats.
83 seed : int, optional
84 Seed for random number generator, by default 0
85 """
87 rv_discrete_frozen.__init__(self, dist, *args, **kwds)
88 Distribution.__init__(self, seed=seed)
91class Bernoulli(DiscreteFrozenDistribution):
92 """
93 A Bernoulli distribution.
95 Parameters
96 ----------
97 p : float or [float]
98 Probability or probabilities of the event occurring (True).
100 seed : int
101 Seed for random number generator.
102 """
104 def __init__(self, p=0.5, seed=None):
105 self.p = np.asarray(p)
106 # Set up the RNG
107 super().__init__(stats.bernoulli, p=self.p, seed=seed)
109 self.pmv = np.array([1 - self.p, self.p])
110 self.atoms = np.array(
111 [[0, 1]]
112 ) # Ensure atoms is properly shaped like other distributions
113 self.limit = {
114 "dist": self,
115 "infimum": np.array([0.0]),
116 "supremum": np.array([1.0]),
117 }
118 self.infimum = np.array([0.0])
119 self.supremum = np.array([1.0])
121 def dim(self):
122 """
123 Last dimension of self.atoms indexes "atom."
124 """
125 return self.atoms.shape[:-1]
128class DiscreteDistribution(Distribution):
129 """
130 A representation of a discrete probability distribution.
132 Parameters
133 ----------
134 pmv : np.array
135 An array of floats representing a probability mass function.
136 atoms : np.array
137 Discrete point values for each probability mass.
138 For multivariate distributions, the last dimension of atoms must index
139 "atom" or the random realization. For instance, if atoms.shape == (2,6,4),
140 the random variable has 4 possible realizations and each of them has shape (2,6).
141 limit : dict
142 Dictionary with information about the continuous distribution from which
143 this distribution was generated. The reference distribution is in the entry
144 called 'dist'.
145 seed : int
146 Seed for random number generator.
147 """
149 def __init__(
150 self,
151 pmv: np.ndarray,
152 atoms: np.ndarray,
153 seed: int = None,
154 limit: Optional[Dict[str, Any]] = None,
155 ) -> None:
156 super().__init__(seed=seed)
158 self.pmv = np.asarray(pmv)
159 self.atoms = np.atleast_2d(atoms)
160 if limit is None:
161 limit = {
162 "infimum": np.min(self.atoms, axis=-1),
163 "supremum": np.max(self.atoms, axis=-1),
164 }
165 self.limit = limit
167 # Check that pmv and atoms have compatible dimensions.
168 if not self.pmv.size == self.atoms.shape[-1]:
169 raise ValueError(
170 "Provided pmv and atoms arrays have incompatible dimensions. "
171 + "The length of the pmv must be equal to that of atoms's last dimension."
172 )
174 def __repr__(self):
175 out = self.__class__.__name__ + " with " + str(self.pmv.size) + " atoms, "
176 if self.atoms.shape[0] > 1:
177 out += "inf=" + str(tuple(self.limit["infimum"])) + ", "
178 out += "sup=" + str(tuple(self.limit["supremum"])) + ", "
179 else:
180 out += "inf=" + str(self.limit["infimum"][0]) + ", "
181 out += "sup=" + str(self.limit["supremum"][0]) + ", "
182 out += "seed=" + str(self.seed)
183 return out
185 def dim(self) -> int:
186 """
187 Last dimension of self.atoms indexes "atom."
188 """
189 return self.atoms.shape[:-1]
191 def draw_events(self, N: int) -> np.ndarray:
192 """
193 Draws N 'events' from the distribution PMF.
194 These events are indices into atoms.
195 """
196 # Generate a cumulative distribution
197 base_draws = self._rng.uniform(size=N)
198 cum_dist = np.cumsum(self.pmv)
200 # Convert the basic uniform draws into discrete draws
201 indices = cum_dist.searchsorted(base_draws)
203 return indices
205 def draw(
206 self,
207 N: int,
208 atoms: Union[None, int, np.ndarray] = None,
209 shuffle: bool = False,
210 ) -> np.ndarray:
211 """
212 Simulates N draws from a discrete distribution with probabilities P and outcomes atoms.
214 Parameters
215 ----------
216 N : int
217 Number of draws to simulate.
218 atoms : None, int, or np.array
219 If None, then use this distribution's atoms for point values.
220 If an int, then the index of atoms for the point values.
221 If an np.array, use the array for the point values.
222 shuffle : boolean
223 Whether the draws should "shuffle" the discrete distribution, matching
224 proportions of outcomes as closely as possible to the probabilities given
225 finite draws. When True, returned draws are a random permutation of the
226 N-length list that best fits the discrete distribution. When False
227 (default), each draw is independent from the others and the result could
228 deviate from the probabilities.
230 Returns
231 -------
232 draws : np.array
233 An array of draws from the discrete distribution; each element is a value in atoms.
234 """
235 if atoms is None:
236 atoms = self.atoms
237 elif isinstance(atoms, int):
238 atoms = self.atoms[atoms]
240 # "Shuffle" an almost-exact population of draws based on the pmv
241 if shuffle:
242 P = self.pmv
243 K_exact = N * P # slots per outcome in real numbers
244 K = np.floor(K_exact).astype(int) # number of slots allocated to each atom
245 M = N - np.sum(K) # number of unallocated slots
246 J = P.size
247 eps = 1.0 / N
248 Q = K_exact - eps * K # "missing" probability mass
249 draws = self._rng.random(M) # uniform draws for "extra" slots
251 # Fill in each unallocated slot, one by one
252 for m in range(M):
253 Q_adj = Q / np.sum(Q) # probabilities for this pass
254 Q_sum = np.cumsum(Q_adj)
255 j = np.searchsorted(Q_sum, draws[m]) # find index for this draw
256 K[j] += 1 # increment its allocated slots
257 Q[j] = 0.0 # zero out its probability because we used it
259 # Make an array of atom indices based on the final slot counts
260 nested_events = [K[j] * [j] for j in range(J)]
261 events = np.array([i for sublist in nested_events for i in sublist])
263 # Draw a random permutation of the indices
264 indices = self._rng.permutation(events)
266 # Draw event indices randomly from the discrete distribution
267 else:
268 indices = self.draw_events(N)
270 # Create and fill in the output array of draws based on the output of event indices
271 draws = atoms[..., indices]
273 # TODO: some models expect univariate draws to just be a 1d vector. Fix those models.
274 if len(draws.shape) == 2 and draws.shape[0] == 1:
275 draws = draws.flatten()
277 return draws
279 def expected(
280 self, func: Optional[Callable] = None, *args: np.ndarray
281 ) -> np.ndarray:
282 """
283 Expected value of a function, given an array of configurations of its
284 inputs along with a DiscreteDistribution object that specifies the
285 probability of each configuration.
287 If no function is provided, it's much faster to go straight to dot
288 product instead of calling the dummy function.
290 If a function is provided, we need to add one more dimension,
291 the atom dimension, to any inputs that are n-dim arrays.
292 This allows numpy to easily broadcast the function's output.
293 For more information on broadcasting, see:
294 https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules
296 Parameters
297 ----------
298 func : function
299 The function to be evaluated.
300 This function should take the full array of distribution values
301 and return either arrays of arbitrary shape or scalars.
302 It may also take other arguments \\*args.
303 Note: If you need to use a function that acts on single outcomes
304 of the distribution, manipulates arrays, or uses branching or logical
305 indexing, use `expected(func, dstn, vectorized=False)` instead.
306 \\*args :
307 Other inputs for func, representing the non-stochastic arguments.
308 The the expectation is computed at ``f(dstn, *args)``.
310 Returns
311 -------
312 f_exp : np.array or scalar
313 The expectation of the function at the queried values.
314 Scalar if only one value.
315 """
317 if func is None:
318 return np.dot(self.atoms, self.pmv)
319 return self._dot_with_pmv(func, self.atoms, args)
321 def _dot_with_pmv(self, func, source, args):
322 """Apply ``func`` to ``source`` and the broadcast-prepared ``*args``,
323 then take the dot product with ``self.pmv``.
325 This helper centralizes the broadcast-expansion step used by
326 ``expected`` before weighting the result by the distribution's
327 probability mass vector.
328 """
329 if args:
330 args = [
331 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg
332 for arg in args
333 ]
334 return np.dot(func(source, *args), self.pmv)
335 return np.dot(func(source), self.pmv)
337 def dist_of_func(
338 self, func: Callable[..., float] = lambda x: x, *args: Any
339 ) -> "DiscreteDistribution":
340 """
341 Finds the distribution of a random variable Y that is a function
342 of discrete random variable atoms, Y=f(atoms).
344 Parameters
345 ----------
346 func : function
347 The function to be evaluated.
348 This function should take the full array of distribution values.
349 It may also take other arguments \\*args.
350 \\*args :
351 Additional non-stochastic arguments for func,
352 The function is computed as ``f(dstn, *args)``.
354 Returns
355 -------
356 f_dstn : DiscreteDistribution
357 The distribution of func(dstn).
358 """
359 # we need to add one more dimension,
360 # the atom dimension, to any inputs that are n-dim arrays.
361 # This allows numpy to easily broadcast the function's output.
362 args = [
363 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg
364 for arg in args
365 ]
366 f_query = func(self.atoms, *args)
368 f_dstn = DiscreteDistribution(list(self.pmv), f_query, seed=self.seed)
370 return f_dstn
372 def discretize(self, N: int, *args: Any, **kwargs: Any) -> "DiscreteDistribution":
373 """
374 `DiscreteDistribution` is already an approximation, so this method
375 returns a copy of the distribution.
377 TODO: print warning message?
378 """
379 return self
381 def make_univariate(self, dim_to_keep, seed=0):
382 """
383 Make a univariate discrete distribution from this distribution, keeping
384 only the specified dimension.
386 Parameters
387 ----------
388 dim_to_keep : int
389 Index of the distribution to be kept. Any other dimensions will be
390 "collapsed" into the univariate atoms, combining probabilities.
391 seed : int, optional
392 Seed for random number generator of univariate distribution
394 Returns
395 -------
396 univariate_dstn : DiscreteDistribution
397 Univariate distribution with only the specified index.
398 """
399 # Do basic validity and triviality checks
400 if (self.atoms.shape[0] == 1) and (dim_to_keep == 0):
401 return deepcopy(self) # Return copy of self if only one dimension
402 if dim_to_keep >= self.atoms.shape[0]:
403 raise ValueError("dim_to_keep exceeds dimensionality of distribution.")
405 # Construct values and probabilities for univariate distribution
406 atoms_temp = self.atoms[dim_to_keep]
407 vals_to_keep = np.unique(atoms_temp)
408 probs_to_keep = np.zeros_like(vals_to_keep)
409 for i in range(vals_to_keep.size):
410 val = vals_to_keep[i]
411 these = atoms_temp == val
412 probs_to_keep[i] = np.sum(self.pmv[these])
414 # Make and return the univariate distribution
415 univariate_dstn = DiscreteDistribution(
416 pmv=probs_to_keep, atoms=vals_to_keep, seed=seed
417 )
418 return univariate_dstn
421class DiscreteDistributionLabeled(DiscreteDistribution):
422 """
423 A representation of a discrete probability distribution stored in an underlying `xarray.Dataset`.
425 Parameters
426 ----------
427 pmv : np.array
428 An array of values representing a probability mass function.
429 data : np.array
430 Discrete point values for each probability mass.
431 For multivariate distributions, the last dimension of atoms must index
432 "atom" or the random realization. For instance, if atoms.shape == (2,6,4),
433 the random variable has 4 possible realizations and each of them has shape (2,6).
434 seed : int
435 Seed for random number generator.
436 name : str
437 Name of the distribution.
438 attrs : dict
439 Attributes for the distribution.
440 var_names : list of str
441 Names of the variables in the distribution.
442 var_attrs : list of dict
443 Attributes of the variables in the distribution.
444 """
446 def __init__(
447 self,
448 pmv: np.ndarray,
449 atoms: np.ndarray,
450 seed: int = None,
451 limit: Optional[Dict[str, Any]] = None,
452 name: str = "DiscreteDistributionLabeled",
453 attrs: Optional[Dict[str, Any]] = None,
454 var_names: Optional[List[str]] = None,
455 var_attrs: Optional[List[Optional[Dict[str, Any]]]] = None,
456 ):
457 super().__init__(pmv, atoms, seed=seed, limit=limit)
459 # vector-value distributions
461 if self.atoms.ndim > 2:
462 raise NotImplementedError(
463 "Only vector-valued distributions are supported for now."
464 )
466 attrs = {} if attrs is None else attrs
467 if limit is None:
468 limit = {
469 "infimum": np.min(self.atoms, axis=-1),
470 "supremum": np.max(self.atoms, axis=-1),
471 }
472 self.limit = limit
473 attrs.update(limit)
474 attrs["name"] = name
476 n_var = self.atoms.shape[0]
478 # give dummy names to variables if none are provided
479 if var_names is None:
480 var_names = ["var_" + str(i) for i in range(n_var)]
482 assert len(var_names) == n_var, (
483 "Number of variable names does not match number of variables."
484 )
486 # give dummy attributes to variables if none are provided
487 if var_attrs is None:
488 var_attrs = [None] * n_var
490 # a DiscreteDistributionLabeled is an xr.Dataset where the only
491 # dimension is "atom", which indexes the random realizations.
492 self.dataset = xr.Dataset(
493 {
494 var_names[i]: xr.DataArray(
495 self.atoms[i],
496 dims=("atom"),
497 attrs=var_attrs[i],
498 )
499 for i in range(n_var)
500 },
501 attrs=attrs,
502 )
504 # the probability mass values are stored in
505 # a DataArray with dimension "atom"
506 self.probability = xr.DataArray(self.pmv, dims=("atom"))
508 # cache for fast labeled access in expected()
509 self._var_names = var_names
510 self._wrapped_atoms = dict(zip(var_names, self.atoms))
512 @classmethod
513 def from_unlabeled(
514 cls,
515 dist,
516 name="DiscreteDistributionLabeled",
517 attrs=None,
518 var_names=None,
519 var_attrs=None,
520 ):
521 ldd = cls(
522 dist.pmv,
523 dist.atoms,
524 seed=dist.seed,
525 limit=dist.limit,
526 name=name,
527 attrs=attrs,
528 var_names=var_names,
529 var_attrs=var_attrs,
530 )
532 return ldd
534 @classmethod
535 def from_dataset(cls, x_obj, pmf):
536 ldd = cls.__new__(cls)
538 if isinstance(x_obj, xr.Dataset):
539 ldd.dataset = x_obj
540 elif isinstance(x_obj, xr.DataArray):
541 ldd.dataset = xr.Dataset({x_obj.name: x_obj})
542 elif isinstance(x_obj, dict):
543 ldd.dataset = xr.Dataset(x_obj)
544 else:
545 raise TypeError(
546 f"from_dataset: 'x_obj' must be an xr.Dataset, xr.DataArray, "
547 f"or dict, got {type(x_obj).__name__}."
548 )
550 ldd.probability = pmf
551 ldd.pmv = np.asarray(pmf)
552 ldd._var_names = list(ldd.dataset.data_vars)
554 # Derive atoms from dataset variables that have the "atom" dimension.
555 atom_vars = [
556 v
557 for v in ldd._var_names
558 if "atom" in ldd.dataset[v].dims and ldd.dataset[v].ndim == 1
559 ]
560 if atom_vars:
561 ldd.atoms = np.stack([ldd.dataset[v].values for v in atom_vars])
562 else:
563 ldd.atoms = np.atleast_2d(np.zeros(len(ldd.pmv)))
565 ldd.limit = {
566 "infimum": np.min(ldd.atoms, axis=-1),
567 "supremum": np.max(ldd.atoms, axis=-1),
568 }
569 # No seed argument available; default to 0 for consistency.
570 ldd.seed = 0
571 ldd._rng = np.random.default_rng(0)
572 # cache for fast labeled access in expected()
573 ldd._wrapped_atoms = dict(zip(atom_vars, ldd.atoms))
575 return ldd
577 @property
578 def variables(self):
579 """
580 A dict-like container of DataArrays corresponding to
581 the variables of the distribution.
582 """
583 return self.dataset.data_vars
585 @property
586 def name(self):
587 """
588 The distribution's name.
589 """
590 return self.dataset.name
592 @property
593 def attrs(self):
594 """
595 The distribution's attributes.
596 """
598 return self.dataset.attrs
600 def dist_of_func(
601 self, func: Callable = lambda x: x, *args, **kwargs
602 ) -> DiscreteDistribution:
603 """
604 Finds the distribution of a random variable Y that is a function
605 of discrete random variable atoms, Y=f(atoms).
607 Parameters
608 ----------
609 func : function
610 The function to be evaluated.
611 This function should take the full array of distribution values.
612 It may also take other arguments \\*args.
613 \\*args :
614 Additional non-stochastic arguments for func,
615 The function is computed as ``f(dstn, *args)``.
616 \\*\\*kwargs :
617 Additional keyword arguments for func. Must be xarray compatible
618 in order to work with xarray broadcasting.
620 Returns
621 -------
622 f_dstn : DiscreteDistribution or DiscreteDistributionLabeled
623 The distribution of func(dstn).
624 """
626 def func_wrapper(x: np.ndarray, *args: Any) -> np.ndarray:
627 """
628 Wrapper function for `func` that handles labeled indexing.
629 """
631 idx = self.variables.keys()
632 wrapped = dict(zip(idx, x))
634 return func(wrapped, *args)
636 if len(kwargs):
637 f_query = func(self.dataset, **kwargs)
638 ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability)
640 return ldd
642 return super().dist_of_func(func_wrapper, *args)
644 def expected(
645 self, func: Optional[Callable] = None, *args: Any, **kwargs: Any
646 ) -> Union[float, np.ndarray]:
647 """
648 Expectation of a function, given an array of configurations of its inputs
649 along with a DiscreteDistributionLabeled object that specifies the probability
650 of each configuration.
652 Parameters
653 ----------
654 func : function
655 The function to be evaluated.
656 By default, func receives a dict mapping variable names to numpy
657 arrays, e.g. ``{"perm_shk": array(...), "tran_shk": array(...)}``.
658 When ``labels=False``, func receives the raw numpy atoms array
659 instead (integer indexing, like ``DiscreteDistribution``).
660 When extra keyword arguments are passed, func receives the full
661 ``xr.Dataset``.
662 It may also take other arguments \\*args.
663 Note: If you need to use a function that acts on single outcomes
664 of the distribution, manipulates arrays, or uses branching or logical
665 indexing, use `expected(func, dstn, vectorized=False)` instead.
666 \\*args :
667 Other inputs for func, representing the non-stochastic arguments.
668 The the expectation is computed at ``f(dstn, *args)``.
669 labels : bool, optional
670 If True (default), func receives a dict of labeled arrays.
671 If False, func receives the raw numpy atoms array, making DDL
672 compatible with functions written for ``DiscreteDistribution``.
674 Returns
675 -------
676 f_exp : np.array or scalar
677 The expectation of the function at the queried values.
678 Scalar if only one value.
679 """
680 labels = kwargs.pop("labels", True)
682 if kwargs:
683 if func is None:
684 return _weighted_mean(self.dataset, self.pmv)
685 f_query = func(self.dataset, *args, **kwargs)
686 return _weighted_mean(f_query, self.pmv)
688 if func is None:
689 return np.dot(self.atoms, self.pmv)
690 # labels=False: pass raw numpy atoms, like DiscreteDistribution.
691 # Otherwise use cached _wrapped_atoms (a {var_name: np.ndarray} dict)
692 # to avoid xarray overhead in the hot path.
693 source = self.atoms if not labels else self._wrapped_atoms
694 return self._dot_with_pmv(func, source, args)