Coverage for HARK/distributions/discrete.py: 92%

156 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-02 05:14 +0000

1from typing import Any, Callable, Dict, List, Optional, Union 

2 

3from copy import deepcopy 

4import numpy as np 

5import xarray as xr 

6from scipy import stats 

7from scipy.stats import rv_discrete 

8from scipy.stats._distn_infrastructure import rv_discrete_frozen 

9 

10from HARK.distributions.base import Distribution 

11 

12 

13class DiscreteFrozenDistribution(rv_discrete_frozen, Distribution): 

14 """ 

15 Parameterized discrete distribution from scipy.stats with seed management. 

16 """ 

17 

18 def __init__( 

19 self, dist: rv_discrete, *args: Any, seed: int = 0, **kwds: Any 

20 ) -> None: 

21 """ 

22 Parameterized discrete distribution from scipy.stats with seed management. 

23 

24 Parameters 

25 ---------- 

26 dist : rv_discrete 

27 Discrete distribution from scipy.stats. 

28 seed : int, optional 

29 Seed for random number generator, by default 0 

30 """ 

31 

32 rv_discrete_frozen.__init__(self, dist, *args, **kwds) 

33 Distribution.__init__(self, seed=seed) 

34 

35 

36class Bernoulli(DiscreteFrozenDistribution): 

37 """ 

38 A Bernoulli distribution. 

39 

40 Parameters 

41 ---------- 

42 p : float or [float] 

43 Probability or probabilities of the event occurring (True). 

44 

45 seed : int 

46 Seed for random number generator. 

47 """ 

48 

49 def __init__(self, p=0.5, seed=0): 

50 self.p = np.asarray(p) 

51 # Set up the RNG 

52 super().__init__(stats.bernoulli, p=self.p, seed=seed) 

53 

54 self.pmv = np.array([1 - self.p, self.p]) 

55 self.atoms = np.array( 

56 [[0, 1]] 

57 ) # Ensure atoms is properly shaped like other distributions 

58 self.limit = { 

59 "dist": self, 

60 "infimum": np.array([0.0]), 

61 "supremum": np.array([1.0]), 

62 } 

63 self.infimum = np.array([0.0]) 

64 self.supremum = np.array([1.0]) 

65 

66 def dim(self): 

67 """ 

68 Last dimension of self.atoms indexes "atom." 

69 """ 

70 return self.atoms.shape[:-1] 

71 

72 

73class DiscreteDistribution(Distribution): 

74 """ 

75 A representation of a discrete probability distribution. 

76 

77 Parameters 

78 ---------- 

79 pmv : np.array 

80 An array of floats representing a probability mass function. 

81 atoms : np.array 

82 Discrete point values for each probability mass. 

83 For multivariate distributions, the last dimension of atoms must index 

84 "atom" or the random realization. For instance, if atoms.shape == (2,6,4), 

85 the random variable has 4 possible realizations and each of them has shape (2,6). 

86 limit : dict 

87 Dictionary with information about the continuous distribution from which 

88 this distribution was generated. The reference distribution is in the entry 

89 called 'dist'. 

90 seed : int 

91 Seed for random number generator. 

92 """ 

93 

94 def __init__( 

95 self, 

96 pmv: np.ndarray, 

97 atoms: np.ndarray, 

98 seed: int = 0, 

99 limit: Optional[Dict[str, Any]] = None, 

100 ) -> None: 

101 super().__init__(seed=seed) 

102 

103 self.pmv = np.asarray(pmv) 

104 self.atoms = np.atleast_2d(atoms) 

105 if limit is None: 

106 limit = { 

107 "infimum": np.min(self.atoms, axis=-1), 

108 "supremum": np.max(self.atoms, axis=-1), 

109 } 

110 self.limit = limit 

111 

112 # Check that pmv and atoms have compatible dimensions. 

113 if not self.pmv.size == self.atoms.shape[-1]: 

114 raise ValueError( 

115 "Provided pmv and atoms arrays have incompatible dimensions. " 

116 + "The length of the pmv must be equal to that of atoms's last dimension." 

117 ) 

