Coverage for HARK / distributions / discrete.py: 92%

173 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-08 05:31 +0000

1from typing import Any, Callable, Dict, List, Optional, Union 

2 

3from copy import deepcopy 

4import numpy as np 

5import xarray as xr 

6from scipy import stats 

7from scipy.stats import rv_discrete 

8from scipy.stats._distn_infrastructure import rv_discrete_frozen 

9 

10from HARK.distributions.base import Distribution 

11 

12 

13class DiscreteFrozenDistribution(rv_discrete_frozen, Distribution): 

14 """ 

15 Parameterized discrete distribution from scipy.stats with seed management. 

16 """ 

17 

18 def __init__( 

19 self, dist: rv_discrete, *args: Any, seed: int = None, **kwds: Any 

20 ) -> None: 

21 """ 

22 Parameterized discrete distribution from scipy.stats with seed management. 

23 

24 Parameters 

25 ---------- 

26 dist : rv_discrete 

27 Discrete distribution from scipy.stats. 

28 seed : int, optional 

29 Seed for random number generator, by default 0 

30 """ 

31 

32 rv_discrete_frozen.__init__(self, dist, *args, **kwds) 

33 Distribution.__init__(self, seed=seed) 

34 

35 

36class Bernoulli(DiscreteFrozenDistribution): 

37 """ 

38 A Bernoulli distribution. 

39 

40 Parameters 

41 ---------- 

42 p : float or [float] 

43 Probability or probabilities of the event occurring (True). 

44 

45 seed : int 

46 Seed for random number generator. 

47 """ 

48 

49 def __init__(self, p=0.5, seed=None): 

50 self.p = np.asarray(p) 

51 # Set up the RNG 

52 super().__init__(stats.bernoulli, p=self.p, seed=seed) 

53 

54 self.pmv = np.array([1 - self.p, self.p]) 

55 self.atoms = np.array( 

56 [[0, 1]] 

57 ) # Ensure atoms is properly shaped like other distributions 

58 self.limit = { 

59 "dist": self, 

60 "infimum": np.array([0.0]), 

61 "supremum": np.array([1.0]), 

62 } 

63 self.infimum = np.array([0.0]) 

64 self.supremum = np.array([1.0]) 

65 

66 def dim(self): 

67 """ 

68 Last dimension of self.atoms indexes "atom." 

69 """ 

70 return self.atoms.shape[:-1] 

71 

72 

73class DiscreteDistribution(Distribution): 

74 """ 

75 A representation of a discrete probability distribution. 

76 

77 Parameters 

78 ---------- 

79 pmv : np.array 

80 An array of floats representing a probability mass function. 

81 atoms : np.array 

82 Discrete point values for each probability mass. 

83 For multivariate distributions, the last dimension of atoms must index 

84 "atom" or the random realization. For instance, if atoms.shape == (2,6,4), 

85 the random variable has 4 possible realizations and each of them has shape (2,6). 

86 limit : dict 

87 Dictionary with information about the continuous distribution from which 

88 this distribution was generated. The reference distribution is in the entry 

89 called 'dist'. 

90 seed : int 

91 Seed for random number generator. 

92 """ 

93 

94 def __init__( 

95 self, 

96 pmv: np.ndarray, 

97 atoms: np.ndarray, 

98 seed: int = None, 

99 limit: Optional[Dict[str, Any]] = None, 

100 ) -> None: 

101 super().__init__(seed=seed) 

102 

103 self.pmv = np.asarray(pmv) 

104 self.atoms = np.atleast_2d(atoms) 

105 if limit is None: 

106 limit = { 

107 "infimum": np.min(self.atoms, axis=-1), 

108 "supremum": np.max(self.atoms, axis=-1), 

109 } 

110 self.limit = limit 

111 

112 # Check that pmv and atoms have compatible dimensions. 

113 if not self.pmv.size == self.atoms.shape[-1]: 

114 raise ValueError( 

115 "Provided pmv and atoms arrays have incompatible dimensions. " 

116 + "The length of the pmv must be equal to that of atoms's last dimension." 

117 ) 

118 

119 def __repr__(self): 

120 out = self.__class__.__name__ + " with " + str(self.pmv.size) + " atoms, " 

121 if self.atoms.shape[0] > 1: 

122 out += "inf=" + str(tuple(self.limit["infimum"])) + ", " 

