Coverage for HARK / distributions / discrete.py: 92%
207 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 06:00 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 06:00 +0000
1from typing import Any, Callable, Dict, List, Optional, Union
3from copy import deepcopy
4import numpy as np
5import xarray as xr
6from scipy import stats
7from scipy.stats import rv_discrete
8from scipy.stats._distn_infrastructure import rv_discrete_frozen
10from HARK.distributions.base import Distribution
13def _weighted_mean_var(var, pmv):
14 """Compute weighted mean of a single DataArray over its ``atom`` dimension.
16 If *var* is not an :class:`xr.DataArray` or lacks an ``atom`` dimension it
17 is returned unchanged. This pass-through is intentional: variables without
18 an atom dimension (e.g. scalar metadata or grid coordinates) are preserved
19 as-is when called from :func:`_weighted_mean`.
20 """
21 if not isinstance(var, xr.DataArray) or "atom" not in var.dims:
22 return var
23 atom_axis = var.dims.index("atom")
24 avg = np.tensordot(pmv, var.values, axes=([0], [atom_axis]))
25 remaining_dims = tuple(d for d in var.dims if d != "atom")
26 if remaining_dims:
27 remaining_coords = {d: var.coords[d] for d in remaining_dims if d in var.coords}
28 return xr.DataArray(avg, dims=remaining_dims, coords=remaining_coords)
29 return float(avg) if np.ndim(avg) == 0 else avg
32def _weighted_mean(data, pmv):
33 """Compute weighted mean over the ``atom`` dimension using numpy.
35 Used in the kwargs path of ``DDL.expected()`` when the caller passes
36 xarray-compatible keyword arguments. Replaces the slow
37 ``Dataset.weighted(prob).mean("atom")`` with ``np.tensordot`` per
38 variable. The common no-kwargs path uses a separate ``np.dot`` on
39 the cached ``_wrapped_atoms`` dict and does not call this function.
41 Returns an :class:`xr.Dataset` when *data* is a Dataset or dict,
42 an :class:`xr.DataArray` or scalar when *data* is a DataArray.
43 Raises :class:`TypeError` for other types.
44 """
45 if isinstance(data, xr.Dataset):
46 return xr.Dataset(
47 {name: _weighted_mean_var(var, pmv) for name, var in data.data_vars.items()}
48 )
49 elif isinstance(data, dict):
50 return xr.Dataset(
51 {
52 name: _weighted_mean_var(val, pmv)
53 if isinstance(val, xr.DataArray)
54 else val
55 for name, val in data.items()
56 }
57 )
58 elif isinstance(data, xr.DataArray):
59 return _weighted_mean_var(data, pmv)
60 raise TypeError(
61 f"_weighted_mean: unsupported data type '{type(data).__name__}'. "
62 "Expected xr.Dataset, xr.DataArray, or dict. "
63 "Ensure the function passed to expected() returns one of these types "
64 "when keyword arguments are used."
65 )
68class DiscreteFrozenDistribution(rv_discrete_frozen, Distribution):
69 """
70 Parameterized discrete distribution from scipy.stats with seed management.
71 """
73 def __init__(
74 self, dist: rv_discrete, *args: Any, seed: int = None, **kwds: Any
75 ) -> None:
76 """
77 Parameterized discrete distribution from scipy.stats with seed management.
79 Parameters
80 ----------
81 dist : rv_discrete
82 Discrete distribution from scipy.stats.
83 seed : int, optional
84 Seed for random number generator, by default 0
85 """
87 rv_discrete_frozen.__init__(self, dist, *args, **kwds)
88 Distribution.__init__(self, seed=seed)
91class Bernoulli(DiscreteFrozenDistribution):
92 """
93 A Bernoulli distribution.
95 Parameters
96 ----------
97 p : float or [float]
98 Probability or probabilities of the event occurring (True).
100 seed : int
101 Seed for random number generator.
102 """
104 def __init__(self, p=0.5, seed=None):
105 self.p = np.asarray(p)
106 # Set up the RNG
107 super().__init__(stats.bernoulli, p=self.p, seed=seed)
109 self.pmv = np.array([1 - self.p, self.p])
110 self.atoms = np.array(
111 [[0, 1]]
112 ) # Ensure atoms is properly shaped like other distributions
113 self.limit = {
114 "dist": self,
115 "infimum": np.array([0.0]),
116 "supremum": np.array([1.0]),
117 }
118 self.infimum = np.array([0.0])
119 self.supremum = np.array([1.0])
121 def dim(self):
122 """
123 Last dimension of self.atoms indexes "atom."
124 """
125 return self.atoms.shape[:-1]
128class DiscreteDistribution(Distribution):
129 """
130 A representation of a discrete probability distribution.
132 Parameters
133 ----------
134 pmv : np.array
135 An array of floats representing a probability mass function.
136 atoms : np.array
137 Discrete point values for each probability mass.
138 For multivariate distributions, the last dimension of atoms must index
139 "atom" or the random realization. For instance, if atoms.shape == (2,6,4),
140 the random variable has 4 possible realizations and each of them has shape (2,6).
141 limit : dict
142 Dictionary with information about the continuous distribution from which
143 this distribution was generated. The reference distribution is in the entry
144 called 'dist'.
145 seed : int
146 Seed for random number generator.
147 """
149 def __init__(
150 self,
151 pmv: np.ndarray,
152 atoms: np.ndarray,
153 seed: int = None,
154 limit: Optional[Dict[str, Any]] = None,
155 ) -> None:
156 super().__init__(seed=seed)
158 self.pmv = np.asarray(pmv)
159 self.atoms = np.atleast_2d(atoms)
160 if limit is None:
161 limit = {
162 "infimum": np.min(self.atoms, axis=-1),
163 "supremum": np.max(self.atoms, axis=-1),
164 }
165 self.limit = limit
167 # Check that pmv and atoms have compatible dimensions.
168 if not self.pmv.size == self.atoms.shape[-1]:
169 raise ValueError(
170 "Provided pmv and atoms arrays have incompatible dimensions. "
171 + "The length of the pmv must be equal to that of atoms's last dimension."
172 )
174 def __repr__(self):
175 out = self.__class__.__name__ + " with " + str(self.pmv.size) + " atoms, "
176 if self.atoms.shape[0] > 1:
177 out += "inf=" + str(tuple(self.limit["infimum"])) + ", "
178 out += "sup=" + str(tuple(self.limit["supremum"])) + ", "
179 else:
180 out += "inf=" + str(self.limit["infimum"][0]) + ", "
181 out += "sup=" + str(self.limit["supremum"][0]) + ", "
182 out += "seed=" + str(self.seed)
183 return out
185 def dim(self) -> int:
186 """
187 Last dimension of self.atoms indexes "atom."
188 """
189 return self.atoms.shape[:-1]
191 def draw_events(self, N: int) -> np.ndarray:
192 """
193 Draws N 'events' from the distribution PMF.
194 These events are indices into atoms.
195 """
196 # Generate a cumulative distribution
197 base_draws = self._rng.uniform(size=N)
198 cum_dist = np.cumsum(self.pmv)
200 # Convert the basic uniform draws into discrete draws
201 indices = cum_dist.searchsorted(base_draws)
203 return indices
205 def draw(
206 self,
207 N: int,
208 atoms: Union[None, int, np.ndarray] = None,
209 shuffle: bool = False,
210 ) -> np.ndarray:
211 """
212 Simulates N draws from a discrete distribution with probabilities P and outcomes atoms.
214 Parameters
215 ----------
216 N : int
217 Number of draws to simulate.
218 atoms : None, int, or np.array
219 If None, then use this distribution's atoms for point values.
220 If an int, then the index of atoms for the point values.
221 If an np.array, use the array for the point values.
222 shuffle : boolean
223 Whether the draws should "shuffle" the discrete distribution, matching
224 proportions of outcomes as closely as possible to the probabilities given
225 finite draws. When True, returned draws are a random permutation of the
226 N-length list that best fits the discrete distribution. When False
227 (default), each draw is independent from the others and the result could
228 deviate from the probabilities.
230 Returns
231 -------
232 draws : np.array
233 An array of draws from the discrete distribution; each element is a value in atoms.
234 """
235 if atoms is None:
236 atoms = self.atoms
237 elif isinstance(atoms, int):
238 atoms = self.atoms[atoms]
240 # "Shuffle" an almost-exact population of draws based on the pmv
241 if shuffle:
242 P = self.pmv
243 K_exact = N * P # slots per outcome in real numbers
244 K = np.floor(K_exact).astype(int) # number of slots allocated to each atom
245 M = N - np.sum(K) # number of unallocated slots
246 J = P.size
247 eps = 1.0 / N
248 Q = K_exact - eps * K # "missing" probability mass
249 draws = self._rng.random(M) # uniform draws for "extra" slots
251 # Fill in each unallocated slot, one by one
252 for m in range(M):
253 Q_adj = Q / np.sum(Q) # probabilities for this pass
254 Q_sum = np.cumsum(Q_adj)
255 j = np.searchsorted(Q_sum, draws[m]) # find index for this draw
256 K[j] += 1 # increment its allocated slots
257 Q[j] = 0.0 # zero out its probability because we used it
259 # Make an array of atom indices based on the final slot counts
260 nested_events = [K[j] * [j] for j in range(J)]
261 events = np.array([i for sublist in nested_events for i in sublist])
263 # Draw a random permutation of the indices
264 indices = self._rng.permutation(events)
266 # Draw event indices randomly from the discrete distribution
267 else:
268 indices = self.draw_events(N)
270 # Create and fill in the output array of draws based on the output of event indices
271 draws = atoms[..., indices]
273 # TODO: some models expect univariate draws to just be a 1d vector. Fix those models.
274 if len(draws.shape) == 2 and draws.shape[0] == 1:
275 draws = draws.flatten()
277 return draws
279 def expected(
280 self, func: Optional[Callable] = None, *args: np.ndarray
281 ) -> np.ndarray:
282 """
283 Expected value of a function, given an array of configurations of its
284 inputs along with a DiscreteDistribution object that specifies the
285 probability of each configuration.
287 If no function is provided, it's much faster to go straight to dot
288 product instead of calling the dummy function.
290 If a function is provided, we need to add one more dimension,
291 the atom dimension, to any inputs that are n-dim arrays.
292 This allows numpy to easily broadcast the function's output.
293 For more information on broadcasting, see:
294 https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules
296 Parameters
297 ----------
298 func : function
299 The function to be evaluated.
300 This function should take the full array of distribution values
301 and return either arrays of arbitrary shape or scalars.
302 It may also take other arguments \\*args.
303 This function differs from the standalone `calc_expectation`
304 method in that it uses numpy's vectorization and broadcasting
305 rules to avoid costly iteration.
306 Note: If you need to use a function that acts on single outcomes
307 of the distribution, consider `distribution.calc_expectation`.
308 \\*args :
309 Other inputs for func, representing the non-stochastic arguments.
310 The the expectation is computed at ``f(dstn, *args)``.
312 Returns
313 -------
314 f_exp : np.array or scalar
315 The expectation of the function at the queried values.
316 Scalar if only one value.
317 """
319 if func is None:
320 return np.dot(self.atoms, self.pmv)
322 if args:
323 args = [
324 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg
325 for arg in args
326 ]
327 return np.dot(func(self.atoms, *args), self.pmv)
328 return np.dot(func(self.atoms), self.pmv)
330 def dist_of_func(
331 self, func: Callable[..., float] = lambda x: x, *args: Any
332 ) -> "DiscreteDistribution":
333 """
334 Finds the distribution of a random variable Y that is a function
335 of discrete random variable atoms, Y=f(atoms).
337 Parameters
338 ----------
339 func : function
340 The function to be evaluated.
341 This function should take the full array of distribution values.
342 It may also take other arguments \\*args.
343 \\*args :
344 Additional non-stochastic arguments for func,
345 The function is computed as ``f(dstn, *args)``.
347 Returns
348 -------
349 f_dstn : DiscreteDistribution
350 The distribution of func(dstn).
351 """
352 # we need to add one more dimension,
353 # the atom dimension, to any inputs that are n-dim arrays.
354 # This allows numpy to easily broadcast the function's output.
355 args = [
356 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg
357 for arg in args
358 ]
359 f_query = func(self.atoms, *args)
361 f_dstn = DiscreteDistribution(list(self.pmv), f_query, seed=self.seed)
363 return f_dstn
365 def discretize(self, N: int, *args: Any, **kwargs: Any) -> "DiscreteDistribution":
366 """
367 `DiscreteDistribution` is already an approximation, so this method
368 returns a copy of the distribution.
370 TODO: print warning message?
371 """
372 return self
374 def make_univariate(self, dim_to_keep, seed=0):
375 """
376 Make a univariate discrete distribution from this distribution, keeping
377 only the specified dimension.
379 Parameters
380 ----------
381 dim_to_keep : int
382 Index of the distribution to be kept. Any other dimensions will be
383 "collapsed" into the univariate atoms, combining probabilities.
384 seed : int, optional
385 Seed for random number generator of univariate distribution
387 Returns
388 -------
389 univariate_dstn : DiscreteDistribution
390 Univariate distribution with only the specified index.
391 """
392 # Do basic validity and triviality checks
393 if (self.atoms.shape[0] == 1) and (dim_to_keep == 0):
394 return deepcopy(self) # Return copy of self if only one dimension
395 if dim_to_keep >= self.atoms.shape[0]:
396 raise ValueError("dim_to_keep exceeds dimensionality of distribution.")
398 # Construct values and probabilities for univariate distribution
399 atoms_temp = self.atoms[dim_to_keep]
400 vals_to_keep = np.unique(atoms_temp)
401 probs_to_keep = np.zeros_like(vals_to_keep)
402 for i in range(vals_to_keep.size):
403 val = vals_to_keep[i]
404 these = atoms_temp == val
405 probs_to_keep[i] = np.sum(self.pmv[these])
407 # Make and return the univariate distribution
408 univariate_dstn = DiscreteDistribution(
409 pmv=probs_to_keep, atoms=vals_to_keep, seed=seed
410 )
411 return univariate_dstn
414class DiscreteDistributionLabeled(DiscreteDistribution):
415 """
416 A representation of a discrete probability distribution stored in an underlying `xarray.Dataset`.
418 Parameters
419 ----------
420 pmv : np.array
421 An array of values representing a probability mass function.
422 data : np.array
423 Discrete point values for each probability mass.
424 For multivariate distributions, the last dimension of atoms must index
425 "atom" or the random realization. For instance, if atoms.shape == (2,6,4),
426 the random variable has 4 possible realizations and each of them has shape (2,6).
427 seed : int
428 Seed for random number generator.
429 name : str
430 Name of the distribution.
431 attrs : dict
432 Attributes for the distribution.
433 var_names : list of str
434 Names of the variables in the distribution.
435 var_attrs : list of dict
436 Attributes of the variables in the distribution.
437 """
439 def __init__(
440 self,
441 pmv: np.ndarray,
442 atoms: np.ndarray,
443 seed: int = None,
444 limit: Optional[Dict[str, Any]] = None,
445 name: str = "DiscreteDistributionLabeled",
446 attrs: Optional[Dict[str, Any]] = None,
447 var_names: Optional[List[str]] = None,
448 var_attrs: Optional[List[Optional[Dict[str, Any]]]] = None,
449 ):
450 super().__init__(pmv, atoms, seed=seed, limit=limit)
452 # vector-value distributions
454 if self.atoms.ndim > 2:
455 raise NotImplementedError(
456 "Only vector-valued distributions are supported for now."
457 )
459 attrs = {} if attrs is None else attrs
460 if limit is None:
461 limit = {
462 "infimum": np.min(self.atoms, axis=-1),
463 "supremum": np.max(self.atoms, axis=-1),
464 }
465 self.limit = limit
466 attrs.update(limit)
467 attrs["name"] = name
469 n_var = self.atoms.shape[0]
471 # give dummy names to variables if none are provided
472 if var_names is None:
473 var_names = ["var_" + str(i) for i in range(n_var)]
475 assert len(var_names) == n_var, (
476 "Number of variable names does not match number of variables."
477 )
479 # give dummy attributes to variables if none are provided
480 if var_attrs is None:
481 var_attrs = [None] * n_var
483 # a DiscreteDistributionLabeled is an xr.Dataset where the only
484 # dimension is "atom", which indexes the random realizations.
485 self.dataset = xr.Dataset(
486 {
487 var_names[i]: xr.DataArray(
488 self.atoms[i],
489 dims=("atom"),
490 attrs=var_attrs[i],
491 )
492 for i in range(n_var)
493 },
494 attrs=attrs,
495 )
497 # the probability mass values are stored in
498 # a DataArray with dimension "atom"
499 self.probability = xr.DataArray(self.pmv, dims=("atom"))
501 # cache for fast labeled access in expected()
502 self._var_names = var_names
503 self._wrapped_atoms = dict(zip(var_names, self.atoms))
505 @classmethod
506 def from_unlabeled(
507 cls,
508 dist,
509 name="DiscreteDistributionLabeled",
510 attrs=None,
511 var_names=None,
512 var_attrs=None,
513 ):
514 ldd = cls(
515 dist.pmv,
516 dist.atoms,
517 seed=dist.seed,
518 limit=dist.limit,
519 name=name,
520 attrs=attrs,
521 var_names=var_names,
522 var_attrs=var_attrs,
523 )
525 return ldd
527 @classmethod
528 def from_dataset(cls, x_obj, pmf):
529 ldd = cls.__new__(cls)
531 if isinstance(x_obj, xr.Dataset):
532 ldd.dataset = x_obj
533 elif isinstance(x_obj, xr.DataArray):
534 ldd.dataset = xr.Dataset({x_obj.name: x_obj})
535 elif isinstance(x_obj, dict):
536 ldd.dataset = xr.Dataset(x_obj)
537 else:
538 raise TypeError(
539 f"from_dataset: 'x_obj' must be an xr.Dataset, xr.DataArray, "
540 f"or dict, got {type(x_obj).__name__}."
541 )
543 ldd.probability = pmf
544 ldd.pmv = np.asarray(pmf)
545 ldd._var_names = list(ldd.dataset.data_vars)
547 # Derive atoms from dataset variables that have the "atom" dimension.
548 atom_vars = [
549 v
550 for v in ldd._var_names
551 if "atom" in ldd.dataset[v].dims and ldd.dataset[v].ndim == 1
552 ]
553 if atom_vars:
554 ldd.atoms = np.stack([ldd.dataset[v].values for v in atom_vars])
555 else:
556 ldd.atoms = np.atleast_2d(np.zeros(len(ldd.pmv)))
558 ldd.limit = {
559 "infimum": np.min(ldd.atoms, axis=-1),
560 "supremum": np.max(ldd.atoms, axis=-1),
561 }
562 # No seed argument available; default to 0 for consistency.
563 ldd.seed = 0
564 ldd._rng = np.random.default_rng(0)
565 # cache for fast labeled access in expected()
566 ldd._wrapped_atoms = dict(zip(atom_vars, ldd.atoms))
568 return ldd
570 @property
571 def variables(self):
572 """
573 A dict-like container of DataArrays corresponding to
574 the variables of the distribution.
575 """
576 return self.dataset.data_vars
578 @property
579 def name(self):
580 """
581 The distribution's name.
582 """
583 return self.dataset.name
585 @property
586 def attrs(self):
587 """
588 The distribution's attributes.
589 """
591 return self.dataset.attrs
593 def dist_of_func(
594 self, func: Callable = lambda x: x, *args, **kwargs
595 ) -> DiscreteDistribution:
596 """
597 Finds the distribution of a random variable Y that is a function
598 of discrete random variable atoms, Y=f(atoms).
600 Parameters
601 ----------
602 func : function
603 The function to be evaluated.
604 This function should take the full array of distribution values.
605 It may also take other arguments \\*args.
606 \\*args :
607 Additional non-stochastic arguments for func,
608 The function is computed as ``f(dstn, *args)``.
609 \\*\\*kwargs :
610 Additional keyword arguments for func. Must be xarray compatible
611 in order to work with xarray broadcasting.
613 Returns
614 -------
615 f_dstn : DiscreteDistribution or DiscreteDistributionLabeled
616 The distribution of func(dstn).
617 """
619 def func_wrapper(x: np.ndarray, *args: Any) -> np.ndarray:
620 """
621 Wrapper function for `func` that handles labeled indexing.
622 """
624 idx = self.variables.keys()
625 wrapped = dict(zip(idx, x))
627 return func(wrapped, *args)
629 if len(kwargs):
630 f_query = func(self.dataset, **kwargs)
631 ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability)
633 return ldd
635 return super().dist_of_func(func_wrapper, *args)
637 def expected(
638 self, func: Optional[Callable] = None, *args: Any, **kwargs: Any
639 ) -> Union[float, np.ndarray]:
640 """
641 Expectation of a function, given an array of configurations of its inputs
642 along with a DiscreteDistributionLabeled object that specifies the probability
643 of each configuration.
645 Parameters
646 ----------
647 func : function
648 The function to be evaluated.
649 By default, func receives a dict mapping variable names to numpy
650 arrays, e.g. ``{"perm_shk": array(...), "tran_shk": array(...)}``.
651 When ``labels=False``, func receives the raw numpy atoms array
652 instead (integer indexing, like ``DiscreteDistribution``).
653 When extra keyword arguments are passed, func receives the full
654 ``xr.Dataset``.
655 It may also take other arguments \\*args.
656 This function differs from the standalone `calc_expectation`
657 method in that it uses numpy's vectorization and broadcasting
658 rules to avoid costly iteration.
659 Note: If you need to use a function that acts on single outcomes
660 of the distribution, consider `distribution.calc_expectation`.
661 \\*args :
662 Other inputs for func, representing the non-stochastic arguments.
663 The the expectation is computed at ``f(dstn, *args)``.
664 labels : bool, optional
665 If True (default), func receives a dict of labeled arrays.
666 If False, func receives the raw numpy atoms array, making DDL
667 compatible with functions written for ``DiscreteDistribution``.
669 Returns
670 -------
671 f_exp : np.array or scalar
672 The expectation of the function at the queried values.
673 Scalar if only one value.
674 """
675 labels = kwargs.pop("labels", True)
677 if kwargs:
678 if func is None:
679 return _weighted_mean(self.dataset, self.pmv)
680 f_query = func(self.dataset, *args, **kwargs)
681 return _weighted_mean(f_query, self.pmv)
683 if func is None:
684 return np.dot(self.atoms, self.pmv)
686 # labels=False: pass raw numpy atoms, like DiscreteDistribution
687 if not labels:
688 if args:
689 args = [
690 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg
691 for arg in args
692 ]
693 return np.dot(func(self.atoms, *args), self.pmv)
694 return np.dot(func(self.atoms), self.pmv)
696 # Fast labeled path: use cached dict {var_name: np.ndarray} instead of
697 # the xarray dataset. _wrapped_atoms maps variable names to the raw
698 # atom arrays, enabling np.dot without xarray overhead.
699 if args:
700 args = [
701 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg
702 for arg in args
703 ]
704 return np.dot(func(self._wrapped_atoms, *args), self.pmv)
705 return np.dot(func(self._wrapped_atoms), self.pmv)