118 

119 def dim(self) -> int: 

120 """ 

121 Last dimension of self.atoms indexes "atom." 

122 """ 

123 return self.atoms.shape[:-1] 

124 

125 def draw_events(self, N: int) -> np.ndarray: 

126 """ 

127 Draws N 'events' from the distribution PMF. 

128 These events are indices into atoms. 

129 """ 

130 # Generate a cumulative distribution 

131 base_draws = self._rng.uniform(size=N) 

132 cum_dist = np.cumsum(self.pmv) 

133 

134 # Convert the basic uniform draws into discrete draws 

135 indices = cum_dist.searchsorted(base_draws) 

136 

137 return indices 

138 

139 def draw( 

140 self, 

141 N: int, 

142 atoms: Union[None, int, np.ndarray] = None, 

143 exact_match: bool = False, 

144 ) -> np.ndarray: 

145 """ 

146 Simulates N draws from a discrete distribution with probabilities P and outcomes atoms. 

147 

148 Parameters 

149 ---------- 

150 N : int 

151 Number of draws to simulate. 

152 atoms : None, int, or np.array 

153 If None, then use this distribution's atoms for point values. 

154 If an int, then the index of atoms for the point values. 

155 If an np.array, use the array for the point values. 

156 exact_match : boolean 

157 Whether the draws should "exactly" match the discrete distribution (as 

158 closely as possible given finite draws). When True, returned draws are 

159 a random permutation of the N-length list that best fits the discrete 

160 distribution. When False (default), each draw is independent from the 

161 others and the result could deviate from the input. 

162 

163 Returns 

164 ------- 

165 draws : np.array 

166 An array of draws from the discrete distribution; each element is a value in atoms. 

167 """ 

168 if atoms is None: 

169 atoms = self.atoms 

170 elif isinstance(atoms, int): 

171 atoms = self.atoms[atoms] 

172 

173 if exact_match: 

174 events = np.arange(self.pmv.size) # just a list of integers 

175 cutoffs = np.round(np.cumsum(self.pmv) * N).astype( 

176 int 

177 ) # cutoff points between discrete outcomes 

178 top = 0 

179 

180 # Make a list of event indices that closely matches the discrete distribution 

181 event_list = [] 

182 for j in range(events.size): 

183 bot = top 

184 top = cutoffs[j] 

185 event_list += (top - bot) * [events[j]] 

186 

187 # Randomly permute the event indices 

188 indices = self._rng.permutation(event_list) 

189 

190 # Draw event indices randomly from the discrete distribution 

191 else: 

192 indices = self.draw_events(N) 

193 

194 # Create and fill in the output array of draws based on the output of event indices 

195 draws = atoms[..., indices] 

196 

197 # TODO: some models expect univariate draws to just be a 1d vector. Fix those models. 

198 if len(draws.shape) == 2 and draws.shape[0] == 1: 

199 draws = draws.flatten() 

200 

201 return draws 

202 

203 def expected( 

204 self, func: Optional[Callable] = None, *args: np.ndarray 

205 ) -> np.ndarray: 

206 """ 

207 Expected value of a function, given an array of configurations of its 

208 inputs along with a DiscreteDistribution object that specifies the 

209 probability of each configuration. 

210 

211 If no function is provided, it's much faster to go straight to dot 

212 product instead of calling the dummy function. 

213 

214 If a function is provided, we need to add one more dimension, 

215 the atom dimension, to any inputs that are n-dim arrays. 

216 This allows numpy to easily broadcast the function's output. 

217 For more information on broadcasting, see: 

218 https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules 

219 

220 Parameters 

221 ---------- 

222 func : function 

223 The function to be evaluated. 

224 This function should take the full array of distribution values 

225 and return either arrays of arbitrary shape or scalars. 

226 It may also take other arguments \\*args. 

227 This function differs from the standalone `calc_expectation` 

228 method in that it uses numpy's vectorization and broadcasting 

229 rules to avoid costly iteration. 

230 Note: If you need to use a function that acts on single outcomes 

231 of the distribution, consider `distribution.calc_expectation`. 

232 \\*args : 

233 Other inputs for func, representing the non-stochastic arguments. 

234 The the expectation is computed at ``f(dstn, *args)``. 

235 

236 Returns 

237 ------- 

238 f_exp : np.array or scalar 

239 The expectation of the function at the queried values. 

240 Scalar if only one value. 

241 """ 