123 out += "sup=" + str(tuple(self.limit["supremum"])) + ", " 

124 else: 

125 out += "inf=" + str(self.limit["infimum"][0]) + ", " 

126 out += "sup=" + str(self.limit["supremum"][0]) + ", " 

127 out += "seed=" + str(self.seed) 

128 return out 

129 

130 def dim(self) -> int: 

131 """ 

132 Last dimension of self.atoms indexes "atom." 

133 """ 

134 return self.atoms.shape[:-1] 

135 

136 def draw_events(self, N: int) -> np.ndarray: 

137 """ 

138 Draws N 'events' from the distribution PMF. 

139 These events are indices into atoms. 

140 """ 

141 # Generate a cumulative distribution 

142 base_draws = self._rng.uniform(size=N) 

143 cum_dist = np.cumsum(self.pmv) 

144 

145 # Convert the basic uniform draws into discrete draws 

146 indices = cum_dist.searchsorted(base_draws) 

147 

148 return indices 

149 

150 def draw( 

151 self, 

152 N: int, 

153 atoms: Union[None, int, np.ndarray] = None, 

154 shuffle: bool = False, 

155 ) -> np.ndarray: 

156 """ 

157 Simulates N draws from a discrete distribution with probabilities P and outcomes atoms. 

158 

159 Parameters 

160 ---------- 

161 N : int 

162 Number of draws to simulate. 

163 atoms : None, int, or np.array 

164 If None, then use this distribution's atoms for point values. 

165 If an int, then the index of atoms for the point values. 

166 If an np.array, use the array for the point values. 

167 shuffle : boolean 

168 Whether the draws should "shuffle" the discrete distribution, matching 

169 proportions of outcomes as closely as possible to the probabilities given 

170 finite draws. When True, returned draws are a random permutation of the 

171 N-length list that best fits the discrete distribution. When False 

172 (default), each draw is independent from the others and the result could 

173 deviate from the probabilities. 

174 

175 Returns 

176 ------- 

177 draws : np.array 

178 An array of draws from the discrete distribution; each element is a value in atoms. 

179 """ 

180 if atoms is None: 

181 atoms = self.atoms 

182 elif isinstance(atoms, int): 

183 atoms = self.atoms[atoms] 

184 

185 # "Shuffle" an almost-exact population of draws based on the pmv 

186 if shuffle: 

187 P = self.pmv 

188 K_exact = N * P # slots per outcome in real numbers 

189 K = np.floor(K_exact).astype(int) # number of slots allocated to each atom 

190 M = N - np.sum(K) # number of unallocated slots 

191 J = P.size 

192 eps = 1.0 / N 

193 Q = K_exact - eps * K # "missing" probability mass 

194 draws = self._rng.random(M) # uniform draws for "extra" slots 

195 

196 # Fill in each unallocated slot, one by one 

197 for m in range(M): 

198 Q_adj = Q / np.sum(Q) # probabilities for this pass 

199 Q_sum = np.cumsum(Q_adj) 

200 j = np.searchsorted(Q_sum, draws[m]) # find index for this draw 

201 K[j] += 1 # increment its allocated slots 

202 Q[j] = 0.0 # zero out its probability because we used it 

203 

204 # Make an array of atom indices based on the final slot counts 

205 nested_events = [K[j] * [j] for j in range(J)] 

206 events = np.array([i for sublist in nested_events for i in sublist]) 

207 

208 # Draw a random permutation of the indices 

209 indices = self._rng.permutation(events) 

210 

211 # Draw event indices randomly from the discrete distribution 

212 else: 

213 indices = self.draw_events(N) 

214 

215 # Create and fill in the output array of draws based on the output of event indices 

216 draws = atoms[..., indices] 

217 

218 # TODO: some models expect univariate draws to just be a 1d vector. Fix those models. 

219 if len(draws.shape) == 2 and draws.shape[0] == 1: 

220 draws = draws.flatten() 

221 

222 return draws 

223 

224 def expected( 

225 self, func: Optional[Callable] = None, *args: np.ndarray 

226 ) -> np.ndarray: 

