Coverage for HARK / distributions / discrete.py: 92%

207 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 06:00 +0000

1from typing import Any, Callable, Dict, List, Optional, Union 

2 

3from copy import deepcopy 

4import numpy as np 

5import xarray as xr 

6from scipy import stats 

7from scipy.stats import rv_discrete 

8from scipy.stats._distn_infrastructure import rv_discrete_frozen 

9 

10from HARK.distributions.base import Distribution 

11 

12 

13def _weighted_mean_var(var, pmv): 

14 """Compute weighted mean of a single DataArray over its ``atom`` dimension. 

15 

16 If *var* is not an :class:`xr.DataArray` or lacks an ``atom`` dimension it 

17 is returned unchanged. This pass-through is intentional: variables without 

18 an atom dimension (e.g. scalar metadata or grid coordinates) are preserved 

19 as-is when called from :func:`_weighted_mean`. 

20 """ 

21 if not isinstance(var, xr.DataArray) or "atom" not in var.dims: 

22 return var 

23 atom_axis = var.dims.index("atom") 

24 avg = np.tensordot(pmv, var.values, axes=([0], [atom_axis])) 

25 remaining_dims = tuple(d for d in var.dims if d != "atom") 

26 if remaining_dims: 

27 remaining_coords = {d: var.coords[d] for d in remaining_dims if d in var.coords} 

28 return xr.DataArray(avg, dims=remaining_dims, coords=remaining_coords) 

29 return float(avg) if np.ndim(avg) == 0 else avg 

30 

31 

32def _weighted_mean(data, pmv): 

33 """Compute weighted mean over the ``atom`` dimension using numpy. 

34 

35 Used in the kwargs path of ``DDL.expected()`` when the caller passes 

36 xarray-compatible keyword arguments. Replaces the slow 

37 ``Dataset.weighted(prob).mean("atom")`` with ``np.tensordot`` per 

38 variable. The common no-kwargs path uses a separate ``np.dot`` on 

39 the cached ``_wrapped_atoms`` dict and does not call this function. 

40 

41 Returns an :class:`xr.Dataset` when *data* is a Dataset or dict, 

42 an :class:`xr.DataArray` or scalar when *data* is a DataArray. 

43 Raises :class:`TypeError` for other types. 

44 """ 

45 if isinstance(data, xr.Dataset): 

46 return xr.Dataset( 

47 {name: _weighted_mean_var(var, pmv) for name, var in data.data_vars.items()} 

48 ) 

49 elif isinstance(data, dict): 

50 return xr.Dataset( 

51 { 

52 name: _weighted_mean_var(val, pmv) 

53 if isinstance(val, xr.DataArray) 

54 else val 

55 for name, val in data.items() 

56 } 

57 ) 

58 elif isinstance(data, xr.DataArray): 

59 return _weighted_mean_var(data, pmv) 

60 raise TypeError( 

61 f"_weighted_mean: unsupported data type '{type(data).__name__}'. " 

62 "Expected xr.Dataset, xr.DataArray, or dict. " 

63 "Ensure the function passed to expected() returns one of these types " 

64 "when keyword arguments are used." 

65 ) 

66 

67 

68class DiscreteFrozenDistribution(rv_discrete_frozen, Distribution): 

69 """ 

70 Parameterized discrete distribution from scipy.stats with seed management. 

71 """ 

72 

73 def __init__( 

74 self, dist: rv_discrete, *args: Any, seed: int = None, **kwds: Any 

75 ) -> None: 

76 """ 

77 Parameterized discrete distribution from scipy.stats with seed management. 

78 

79 Parameters 

80 ---------- 

81 dist : rv_discrete 

82 Discrete distribution from scipy.stats. 

83 seed : int, optional 

84 Seed for random number generator, by default 0 

85 """ 

86 

87 rv_discrete_frozen.__init__(self, dist, *args, **kwds) 

88 Distribution.__init__(self, seed=seed) 

89 

90 

91class Bernoulli(DiscreteFrozenDistribution): 

92 """ 

93 A Bernoulli distribution. 

94 

95 Parameters 

96 ---------- 

97 p : float or [float] 

98 Probability or probabilities of the event occurring (True). 

99 

100 seed : int 

101 Seed for random number generator. 

102 """ 

103 

104 def __init__(self, p=0.5, seed=None): 

105 self.p = np.asarray(p) 

106 # Set up the RNG 