242 

243 if func is None: 

244 f_query = self.atoms 

245 else: 

246 args = [ 

247 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg 

248 for arg in args 

249 ] 

250 

251 f_query = func(self.atoms, *args) 

252 

253 f_exp = np.dot(f_query, self.pmv) 

254 

255 return f_exp 

256 

257 def dist_of_func( 

258 self, func: Callable[..., float] = lambda x: x, *args: Any 

259 ) -> "DiscreteDistribution": 

260 """ 

261 Finds the distribution of a random variable Y that is a function 

262 of discrete random variable atoms, Y=f(atoms). 

263 

264 Parameters 

265 ---------- 

266 func : function 

267 The function to be evaluated. 

268 This function should take the full array of distribution values. 

269 It may also take other arguments \\*args. 

270 \\*args : 

271 Additional non-stochastic arguments for func, 

272 The function is computed as ``f(dstn, *args)``. 

273 

274 Returns 

275 ------- 

276 f_dstn : DiscreteDistribution 

277 The distribution of func(dstn). 

278 """ 

279 # we need to add one more dimension, 

280 # the atom dimension, to any inputs that are n-dim arrays. 

281 # This allows numpy to easily broadcast the function's output. 

282 args = [ 

283 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg 

284 for arg in args 

285 ] 

286 f_query = func(self.atoms, *args) 

287 

288 f_dstn = DiscreteDistribution(list(self.pmv), f_query, seed=self.seed) 

289 

290 return f_dstn 

291 

292 def discretize(self, N: int, *args: Any, **kwargs: Any) -> "DiscreteDistribution": 

293 """ 

294 `DiscreteDistribution` is already an approximation, so this method 

295 returns a copy of the distribution. 

296 

297 TODO: print warning message? 

298 """ 

299 return self 

300 

301 def make_univariate(self, dim_to_keep, seed=0): 

302 """ 

303 Make a univariate discrete distribution from this distribution, keeping 

304 only the specified dimension. 

305 

306 Parameters 

307 ---------- 

308 dim_to_keep : int 

309 Index of the distribution to be kept. Any other dimensions will be 

310 "collapsed" into the univariate atoms, combining probabilities. 

311 seed : int, optional 

312 Seed for random number generator of univariate distribution 

313 

314 Returns 

315 ------- 

316 univariate_dstn : DiscreteDistribution 

317 Univariate distribution with only the specified index. 

318 """ 

319 # Do basic validity and triviality checks 

320 if (self.atoms.shape[0] == 1) and (dim_to_keep == 0): 

321 return deepcopy(self) # Return copy of self if only one dimension 

322 if dim_to_keep >= self.atoms.shape[0]: 

323 raise ValueError("dim_to_keep exceeds dimensionality of distribution.") 

324 

325 # Construct values and probabilities for univariate distribution 

326 atoms_temp = self.atoms[dim_to_keep] 

327 vals_to_keep = np.unique(atoms_temp) 

328 probs_to_keep = np.zeros_like(vals_to_keep) 

329 for i in range(vals_to_keep.size): 

330 val = vals_to_keep[i] 

331 these = atoms_temp == val 

332 probs_to_keep[i] = np.sum(self.pmv[these]) 

333 

334 # Make and return the univariate distribution 

335 univariate_dstn = DiscreteDistribution( 

336 pmv=probs_to_keep, atoms=vals_to_keep, seed=seed 

337 ) 

338 return univariate_dstn 

339 

340 

341class DiscreteDistributionLabeled(DiscreteDistribution): 