227 """ 

228 Expected value of a function, given an array of configurations of its 

229 inputs along with a DiscreteDistribution object that specifies the 

230 probability of each configuration. 

231 

232 If no function is provided, it's much faster to go straight to dot 

233 product instead of calling the dummy function. 

234 

235 If a function is provided, we need to add one more dimension, 

236 the atom dimension, to any inputs that are n-dim arrays. 

237 This allows numpy to easily broadcast the function's output. 

238 For more information on broadcasting, see: 

239 https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules 

240 

241 Parameters 

242 ---------- 

243 func : function 

244 The function to be evaluated. 

245 This function should take the full array of distribution values 

246 and return either arrays of arbitrary shape or scalars. 

247 It may also take other arguments \\*args. 

248 This function differs from the standalone `calc_expectation` 

249 method in that it uses numpy's vectorization and broadcasting 

250 rules to avoid costly iteration. 

251 Note: If you need to use a function that acts on single outcomes 

252 of the distribution, consider `distribution.calc_expectation`. 

253 \\*args : 

254 Other inputs for func, representing the non-stochastic arguments. 

255 The the expectation is computed at ``f(dstn, *args)``. 

256 

257 Returns 

258 ------- 

259 f_exp : np.array or scalar 

260 The expectation of the function at the queried values. 

261 Scalar if only one value. 

262 """ 

263 

264 if func is None: 

265 f_query = self.atoms 

266 else: 

267 args = [ 

268 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg 

269 for arg in args 

270 ] 

271 

272 f_query = func(self.atoms, *args) 

273 

274 f_exp = np.dot(f_query, self.pmv) 

275 

276 return f_exp 

277 

278 def dist_of_func( 

279 self, func: Callable[..., float] = lambda x: x, *args: Any 

280 ) -> "DiscreteDistribution": 

281 """ 

282 Finds the distribution of a random variable Y that is a function 

283 of discrete random variable atoms, Y=f(atoms). 

284 

285 Parameters 

286 ---------- 

287 func : function 

288 The function to be evaluated. 

289 This function should take the full array of distribution values. 

290 It may also take other arguments \\*args. 

291 \\*args : 

292 Additional non-stochastic arguments for func, 

293 The function is computed as ``f(dstn, *args)``. 

294 

295 Returns 

296 ------- 

297 f_dstn : DiscreteDistribution 

298 The distribution of func(dstn). 

299 """ 

300 # we need to add one more dimension, 

301 # the atom dimension, to any inputs that are n-dim arrays. 

302 # This allows numpy to easily broadcast the function's output. 

303 args = [ 

304 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg 

305 for arg in args 

306 ] 

307 f_query = func(self.atoms, *args) 

308 

309 f_dstn = DiscreteDistribution(list(self.pmv), f_query, seed=self.seed) 

310 

311 return f_dstn 

312 

313 def discretize(self, N: int, *args: Any, **kwargs: Any) -> "DiscreteDistribution": 

314 """ 

315 `DiscreteDistribution` is already an approximation, so this method 

316 returns a copy of the distribution. 

317 

318 TODO: print warning message? 

319 """ 

320 return self 

321 

322 def make_univariate(self, dim_to_keep, seed=0): 

323 """ 

324 Make a univariate discrete distribution from this distribution, keeping 

325 only the specified dimension. 

326 

327 Parameters 

328 ---------- 

329 dim_to_keep : int 

330 Index of the distribution to be kept. Any other dimensions will be 

331 "collapsed" into the univariate atoms, combining probabilities. 

332 seed : int, optional 

333 Seed for random number generator of univariate distribution 

334 

335 Returns 

336 ------- 

337 univariate_dstn : DiscreteDistribution 

338 Univariate distribution with only the specified index. 

339 """ 

340 # Do basic validity and triviality checks 

341 if (self.atoms.shape[0] == 1) and (dim_to_keep == 0): 

342 return deepcopy(self) # Return copy of self if only one dimension 

343 if dim_to_keep >= self.atoms.shape[0]: 

344 raise ValueError("dim_to_keep exceeds dimensionality of distribution.") 

345 

346 # Construct values and probabilities for univariate distribution 

347 atoms_temp = self.atoms[dim_to_keep] 

348 vals_to_keep = np.unique(atoms_temp) 

349 probs_to_keep = np.zeros_like(vals_to_keep) 

350 for i in range(vals_to_keep.size): 

351 val = vals_to_keep[i] 

352 these = atoms_temp == val 

353 probs_to_keep[i] = np.sum(self.pmv[these]) 

354 

