Coverage for HARK / distributions / discrete.py: 92%

173 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-25 05:22 +0000

1from typing import Any, Callable, Dict, List, Optional, Union 

2 

3from copy import deepcopy 

4import numpy as np 

5import xarray as xr 

6from scipy import stats 

7from scipy.stats import rv_discrete 

8from scipy.stats._distn_infrastructure import rv_discrete_frozen 

9 

10from HARK.distributions.base import Distribution 

11 

12 

13class DiscreteFrozenDistribution(rv_discrete_frozen, Distribution): 

14 """ 

15 Parameterized discrete distribution from scipy.stats with seed management. 

16 """ 

17 

18 def __init__( 

19 self, dist: rv_discrete, *args: Any, seed: int = None, **kwds: Any 

20 ) -> None: 

21 """ 

22 Parameterized discrete distribution from scipy.stats with seed management. 

23 

24 Parameters 

25 ---------- 

26 dist : rv_discrete 

27 Discrete distribution from scipy.stats. 

28 seed : int, optional 

29 Seed for random number generator, by default 0 

30 """ 

31 

32 rv_discrete_frozen.__init__(self, dist, *args, **kwds) 

33 Distribution.__init__(self, seed=seed) 

34 

35 

36class Bernoulli(DiscreteFrozenDistribution): 

37 """ 

38 A Bernoulli distribution. 

39 

40 Parameters 

41 ---------- 

42 p : float or [float] 

43 Probability or probabilities of the event occurring (True). 

44 

45 seed : int 

46 Seed for random number generator. 

47 """ 

48 

49 def __init__(self, p=0.5, seed=None): 

50 self.p = np.asarray(p) 

51 # Set up the RNG 

52 super().__init__(stats.bernoulli, p=self.p, seed=seed) 

53 

54 self.pmv = np.array([1 - self.p, self.p]) 

55 self.atoms = np.array( 

56 [[0, 1]] 

57 ) # Ensure atoms is properly shaped like other distributions 

58 self.limit = { 

59 "dist": self, 

60 "infimum": np.array([0.0]), 

61 "supremum": np.array([1.0]), 

62 } 

63 self.infimum = np.array([0.0]) 

64 self.supremum = np.array([1.0]) 

65 

66 def dim(self): 

67 """ 

68 Last dimension of self.atoms indexes "atom." 

69 """ 

70 return self.atoms.shape[:-1] 

71 

72 

73class DiscreteDistribution(Distribution): 

74 """ 

75 A representation of a discrete probability distribution. 

76 

77 Parameters 

78 ---------- 

79 pmv : np.array 

80 An array of floats representing a probability mass function. 

81 atoms : np.array 

82 Discrete point values for each probability mass. 

83 For multivariate distributions, the last dimension of atoms must index 

84 "atom" or the random realization. For instance, if atoms.shape == (2,6,4), 

85 the random variable has 4 possible realizations and each of them has shape (2,6). 

86 limit : dict 

87 Dictionary with information about the continuous distribution from which 

88 this distribution was generated. The reference distribution is in the entry 

89 called 'dist'. 

90 seed : int 

91 Seed for random number generator. 

92 """ 

93 

94 def __init__( 

95 self, 

96 pmv: np.ndarray, 

97 atoms: np.ndarray, 

98 seed: int = None, 

99 limit: Optional[Dict[str, Any]] = None, 

100 ) -> None: 

101 super().__init__(seed=seed) 

102 

103 self.pmv = np.asarray(pmv) 

104 self.atoms = np.atleast_2d(atoms) 

105 if limit is None: 

106 limit = { 

107 "infimum": np.min(self.atoms, axis=-1), 

108 "supremum": np.max(self.atoms, axis=-1), 

109 } 

110 self.limit = limit 

111 

112 # Check that pmv and atoms have compatible dimensions. 

113 if not self.pmv.size == self.atoms.shape[-1]: 

114 raise ValueError( 

115 "Provided pmv and atoms arrays have incompatible dimensions. " 

116 + "The length of the pmv must be equal to that of atoms's last dimension." 

117 ) 

118 

119 def __repr__(self): 

120 out = self.__class__.__name__ + " with " + str(self.pmv.size) + " atoms, " 

121 if self.atoms.shape[0] > 1: 

122 out += "inf=" + str(tuple(self.limit["infimum"])) + ", " 