342 """ 

343 A representation of a discrete probability distribution 

344 stored in an underlying `xarray.Dataset`. 

345 

346 Parameters 

347 ---------- 

348 pmv : np.array 

349 An array of values representing a probability mass function. 

350 data : np.array 

351 Discrete point values for each probability mass. 

352 For multivariate distributions, the last dimension of atoms must index 

353 "atom" or the random realization. For instance, if atoms.shape == (2,6,4), 

354 the random variable has 4 possible realizations and each of them has shape (2,6). 

355 seed : int 

356 Seed for random number generator. 

357 name : str 

358 Name of the distribution. 

359 attrs : dict 

360 Attributes for the distribution. 

361 var_names : list of str 

362 Names of the variables in the distribution. 

363 var_attrs : list of dict 

364 Attributes of the variables in the distribution. 

365 

366 """ 

367 

368 def __init__( 

369 self, 

370 pmv: np.ndarray, 

371 atoms: np.ndarray, 

372 seed: int = 0, 

373 limit: Optional[Dict[str, Any]] = None, 

374 name: str = "DiscreteDistributionLabeled", 

375 attrs: Optional[Dict[str, Any]] = None, 

376 var_names: Optional[List[str]] = None, 

377 var_attrs: Optional[List[Optional[Dict[str, Any]]]] = None, 

378 ): 

379 super().__init__(pmv, atoms, seed=seed, limit=limit) 

380 

381 # vector-value distributions 

382 

383 if self.atoms.ndim > 2: 

384 raise NotImplementedError( 

385 "Only vector-valued distributions are supported for now." 

386 ) 

387 

388 attrs = {} if attrs is None else attrs 

389 if limit is None: 

390 limit = { 

391 "infimum": np.min(self.atoms, axis=-1), 

392 "supremum": np.max(self.atoms, axis=-1), 

393 } 

394 self.limit = limit 

395 attrs.update(limit) 

396 attrs["name"] = name 

397 

398 n_var = self.atoms.shape[0] 

399 

400 # give dummy names to variables if none are provided 

401 if var_names is None: 

402 var_names = ["var_" + str(i) for i in range(n_var)] 

403 

404 assert len(var_names) == n_var, ( 

405 "Number of variable names does not match number of variables." 

406 ) 

407 

408 # give dummy attributes to variables if none are provided 

409 if var_attrs is None: 

410 var_attrs = [None] * n_var 

411 

412 # a DiscreteDistributionLabeled is an xr.Dataset where the only 

413 # dimension is "atom", which indexes the random realizations. 

414 self.dataset = xr.Dataset( 

415 { 

416 var_names[i]: xr.DataArray( 

417 self.atoms[i], 

418 dims=("atom"), 

419 attrs=var_attrs[i], 

420 ) 

421 for i in range(n_var) 

422 }, 

423 attrs=attrs, 

424 ) 

425 

426 # the probability mass values are stored in 

427 # a DataArray with dimension "atom" 

428 self.probability = xr.DataArray(self.pmv, dims=("atom")) 

429 

430 @classmethod 

431 def from_unlabeled( 

432 cls, 

433 dist, 

434 name="DiscreteDistributionLabeled", 

435 attrs=None, 

436 var_names=None, 

437 var_attrs=None, 

438 ): 

439 ldd = cls( 

440 dist.pmv, 

441 dist.atoms, 

442 seed=dist.seed, 

443 limit=dist.limit, 

444 name=name, 

445 attrs=attrs, 

446 var_names=var_names, 

447 var_attrs=var_attrs, 

448 ) 

449 

450 return ldd 

451 

452 @classmethod 

453 def from_dataset(cls, x_obj, pmf): 

454 ldd = cls.__new__(cls) 

455 

456 if isinstance(x_obj, xr.Dataset): 

457 ldd.dataset = x_obj 

458 elif isinstance(x_obj, xr.DataArray): 

459 ldd.dataset = xr.Dataset({x_obj.name: x_obj}) 

460 elif isinstance(x_obj, dict): 

461 ldd.dataset = xr.Dataset(x_obj) 

462 

463 ldd.probability = pmf 

464 

465 return ldd 

466 

467 @property 

468 def _weighted(self): 

469 """ 

470 Returns a DatasetWeighted object for the distribution. 

471 """ 