355 # Make and return the univariate distribution 

356 univariate_dstn = DiscreteDistribution( 

357 pmv=probs_to_keep, atoms=vals_to_keep, seed=seed 

358 ) 

359 return univariate_dstn 

360 

361 

362class DiscreteDistributionLabeled(DiscreteDistribution): 

363 """ 

364 A representation of a discrete probability distribution stored in an underlying `xarray.Dataset`. 

365 

366 Parameters 

367 ---------- 

368 pmv : np.array 

369 An array of values representing a probability mass function. 

370 data : np.array 

371 Discrete point values for each probability mass. 

372 For multivariate distributions, the last dimension of atoms must index 

373 "atom" or the random realization. For instance, if atoms.shape == (2,6,4), 

374 the random variable has 4 possible realizations and each of them has shape (2,6). 

375 seed : int 

376 Seed for random number generator. 

377 name : str 

378 Name of the distribution. 

379 attrs : dict 

380 Attributes for the distribution. 

381 var_names : list of str 

382 Names of the variables in the distribution. 

383 var_attrs : list of dict 

384 Attributes of the variables in the distribution. 

385 """ 

386 

387 def __init__( 

388 self, 

389 pmv: np.ndarray, 

390 atoms: np.ndarray, 

391 seed: int = None, 

392 limit: Optional[Dict[str, Any]] = None, 

393 name: str = "DiscreteDistributionLabeled", 

394 attrs: Optional[Dict[str, Any]] = None, 

395 var_names: Optional[List[str]] = None, 

396 var_attrs: Optional[List[Optional[Dict[str, Any]]]] = None, 

397 ): 

398 super().__init__(pmv, atoms, seed=seed, limit=limit) 

399 

400 # vector-value distributions 

401 

402 if self.atoms.ndim > 2: 

403 raise NotImplementedError( 

404 "Only vector-valued distributions are supported for now." 

405 ) 

406 

407 attrs = {} if attrs is None else attrs 

408 if limit is None: 

409 limit = { 

410 "infimum": np.min(self.atoms, axis=-1), 

411 "supremum": np.max(self.atoms, axis=-1), 

412 } 

413 self.limit = limit 

414 attrs.update(limit) 

415 attrs["name"] = name 

416 

417 n_var = self.atoms.shape[0] 

418 

419 # give dummy names to variables if none are provided 

420 if var_names is None: 

421 var_names = ["var_" + str(i) for i in range(n_var)] 

422 

423 assert len(var_names) == n_var, ( 

424 "Number of variable names does not match number of variables." 

425 ) 

426 

427 # give dummy attributes to variables if none are provided 

428 if var_attrs is None: 

429 var_attrs = [None] * n_var 

430 

431 # a DiscreteDistributionLabeled is an xr.Dataset where the only 

432 # dimension is "atom", which indexes the random realizations. 

433 self.dataset = xr.Dataset( 

434 { 

435 var_names[i]: xr.DataArray( 

436 self.atoms[i], 

437 dims=("atom"), 

438 attrs=var_attrs[i], 

439 ) 

440 for i in range(n_var) 

441 }, 

442 attrs=attrs, 

443 ) 

444 

445 # the probability mass values are stored in 

446 # a DataArray with dimension "atom" 

447 self.probability = xr.DataArray(self.pmv, dims=("atom")) 

448 

449 @classmethod 

450 def from_unlabeled( 

451 cls, 

452 dist, 

453 name="DiscreteDistributionLabeled", 

454 attrs=None, 

455 var_names=None, 

456 var_attrs=None, 

457 ): 

458 ldd = cls( 

459 dist.pmv, 

460 dist.atoms, 

461 seed=dist.seed, 

462 limit=dist.limit, 

463 name=name, 

464 attrs=attrs, 

465 var_names=var_names, 

466 var_attrs=var_attrs, 

467 ) 

468 

469 return ldd 

470 

471 @classmethod 

472 def from_dataset(cls, x_obj, pmf): 

473 ldd = cls.__new__(cls) 

474 

475 if isinstance(x_obj, xr.Dataset): 

476 ldd.dataset = x_obj 

477 elif isinstance(x_obj, xr.DataArray): 

478 ldd.dataset = xr.Dataset({x_obj.name: x_obj}) 

479 elif isinstance(x_obj, dict): 