123 out += "sup=" + str(tuple(self.limit["supremum"])) + ", " 

124 else: 

125 out += "inf=" + str(self.limit["infimum"][0]) + ", " 

126 out += "sup=" + str(self.limit["supremum"][0]) + ", " 

127 out += "seed=" + str(self.seed) 

128 return out 

129 

130 def dim(self) -> int: 

131 """ 

132 Last dimension of self.atoms indexes "atom." 

133 """ 

134 return self.atoms.shape[:-1] 

135 

136 def draw_events(self, N: int) -> np.ndarray: 

137 """ 

138 Draws N 'events' from the distribution PMF. 

139 These events are indices into atoms. 

140 """ 

141 # Generate a cumulative distribution 

142 base_draws = self._rng.uniform(size=N) 

143 cum_dist = np.cumsum(self.pmv) 

144 

145 # Convert the basic uniform draws into discrete draws 

146 indices = cum_dist.searchsorted(base_draws) 

147 

148 return indices 

149 

150 def draw( 

151 self, 

152 N: int, 

153 atoms: Union[None, int, np.ndarray] = None, 

154 shuffle: bool = False, 

155 ) -> np.ndarray: 

156 """ 

157 Simulates N draws from a discrete distribution with probabilities P and outcomes atoms. 

158 

159 Parameters 

160 ---------- 

161 N : int 

162 Number of draws to simulate. 

163 atoms : None, int, or np.array 

164 If None, then use this distribution's atoms for point values. 

165 If an int, then the index of atoms for the point values. 

166 If an np.array, use the array for the point values. 

167 shuffle : boolean 

168 Whether the draws should "shuffle" the discrete distribution, matching 

169 proportions of outcomes as closely as possible to the probabilities given 

170 finite draws. When True, returned draws are a random permutation of the 

171 N-length list that best fits the discrete distribution. When False 

172 (default), each draw is independent from the others and the result could 

173 deviate from the probabilities. 

174 

175 Returns 

176 ------- 

177 draws : np.array 

178 An array of draws from the discrete distribution; each element is a value in atoms. 

179 """ 

180 if atoms is None: 

181 atoms = self.atoms 

182 elif isinstance(atoms, int): 

183 atoms = self.atoms[atoms] 

184 

185 # "Shuffle" an almost-exact population of draws based on the pmv 

186 if shuffle: 

187 P = self.pmv 

188 K_exact = N * P # slots per outcome in real numbers 

189 K = np.floor(K_exact).astype(int) # number of slots allocated to each atom 

190 M = N - np.sum(K) # number of unallocated slots 

191 J = P.size 

192 eps = 1.0 / N 

193 Q = K_exact - eps * K # "missing" probability mass 

194 draws = self._rng.random(M) # uniform draws for "extra" slots 

195 

196 # Fill in each unallocated slot, one by one 

197 for m in range(M): 

198 Q_adj = Q / np.sum(Q) # probabilities for this pass 

199 Q_sum = np.cumsum(Q_adj) 

200 j = np.searchsorted(Q_sum, draws[m]) # find index for this draw 

201 K[j] += 1 # increment its allocated slots 

202 Q[j] = 0.0 # zero out its probability because we used it 

203 

204 # Make an array of atom indices based on the final slot counts 

205 nested_events = [K[j] * [j] for j in range(J)] 

206 events = np.array([i for sublist in nested_events for i in sublist]) 

207 

208 # Draw a random permutation of the indices 

209 indices = self._rng.permutation(events) 

210 

211 # Draw event indices randomly from the discrete distribution 

212 else: 

213 indices = self.draw_events(N) 

214 

215 # Create and fill in the output array of draws based on the output of event indices 

216 draws = atoms[..., indices] 

217 

218 # TODO: some models expect univariate draws to just be a 1d vector. Fix those models. 

219 if len(draws.shape) == 2 and draws.shape[0] == 1: 

220 draws = draws.flatten() 

221 

222 return draws 

223 

224 def expected( 

225 self, func: Optional[Callable] = None, *args: np.ndarray 

226 ) -> np.ndarray: 