472 return self.dataset.weighted(self.probability) 

473 

474 @property 

475 def variables(self): 

476 """ 

477 A dict-like container of DataArrays corresponding to 

478 the variables of the distribution. 

479 """ 

480 return self.dataset.data_vars 

481 

482 @property 

483 def name(self): 

484 """ 

485 The distribution's name. 

486 """ 

487 return self.dataset.name 

488 

489 @property 

490 def attrs(self): 

491 """ 

492 The distribution's attributes. 

493 """ 

494 

495 return self.dataset.attrs 

496 

497 def dist_of_func( 

498 self, func: Callable = lambda x: x, *args, **kwargs 

499 ) -> DiscreteDistribution: 

500 """ 

501 Finds the distribution of a random variable Y that is a function 

502 of discrete random variable atoms, Y=f(atoms). 

503 

504 Parameters 

505 ---------- 

506 func : function 

507 The function to be evaluated. 

508 This function should take the full array of distribution values. 

509 It may also take other arguments \\*args. 

510 \\*args : 

511 Additional non-stochastic arguments for func, 

512 The function is computed as ``f(dstn, *args)``. 

513 \\*\\*kwargs : 

514 Additional keyword arguments for func. Must be xarray compatible 

515 in order to work with xarray broadcasting. 

516 

517 Returns 

518 ------- 

519 f_dstn : DiscreteDistribution or DiscreteDistributionLabeled 

520 The distribution of func(dstn). 

521 """ 

522 

523 def func_wrapper(x: np.ndarray, *args: Any) -> np.ndarray: 

524 """ 

525 Wrapper function for `func` that handles labeled indexing. 

526 """ 

527 

528 idx = self.variables.keys() 

529 wrapped = dict(zip(idx, x)) 

530 

531 return func(wrapped, *args) 

532 

533 if len(kwargs): 

534 f_query = func(self.dataset, **kwargs) 

535 ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability) 

536 

537 return ldd 

538 

539 return super().dist_of_func(func_wrapper, *args) 

540 

541 def expected( 

542 self, func: Optional[Callable] = None, *args: Any, **kwargs: Any 

543 ) -> Union[float, np.ndarray]: 

544 """ 

545 Expectation of a function, given an array of configurations of its inputs 

546 along with a DiscreteDistributionLabeled object that specifies the probability 

547 of each configuration. 

548 

549 Parameters 

550 ---------- 

551 func : function 

552 The function to be evaluated. 

553 This function should take the full array of distribution values 

554 and return either arrays of arbitrary shape or scalars. 

555 It may also take other arguments \\*args. 

556 This function differs from the standalone `calc_expectation` 

557 method in that it uses numpy's vectorization and broadcasting 

558 rules to avoid costly iteration. 

559 Note: If you need to use a function that acts on single outcomes 

560 of the distribution, consider `distribution.calc_expectation`. 

561 \\*args : 

562 Other inputs for func, representing the non-stochastic arguments. 

563 The the expectation is computed at ``f(dstn, *args)``. 

564 labels : bool 

565 If True, the function should use labeled indexing instead of integer 

566 indexing using the distribution's underlying rv coordinates. For example, 

567 if `dims = ('rv', 'x')` and `coords = {'rv': ['a', 'b'], }`, then 

568 the function can be `lambda x: x["a"] + x["b"]`. 

569 

570 Returns 

571 ------- 

572 f_exp : np.array or scalar 

573 The expectation of the function at the queried values. 

574 Scalar if only one value. 

575 """ 

576 

577 def func_wrapper(x, *args): 

578 """ 

579 Wrapper function for `func` that handles labeled indexing. 

580 """ 

581 

582 idx = self.variables.keys() 

583 wrapped = dict(zip(idx, x)) 

584 

585 return func(wrapped, *args) 

586 

587 if len(kwargs): 

588 f_query = func(self.dataset, *args, **kwargs) 

589 ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability) 

590 

591 return ldd._weighted.mean("atom") 

592 else: 

593 if func is None: 

594 return super().expected() 

595 else: 

596 return super().expected(func_wrapper, *args)