107 super().__init__(stats.bernoulli, p=self.p, seed=seed) 

108 

109 self.pmv = np.array([1 - self.p, self.p]) 

110 self.atoms = np.array( 

111 [[0, 1]] 

112 ) # Ensure atoms is properly shaped like other distributions 

113 self.limit = { 

114 "dist": self, 

115 "infimum": np.array([0.0]), 

116 "supremum": np.array([1.0]), 

117 } 

118 self.infimum = np.array([0.0]) 

119 self.supremum = np.array([1.0]) 

120 

121 def dim(self): 

122 """ 

123 Last dimension of self.atoms indexes "atom." 

124 """ 

125 return self.atoms.shape[:-1] 

126 

127 

128class DiscreteDistribution(Distribution): 

129 """ 

130 A representation of a discrete probability distribution. 

131 

132 Parameters 

133 ---------- 

134 pmv : np.array 

135 An array of floats representing a probability mass function. 

136 atoms : np.array 

137 Discrete point values for each probability mass. 

138 For multivariate distributions, the last dimension of atoms must index 

139 "atom" or the random realization. For instance, if atoms.shape == (2,6,4), 

140 the random variable has 4 possible realizations and each of them has shape (2,6). 

141 limit : dict 

142 Dictionary with information about the continuous distribution from which 

143 this distribution was generated. The reference distribution is in the entry 

144 called 'dist'. 

145 seed : int 

146 Seed for random number generator. 

147 """ 

148 

149 def __init__( 

150 self, 

151 pmv: np.ndarray, 

152 atoms: np.ndarray, 

153 seed: int = None, 

154 limit: Optional[Dict[str, Any]] = None, 

155 ) -> None: 

156 super().__init__(seed=seed) 

157 

158 self.pmv = np.asarray(pmv) 

159 self.atoms = np.atleast_2d(atoms) 

160 if limit is None: 

161 limit = { 

162 "infimum": np.min(self.atoms, axis=-1), 

163 "supremum": np.max(self.atoms, axis=-1), 

164 } 

165 self.limit = limit 

166 

167 # Check that pmv and atoms have compatible dimensions. 

168 if not self.pmv.size == self.atoms.shape[-1]: 

169 raise ValueError( 

170 "Provided pmv and atoms arrays have incompatible dimensions. " 

171 + "The length of the pmv must be equal to that of atoms's last dimension." 

172 ) 

173 

174 def __repr__(self): 

175 out = self.__class__.__name__ + " with " + str(self.pmv.size) + " atoms, " 

176 if self.atoms.shape[0] > 1: 

177 out += "inf=" + str(tuple(self.limit["infimum"])) + ", " 

178 out += "sup=" + str(tuple(self.limit["supremum"])) + ", " 

179 else: 

180 out += "inf=" + str(self.limit["infimum"][0]) + ", " 

181 out += "sup=" + str(self.limit["supremum"][0]) + ", " 

182 out += "seed=" + str(self.seed) 

183 return out 

184 

185 def dim(self) -> int: 

186 """ 

187 Last dimension of self.atoms indexes "atom." 

188 """ 

189 return self.atoms.shape[:-1] 

190 

191 def draw_events(self, N: int) -> np.ndarray: 

192 """ 

193 Draws N 'events' from the distribution PMF. 

194 These events are indices into atoms. 

195 """ 

196 # Generate a cumulative distribution 

197 base_draws = self._rng.uniform(size=N) 

198 cum_dist = np.cumsum(self.pmv) 

199 

200 # Convert the basic uniform draws into discrete draws 

201 indices = cum_dist.searchsorted(base_draws) 

202 

203 return indices 

204 

205 def draw( 

206 self, 

207 N: int, 

208 atoms: Union[None, int, np.ndarray] = None, 

209 shuffle: bool = False, 

210 ) -> np.ndarray: 

211 """ 

212 Simulates N draws from a discrete distribution with probabilities P and outcomes atoms. 

213 

214 Parameters 

215 ---------- 

216 N : int 

217 Number of draws to simulate. 

218 atoms : None, int, or np.array 

219 If None, then use this distribution's atoms for point values. 

220 If an int, then the index of atoms for the point values. 

221 If an np.array, use the array for the point values. 

222 shuffle : boolean 

223 Whether the draws should "shuffle" the discrete distribution, matching 

224 proportions of outcomes as closely as possible to the probabilities given 

225 finite draws. When True, returned draws are a random permutation of the 

226 N-length list that best fits the discrete distribution. When False 

227 (default), each draw is independent from the others and the result could 

228 deviate from the probabilities. 

229 

230 Returns 

231 ------- 

232 draws : np.array 

233 An array of draws from the discrete distribution; each element is a value in atoms. 

234 """ 