227 """ 

228 Expected value of a function, given an array of configurations of its 

229 inputs along with a DiscreteDistribution object that specifies the 

230 probability of each configuration. 

231 

232 If no function is provided, it's much faster to go straight to dot 

233 product instead of calling the dummy function. 

234 

235 If a function is provided, we need to add one more dimension, 

236 the atom dimension, to any inputs that are n-dim arrays. 

237 This allows numpy to easily broadcast the function's output. 

238 For more information on broadcasting, see: 

239 https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules 

240 

241 Parameters 

242 ---------- 

243 func : function 

244 The function to be evaluated. 

245 This function should take the full array of distribution values 

246 and return either arrays of arbitrary shape or scalars. 

247 It may also take other arguments \\*args. 

248 This function differs from the standalone `calc_expectation` 

249 method in that it uses numpy's vectorization and broadcasting 

250 rules to avoid costly iteration. 

251 Note: If you need to use a function that acts on single outcomes 

252 of the distribution, consider `distribution.calc_expectation`. 

253 \\*args : 

254 Other inputs for func, representing the non-stochastic arguments. 

255 The the expectation is computed at ``f(dstn, *args)``. 

256 

257 Returns 

258 ------- 

259 f_exp : np.array or scalar 

260 The expectation of the function at the queried values. 

261 Scalar if only one value. 

262 """ 

263 

264 if func is None: 

265 f_query = self.atoms 

266 else: 

267 args = [ 

268 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg 

269 for arg in args 

270 ] 

271 

272 f_query = func(self.atoms, *args) 

273 

274 f_exp = np.dot(f_query, self.pmv) 

275 

276 return f_exp 

277 

278 def dist_of_func( 

279 self, func: Callable[..., float] = lambda x: x, *args: Any 

280 ) -> "DiscreteDistribution": 

281 """ 

282 Finds the distribution of a random variable Y that is a function 

283 of discrete random variable atoms, Y=f(atoms). 

284 

285 Parameters 

286 ---------- 

287 func : function 

288 The function to be evaluated. 

289 This function should take the full array of distribution values. 

290 It may also take other arguments \\*args. 

291 \\*args : 

292 Additional non-stochastic arguments for func, 

293 The function is computed as ``f(dstn, *args)``. 

294 

295 Returns 

296 ------- 

297 f_dstn : DiscreteDistribution 

298 The distribution of func(dstn). 

299 """ 

300 # we need to add one more dimension, 

301 # the atom dimension, to any inputs that are n-dim arrays. 

302 # This allows numpy to easily broadcast the function's output. 

303 args = [ 

304 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg 

305 for arg in args 

306 ] 

307 f_query = func(self.atoms, *args) 

308 

309 f_dstn = DiscreteDistribution(list(self.pmv), f_query, seed=self.seed) 

310 

311 return f_dstn 

312 

313 def discretize(self, N: int, *args: Any, **kwargs: Any) -> "DiscreteDistribution": 

314 """ 

315 `DiscreteDistribution` is already an approximation, so this method 

316 returns a copy of the distribution. 

317 

318 TODO: print warning message? 

319 """ 

320 return self 

321 

322 def make_univariate(self, dim_to_keep, seed=0): 

323 """ 

324 Make a univariate discrete distribution from this distribution, keeping 

325 only the specified dimension. 

326 

327 Parameters 

328 ---------- 

329 dim_to_keep : int 

330 Index of the distribution to be kept. Any other dimensions will be 

331 "collapsed" into the univariate atoms, combining probabilities. 

332 seed : int, optional 

333 Seed for random number generator of univariate distribution 

334 

335 Returns 

336 ------- 

337 univariate_dstn : DiscreteDistribution 

338 Univariate distribution with only the specified index. 

339 """ 

340 # Do basic validity and triviality checks 

341 if (self.atoms.shape[0] == 1) and (dim_to_keep == 0): 

342 return deepcopy(self) # Return copy of self if only one dimension 

343 if dim_to_keep >= self.atoms.shape[0]: 

344 raise ValueError("dim_to_keep exceeds dimensionality of distribution.") 

345 

346 # Construct values and probabilities for univariate distribution 

347 atoms_temp = self.atoms[dim_to_keep] 

348 vals_to_keep = np.unique(atoms_temp) 

349 probs_to_keep = np.zeros_like(vals_to_keep) 

350 for i in range(vals_to_keep.size): 

351 val = vals_to_keep[i] 

352 these = atoms_temp == val 

353 probs_to_keep[i] = np.sum(self.pmv[these]) 

354 