480 ldd.dataset = xr.Dataset(x_obj) 

481 

482 ldd.probability = pmf 

483 

484 return ldd 

485 

486 @property 

487 def _weighted(self): 

488 """ 

489 Returns a DatasetWeighted object for the distribution. 

490 """ 

491 return self.dataset.weighted(self.probability) 

492 

493 @property 

494 def variables(self): 

495 """ 

496 A dict-like container of DataArrays corresponding to 

497 the variables of the distribution. 

498 """ 

499 return self.dataset.data_vars 

500 

501 @property 

502 def name(self): 

503 """ 

504 The distribution's name. 

505 """ 

506 return self.dataset.name 

507 

508 @property 

509 def attrs(self): 

510 """ 

511 The distribution's attributes. 

512 """ 

513 

514 return self.dataset.attrs 

515 

516 def dist_of_func( 

517 self, func: Callable = lambda x: x, *args, **kwargs 

518 ) -> DiscreteDistribution: 

519 """ 

520 Finds the distribution of a random variable Y that is a function 

521 of discrete random variable atoms, Y=f(atoms). 

522 

523 Parameters 

524 ---------- 

525 func : function 

526 The function to be evaluated. 

527 This function should take the full array of distribution values. 

528 It may also take other arguments \\*args. 

529 \\*args : 

530 Additional non-stochastic arguments for func, 

531 The function is computed as ``f(dstn, *args)``. 

532 \\*\\*kwargs : 

533 Additional keyword arguments for func. Must be xarray compatible 

534 in order to work with xarray broadcasting. 

535 

536 Returns 

537 ------- 

538 f_dstn : DiscreteDistribution or DiscreteDistributionLabeled 

539 The distribution of func(dstn). 

540 """ 

541 

542 def func_wrapper(x: np.ndarray, *args: Any) -> np.ndarray: 

543 """ 

544 Wrapper function for `func` that handles labeled indexing. 

545 """ 

546 

547 idx = self.variables.keys() 

548 wrapped = dict(zip(idx, x)) 

549 

550 return func(wrapped, *args) 

551 

552 if len(kwargs): 

553 f_query = func(self.dataset, **kwargs) 

554 ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability) 

555 

556 return ldd 

557 

558 return super().dist_of_func(func_wrapper, *args) 

559 

560 def expected( 

561 self, func: Optional[Callable] = None, *args: Any, **kwargs: Any 

562 ) -> Union[float, np.ndarray]: 

563 """ 

564 Expectation of a function, given an array of configurations of its inputs 

565 along with a DiscreteDistributionLabeled object that specifies the probability 

566 of each configuration. 

567 

568 Parameters 

569 ---------- 

570 func : function 

571 The function to be evaluated. 

572 This function should take the full array of distribution values 

573 and return either arrays of arbitrary shape or scalars. 

574 It may also take other arguments \\*args. 

575 This function differs from the standalone `calc_expectation` 

576 method in that it uses numpy's vectorization and broadcasting 

577 rules to avoid costly iteration. 

578 Note: If you need to use a function that acts on single outcomes 

579 of the distribution, consider `distribution.calc_expectation`. 

580 \\*args : 

581 Other inputs for func, representing the non-stochastic arguments. 

582 The the expectation is computed at ``f(dstn, *args)``. 

583 labels : bool 

584 If True, the function should use labeled indexing instead of integer 

585 indexing using the distribution's underlying rv coordinates. For example, 

586 if `dims = ('rv', 'x')` and `coords = {'rv': ['a', 'b'], }`, then 

587 the function can be `lambda x: x["a"] + x["b"]`. 

588 

589 Returns 

590 ------- 

591 f_exp : np.array or scalar 

592 The expectation of the function at the queried values. 

593 Scalar if only one value. 

594 """ 

595 

596 def func_wrapper(x, *args): 

597 """ 

598 Wrapper function for `func` that handles labeled indexing. 

599 """ 

600 

601 idx = self.variables.keys() 

602 wrapped = dict(zip(idx, x)) 

603 

604 return func(wrapped, *args) 

605 

606 if len(kwargs): 

607 f_query = func(self.dataset, *args, **kwargs) 

608 ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability) 

609 

610 return ldd._weighted.mean("atom") 

611 else: 

612 if func is None: 

613 return super().expected() 

614 else: 

615 return super().expected(func_wrapper, *args)