235 if atoms is None: 

236 atoms = self.atoms 

237 elif isinstance(atoms, int): 

238 atoms = self.atoms[atoms] 

239 

240 # "Shuffle" an almost-exact population of draws based on the pmv 

241 if shuffle: 

242 P = self.pmv 

243 K_exact = N * P # slots per outcome in real numbers 

244 K = np.floor(K_exact).astype(int) # number of slots allocated to each atom 

245 M = N - np.sum(K) # number of unallocated slots 

246 J = P.size 

247 eps = 1.0 / N 

248 Q = K_exact - eps * K # "missing" probability mass 

249 draws = self._rng.random(M) # uniform draws for "extra" slots 

250 

251 # Fill in each unallocated slot, one by one 

252 for m in range(M): 

253 Q_adj = Q / np.sum(Q) # probabilities for this pass 

254 Q_sum = np.cumsum(Q_adj) 

255 j = np.searchsorted(Q_sum, draws[m]) # find index for this draw 

256 K[j] += 1 # increment its allocated slots 

257 Q[j] = 0.0 # zero out its probability because we used it 

258 

259 # Make an array of atom indices based on the final slot counts 

260 nested_events = [K[j] * [j] for j in range(J)] 

261 events = np.array([i for sublist in nested_events for i in sublist]) 

262 

263 # Draw a random permutation of the indices 

264 indices = self._rng.permutation(events) 

265 

266 # Draw event indices randomly from the discrete distribution 

267 else: 

268 indices = self.draw_events(N) 

269 

270 # Create and fill in the output array of draws based on the output of event indices 

271 draws = atoms[..., indices] 

272 

273 # TODO: some models expect univariate draws to just be a 1d vector. Fix those models. 

274 if len(draws.shape) == 2 and draws.shape[0] == 1: 

275 draws = draws.flatten() 

276 

277 return draws 

278 

279 def expected( 

280 self, func: Optional[Callable] = None, *args: np.ndarray 

281 ) -> np.ndarray: 

282 """ 

283 Expected value of a function, given an array of configurations of its 

284 inputs along with a DiscreteDistribution object that specifies the 

285 probability of each configuration. 

286 

287 If no function is provided, it's much faster to go straight to dot 

288 product instead of calling the dummy function. 

289 

290 If a function is provided, we need to add one more dimension, 

291 the atom dimension, to any inputs that are n-dim arrays. 

292 This allows numpy to easily broadcast the function's output. 

293 For more information on broadcasting, see: 

294 https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules 

295 

296 Parameters 

297 ---------- 

298 func : function 

299 The function to be evaluated. 

300 This function should take the full array of distribution values 

301 and return either arrays of arbitrary shape or scalars. 

302 It may also take other arguments \\*args. 

303 This function differs from the standalone `calc_expectation` 

304 method in that it uses numpy's vectorization and broadcasting 

305 rules to avoid costly iteration. 

306 Note: If you need to use a function that acts on single outcomes 

307 of the distribution, consider `distribution.calc_expectation`. 

308 \\*args : 

309 Other inputs for func, representing the non-stochastic arguments. 

310 The the expectation is computed at ``f(dstn, *args)``. 

311 

312 Returns 

313 ------- 

314 f_exp : np.array or scalar 

315 The expectation of the function at the queried values. 

316 Scalar if only one value. 

317 """ 

318 

319 if func is None: 

320 return np.dot(self.atoms, self.pmv) 

321 

322 if args: 

323 args = [ 

324 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg 

325 for arg in args 

326 ] 

327 return np.dot(func(self.atoms, *args), self.pmv) 

328 return np.dot(func(self.atoms), self.pmv) 

329 

330 def dist_of_func( 

331 self, func: Callable[..., float] = lambda x: x, *args: Any 

332 ) -> "DiscreteDistribution": 

333 """ 

334 Finds the distribution of a random variable Y that is a function 

335 of discrete random variable atoms, Y=f(atoms). 

336 

337 Parameters 

338 ---------- 

339 func : function 

340 The function to be evaluated. 

341 This function should take the full array of distribution values. 

342 It may also take other arguments \\*args. 

343 \\*args : 

344 Additional non-stochastic arguments for func, 

345 The function is computed as ``f(dstn, *args)``. 

346 

347 Returns 

348 ------- 

349 f_dstn : DiscreteDistribution 

350 The distribution of func(dstn). 

351 """ 