355 # Make and return the univariate distribution 

356 univariate_dstn = DiscreteDistribution( 

357 pmv=probs_to_keep, atoms=vals_to_keep, seed=seed 

358 ) 

359 return univariate_dstn 

360 

361 

362class DiscreteDistributionLabeled(DiscreteDistribution): 

363 """ 

364 A representation of a discrete probability distribution 

365 stored in an underlying `xarray.Dataset`. 

366 

367 Parameters 

368 ---------- 

369 pmv : np.array 

370 An array of values representing a probability mass function. 

371 data : np.array 

372 Discrete point values for each probability mass. 

373 For multivariate distributions, the last dimension of atoms must index 

374 "atom" or the random realization. For instance, if atoms.shape == (2,6,4), 

375 the random variable has 4 possible realizations and each of them has shape (2,6). 

376 seed : int 

377 Seed for random number generator. 

378 name : str 

379 Name of the distribution. 

380 attrs : dict 

381 Attributes for the distribution. 

382 var_names : list of str 

383 Names of the variables in the distribution. 

384 var_attrs : list of dict 

385 Attributes of the variables in the distribution. 

386 

387 """ 

388 

389 def __init__( 

390 self, 

391 pmv: np.ndarray, 

392 atoms: np.ndarray, 

393 seed: int = None, 

394 limit: Optional[Dict[str, Any]] = None, 

395 name: str = "DiscreteDistributionLabeled", 

396 attrs: Optional[Dict[str, Any]] = None, 

397 var_names: Optional[List[str]] = None, 

398 var_attrs: Optional[List[Optional[Dict[str, Any]]]] = None, 

399 ): 

400 super().__init__(pmv, atoms, seed=seed, limit=limit) 

401 

402 # vector-value distributions 

403 

404 if self.atoms.ndim > 2: 

405 raise NotImplementedError( 

406 "Only vector-valued distributions are supported for now." 

407 ) 

408 

409 attrs = {} if attrs is None else attrs 

410 if limit is None: 

411 limit = { 

412 "infimum": np.min(self.atoms, axis=-1), 

413 "supremum": np.max(self.atoms, axis=-1), 

414 } 

415 self.limit = limit 

416 attrs.update(limit) 

417 attrs["name"] = name 

418 

419 n_var = self.atoms.shape[0] 

420 

421 # give dummy names to variables if none are provided 

422 if var_names is None: 

423 var_names = ["var_" + str(i) for i in range(n_var)] 

424 

425 assert len(var_names) == n_var, ( 

426 "Number of variable names does not match number of variables." 

427 ) 

428 

429 # give dummy attributes to variables if none are provided 

430 if var_attrs is None: 

431 var_attrs = [None] * n_var 

432 

433 # a DiscreteDistributionLabeled is an xr.Dataset where the only 

434 # dimension is "atom", which indexes the random realizations. 

435 self.dataset = xr.Dataset( 

436 { 

437 var_names[i]: xr.DataArray( 

438 self.atoms[i], 

439 dims=("atom"), 

440 attrs=var_attrs[i], 

441 ) 

442 for i in range(n_var) 

443 }, 

444 attrs=attrs, 

445 ) 

446 

447 # the probability mass values are stored in 

448 # a DataArray with dimension "atom" 

449 self.probability = xr.DataArray(self.pmv, dims=("atom")) 

450 

451 @classmethod 

452 def from_unlabeled( 

453 cls, 

454 dist, 

455 name="DiscreteDistributionLabeled", 

456 attrs=None, 

457 var_names=None, 

458 var_attrs=None, 

459 ): 

460 ldd = cls( 

461 dist.pmv, 

462 dist.atoms, 

463 seed=dist.seed, 

464 limit=dist.limit, 

465 name=name, 

466 attrs=attrs, 

467 var_names=var_names, 

468 var_attrs=var_attrs, 

469 ) 

470 

471 return ldd 

472 

473 @classmethod 

474 def from_dataset(cls, x_obj, pmf): 

475 ldd = cls.__new__(cls) 

476 

477 if isinstance(x_obj, xr.Dataset): 

478 ldd.dataset = x_obj 

479 elif isinstance(x_obj, xr.DataArray): 

480 ldd.dataset = xr.Dataset({x_obj.name: x_obj}) 

481 elif isinstance(x_obj, dict): 