352 # we need to add one more dimension, 

353 # the atom dimension, to any inputs that are n-dim arrays. 

354 # This allows numpy to easily broadcast the function's output. 

355 args = [ 

356 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg 

357 for arg in args 

358 ] 

359 f_query = func(self.atoms, *args) 

360 

361 f_dstn = DiscreteDistribution(list(self.pmv), f_query, seed=self.seed) 

362 

363 return f_dstn 

364 

365 def discretize(self, N: int, *args: Any, **kwargs: Any) -> "DiscreteDistribution": 

366 """ 

367 `DiscreteDistribution` is already an approximation, so this method 

368 returns a copy of the distribution. 

369 

370 TODO: print warning message? 

371 """ 

372 return self 

373 

374 def make_univariate(self, dim_to_keep, seed=0): 

375 """ 

376 Make a univariate discrete distribution from this distribution, keeping 

377 only the specified dimension. 

378 

379 Parameters 

380 ---------- 

381 dim_to_keep : int 

382 Index of the distribution to be kept. Any other dimensions will be 

383 "collapsed" into the univariate atoms, combining probabilities. 

384 seed : int, optional 

385 Seed for random number generator of univariate distribution 

386 

387 Returns 

388 ------- 

389 univariate_dstn : DiscreteDistribution 

390 Univariate distribution with only the specified index. 

391 """ 

392 # Do basic validity and triviality checks 

393 if (self.atoms.shape[0] == 1) and (dim_to_keep == 0): 

394 return deepcopy(self) # Return copy of self if only one dimension 

395 if dim_to_keep >= self.atoms.shape[0]: 

396 raise ValueError("dim_to_keep exceeds dimensionality of distribution.") 

397 

398 # Construct values and probabilities for univariate distribution 

399 atoms_temp = self.atoms[dim_to_keep] 

400 vals_to_keep = np.unique(atoms_temp) 

401 probs_to_keep = np.zeros_like(vals_to_keep) 

402 for i in range(vals_to_keep.size): 

403 val = vals_to_keep[i] 

404 these = atoms_temp == val 

405 probs_to_keep[i] = np.sum(self.pmv[these]) 

406 

407 # Make and return the univariate distribution 

408 univariate_dstn = DiscreteDistribution( 

409 pmv=probs_to_keep, atoms=vals_to_keep, seed=seed 

410 ) 

411 return univariate_dstn 

412 

413 

414class DiscreteDistributionLabeled(DiscreteDistribution): 

415 """ 

416 A representation of a discrete probability distribution stored in an underlying `xarray.Dataset`. 

417 

418 Parameters 

419 ---------- 

420 pmv : np.array 

421 An array of values representing a probability mass function. 

422 data : np.array 

423 Discrete point values for each probability mass. 

424 For multivariate distributions, the last dimension of atoms must index 

425 "atom" or the random realization. For instance, if atoms.shape == (2,6,4), 

426 the random variable has 4 possible realizations and each of them has shape (2,6). 

427 seed : int 

428 Seed for random number generator. 

429 name : str 

430 Name of the distribution. 

431 attrs : dict 

432 Attributes for the distribution. 

433 var_names : list of str 

434 Names of the variables in the distribution. 

435 var_attrs : list of dict 

436 Attributes of the variables in the distribution. 

437 """ 

438 

439 def __init__( 

440 self, 

441 pmv: np.ndarray, 

442 atoms: np.ndarray, 

443 seed: int = None, 

444 limit: Optional[Dict[str, Any]] = None, 

445 name: str = "DiscreteDistributionLabeled", 

446 attrs: Optional[Dict[str, Any]] = None, 

447 var_names: Optional[List[str]] = None, 

448 var_attrs: Optional[List[Optional[Dict[str, Any]]]] = None, 

449 ): 

450 super().__init__(pmv, atoms, seed=seed, limit=limit) 

451 

452 # vector-value distributions 

453 

454 if self.atoms.ndim > 2: 

455 raise NotImplementedError( 

456 "Only vector-valued distributions are supported for now." 

457 ) 

458 

459 attrs = {} if attrs is None else attrs 

460 if limit is None: 

461 limit = { 

462 "infimum": np.min(self.atoms, axis=-1), 

463 "supremum": np.max(self.atoms, axis=-1), 

464 } 

465 self.limit = limit 

466 attrs.update(limit) 

467 attrs["name"] = name 

468 

469 n_var = self.atoms.shape[0] 

470 

471 # give dummy names to variables if none are provided 

472 if var_names is None: 

473 var_names = ["var_" + str(i) for i in range(n_var)] 

474 

475 assert len(var_names) == n_var, ( 

476 "Number of variable names does not match number of variables." 

477 ) 

478 

479 # give dummy attributes to variables if none are provided 

480 if var_attrs is None: 

481 var_attrs = [None] * n_var 

482 

483 # a DiscreteDistributionLabeled is an xr.Dataset where the only 

484 # dimension is "atom", which indexes the random realizations. 

485 self.dataset = xr.Dataset( 

486 { 

487 var_names[i]: xr.DataArray( 

488 self.atoms[i], 

489 dims=("atom"), 

490 attrs=var_attrs[i], 

491 ) 

492 for i in range(n_var) 

493 }, 

494 attrs=attrs, 

495 ) 

496 

497 # the probability mass values are stored in 

498 # a DataArray with dimension "atom" 

499 self.probability = xr.DataArray(self.pmv, dims=("atom")) 

500 

501 # cache for fast labeled access in expected() 

502 self._var_names = var_names 

503 self._wrapped_atoms = dict(zip(var_names, self.atoms)) 

504 

505 @classmethod 

506 def from_unlabeled( 

507 cls, 

508 dist, 

509 name="DiscreteDistributionLabeled", 

510 attrs=None, 

511 var_names=None, 

512 var_attrs=None, 

513 ): 

514 ldd = cls( 

515 dist.pmv, 

516 dist.atoms, 

517 seed=dist.seed, 

518 limit=dist.limit, 

519 name=name, 

520 attrs=attrs, 

521 var_names=var_names, 

522 var_attrs=var_attrs, 

523 ) 

524 

525 return ldd 

526 

527 @classmethod 

528 def from_dataset(cls, x_obj, pmf): 

529 ldd = cls.__new__(cls) 

530 

531 if isinstance(x_obj, xr.Dataset): 

532 ldd.dataset = x_obj 

533 elif isinstance(x_obj, xr.DataArray): 

534 ldd.dataset = xr.Dataset({x_obj.name: x_obj}) 

535 elif isinstance(x_obj, dict): 

536 ldd.dataset = xr.Dataset(x_obj) 

537 else: 

538 raise TypeError( 

539 f"from_dataset: 'x_obj' must be an xr.Dataset, xr.DataArray, " 

540 f"or dict, got {type(x_obj).__name__}." 

541 ) 

542 

543 ldd.probability = pmf 

544 ldd.pmv = np.asarray(pmf) 

545 ldd._var_names = list(ldd.dataset.data_vars) 

546 

547 # Derive atoms from dataset variables that have the "atom" dimension. 

548 atom_vars = [ 

549 v 

550 for v in ldd._var_names 

551 if "atom" in ldd.dataset[v].dims and ldd.dataset[v].ndim == 1 

552 ] 

553 if atom_vars: 

554 ldd.atoms = np.stack([ldd.dataset[v].values for v in atom_vars]) 

555 else: 

556 ldd.atoms = np.atleast_2d(np.zeros(len(ldd.pmv))) 

557 

558 ldd.limit = { 

559 "infimum": np.min(ldd.atoms, axis=-1), 

560 "supremum": np.max(ldd.atoms, axis=-1), 

561 } 

562 # No seed argument available; default to 0 for consistency. 

563 ldd.seed = 0 

564 ldd._rng = np.random.default_rng(0) 

565 # cache for fast labeled access in expected() 

566 ldd._wrapped_atoms = dict(zip(atom_vars, ldd.atoms)) 

567 

568 return ldd 

569 

570 @property 

571 def variables(self): 

572 """ 

573 A dict-like container of DataArrays corresponding to 

574 the variables of the distribution. 

575 """ 

576 return self.dataset.data_vars 

577 

578 @property 

579 def name(self): 

580 """ 

581 The distribution's name. 

582 """ 

583 return self.dataset.name 

584 

585 @property 

586 def attrs(self): 

587 """ 

588 The distribution's attributes. 

589 """ 

590 