482 ldd.dataset = xr.Dataset(x_obj) 

483 

484 ldd.probability = pmf 

485 

486 return ldd 

487 

488 @property 

489 def _weighted(self): 

490 """ 

491 Returns a DatasetWeighted object for the distribution. 

492 """ 

493 return self.dataset.weighted(self.probability) 

494 

495 @property 

496 def variables(self): 

497 """ 

498 A dict-like container of DataArrays corresponding to 

499 the variables of the distribution. 

500 """ 

501 return self.dataset.data_vars 

502 

503 @property 

504 def name(self): 

505 """ 

506 The distribution's name. 

507 """ 

508 return self.dataset.name 

509 

510 @property 

511 def attrs(self): 

512 """ 

513 The distribution's attributes. 

514 """ 

515 

516 return self.dataset.attrs 

517 

518 def dist_of_func( 

519 self, func: Callable = lambda x: x, *args, **kwargs 

520 ) -> DiscreteDistribution: 

521 """ 

522 Finds the distribution of a random variable Y that is a function 

523 of discrete random variable atoms, Y=f(atoms). 

524 

525 Parameters 

526 ---------- 

527 func : function 

528 The function to be evaluated. 

529 This function should take the full array of distribution values. 

530 It may also take other arguments \\*args. 

531 \\*args : 

532 Additional non-stochastic arguments for func, 

533 The function is computed as ``f(dstn, *args)``. 

534 \\*\\*kwargs : 

535 Additional keyword arguments for func. Must be xarray compatible 

536 in order to work with xarray broadcasting. 

537 

538 Returns 

539 ------- 

540 f_dstn : DiscreteDistribution or DiscreteDistributionLabeled 

541 The distribution of func(dstn). 

542 """ 

543 

544 def func_wrapper(x: np.ndarray, *args: Any) -> np.ndarray: 

545 """ 

546 Wrapper function for `func` that handles labeled indexing. 

547 """ 

548 

549 idx = self.variables.keys() 

550 wrapped = dict(zip(idx, x)) 

551 

552 return func(wrapped, *args) 

553 

554 if len(kwargs): 

555 f_query = func(self.dataset, **kwargs) 

556 ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability) 

557 

558 return ldd 

559 

560 return super().dist_of_func(func_wrapper, *args) 

561 

562 def expected( 

563 self, func: Optional[Callable] = None, *args: Any, **kwargs: Any 

564 ) -> Union[float, np.ndarray]: 

565 """ 

566 Expectation of a function, given an array of configurations of its inputs 

567 along with a DiscreteDistributionLabeled object that specifies the probability 

568 of each configuration. 

569 

570 Parameters 

571 ---------- 

572 func : function 

573 The function to be evaluated. 

574 This function should take the full array of distribution values 

575 and return either arrays of arbitrary shape or scalars. 

576 It may also take other arguments \\*args. 

577 This function differs from the standalone `calc_expectation` 

578 method in that it uses numpy's vectorization and broadcasting 

579 rules to avoid costly iteration. 

580 Note: If you need to use a function that acts on single outcomes 

581 of the distribution, consider `distribution.calc_expectation`. 

582 \\*args : 

583 Other inputs for func, representing the non-stochastic arguments. 

584 The the expectation is computed at ``f(dstn, *args)``. 

585 labels : bool 

586 If True, the function should use labeled indexing instead of integer 

587 indexing using the distribution's underlying rv coordinates. For example, 

588 if `dims = ('rv', 'x')` and `coords = {'rv': ['a', 'b'], }`, then 

589 the function can be `lambda x: x["a"] + x["b"]`. 

590 

591 Returns 

592 ------- 

593 f_exp : np.array or scalar 

594 The expectation of the function at the queried values. 

595 Scalar if only one value. 

596 """ 

597 

598 def func_wrapper(x, *args): 

599 """ 

600 Wrapper function for `func` that handles labeled indexing. 

601 """ 

602 

603 idx = self.variables.keys() 

604 wrapped = dict(zip(idx, x)) 

605 

606 return func(wrapped, *args) 

607 

608 if len(kwargs): 

609 f_query = func(self.dataset, *args, **kwargs) 

610 ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability) 

611 

612 return ldd._weighted.mean("atom") 

613 else: 

614 if func is None: 

615 return super().expected() 

616 else: 

617 return super().expected(func_wrapper, *args)