591 return self.dataset.attrs 

592 

593 def dist_of_func( 

594 self, func: Callable = lambda x: x, *args, **kwargs 

595 ) -> DiscreteDistribution: 

596 """ 

597 Finds the distribution of a random variable Y that is a function 

598 of discrete random variable atoms, Y=f(atoms). 

599 

600 Parameters 

601 ---------- 

602 func : function 

603 The function to be evaluated. 

604 This function should take the full array of distribution values. 

605 It may also take other arguments \\*args. 

606 \\*args : 

607 Additional non-stochastic arguments for func, 

608 The function is computed as ``f(dstn, *args)``. 

609 \\*\\*kwargs : 

610 Additional keyword arguments for func. Must be xarray compatible 

611 in order to work with xarray broadcasting. 

612 

613 Returns 

614 ------- 

615 f_dstn : DiscreteDistribution or DiscreteDistributionLabeled 

616 The distribution of func(dstn). 

617 """ 

618 

619 def func_wrapper(x: np.ndarray, *args: Any) -> np.ndarray: 

620 """ 

621 Wrapper function for `func` that handles labeled indexing. 

622 """ 

623 

624 idx = self.variables.keys() 

625 wrapped = dict(zip(idx, x)) 

626 

627 return func(wrapped, *args) 

628 

629 if len(kwargs): 

630 f_query = func(self.dataset, **kwargs) 

631 ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability) 

632 

633 return ldd 

634 

635 return super().dist_of_func(func_wrapper, *args) 

636 

637 def expected( 

638 self, func: Optional[Callable] = None, *args: Any, **kwargs: Any 

639 ) -> Union[float, np.ndarray]: 

640 """ 

641 Expectation of a function, given an array of configurations of its inputs 

642 along with a DiscreteDistributionLabeled object that specifies the probability 

643 of each configuration. 

644 

645 Parameters 

646 ---------- 

647 func : function 

648 The function to be evaluated. 

649 By default, func receives a dict mapping variable names to numpy 

650 arrays, e.g. ``{"perm_shk": array(...), "tran_shk": array(...)}``. 

651 When ``labels=False``, func receives the raw numpy atoms array 

652 instead (integer indexing, like ``DiscreteDistribution``). 

653 When extra keyword arguments are passed, func receives the full 

654 ``xr.Dataset``. 

655 It may also take other arguments \\*args. 

656 This function differs from the standalone `calc_expectation` 

657 method in that it uses numpy's vectorization and broadcasting 

658 rules to avoid costly iteration. 

659 Note: If you need to use a function that acts on single outcomes 

660 of the distribution, consider `distribution.calc_expectation`. 

661 \\*args : 

662 Other inputs for func, representing the non-stochastic arguments. 

663 The the expectation is computed at ``f(dstn, *args)``. 

664 labels : bool, optional 

665 If True (default), func receives a dict of labeled arrays. 

666 If False, func receives the raw numpy atoms array, making DDL 

667 compatible with functions written for ``DiscreteDistribution``. 

668 

669 Returns 

670 ------- 

671 f_exp : np.array or scalar 

672 The expectation of the function at the queried values. 

673 Scalar if only one value. 

674 """ 

675 labels = kwargs.pop("labels", True) 

676 

677 if kwargs: 

678 if func is None: 

679 return _weighted_mean(self.dataset, self.pmv) 

680 f_query = func(self.dataset, *args, **kwargs) 

681 return _weighted_mean(f_query, self.pmv) 

682 

683 if func is None: 

684 return np.dot(self.atoms, self.pmv) 

685 

686 # labels=False: pass raw numpy atoms, like DiscreteDistribution 

687 if not labels: 

688 if args: 

689 args = [ 

690 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg 

691 for arg in args 

692 ] 

693 return np.dot(func(self.atoms, *args), self.pmv) 

694 return np.dot(func(self.atoms), self.pmv) 

695 

696 # Fast labeled path: use cached dict {var_name: np.ndarray} instead of 

697 # the xarray dataset. _wrapped_atoms maps variable names to the raw 

698 # atom arrays, enabling np.dot without xarray overhead. 

699 if args: 

700 args = [ 

701 np.expand_dims(arg, -1) if isinstance(arg, np.ndarray) else arg 

702 for arg in args 

703 ] 

704 return np.dot(func(self._wrapped_atoms, *args), self.pmv) 

705 return np.dot(func(self._wrapped_atoms), self.pmv)