Coverage for HARK / simulator.py: 94%

1576 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-10 06:19 +0000

1""" 

2A module with classes and functions for automated simulation of HARK.AgentType 

3models from a human- and machine-readable model specification. 

4""" 

5 

6from dataclasses import dataclass, field 

7from copy import copy, deepcopy 

8import numpy as np 

9from numba import njit 

10from sympy.utilities.lambdify import lambdify 

11from sympy import symbols, IndexedBase 

12from typing import Callable 

13from HARK.utilities import NullFunc, make_polynomial_grid, make_grid_exp_mult 

14from HARK.distributions import Distribution 

15from scipy.sparse import csr_matrix, csc_matrix 

16from scipy.sparse.linalg import eigs 

17from scipy.optimize import brentq 

18from itertools import product 

19import importlib.resources 

20import yaml 

21 

22# Prevent pre-commit from removing sympy 

23x = symbols("x") 

24del x 

25y = IndexedBase("y") 

26del y 

27 

28 

29@dataclass(kw_only=True) 

30class ModelEvent: 

31 """ 

32 Class for representing "events" that happen to agents in the course of their 

33 model. These might be statements of dynamics, realization of a random shock, 

34 or the evaluation of a function (potentially a control or other solution- 

35 based object). This is a superclass for types of events defined below. 

36 

37 Parameters 

38 ---------- 

39 description : str 

40 Text description of this model event. 

41 statement : str 

42 The line of the model statement that this event corresponds to. 

43 parameters : dict 

44 Dictionary of objects that are static / universal within this event. 

45 assigns : list[str] 

46 List of names of variables that this event assigns values for. 

47 needs : list[str] 

48 List of names of variables that this event requires to be run. 

49 data : dict 

50 Dictionary of current variable values within this event. 

51 common : bool 

52 Indicator for whether the variables assigned in this event are commonly 

53 held across all agents, rather than idiosyncratic. 

54 N : int 

55 Number of agents currently in this event. 

56 """ 

57 

58 statement: str = field(default="") 

59 parameters: dict = field(default_factory=dict) 

60 description: str = field(default="") 

61 assigns: list[str] = field(default_factory=list, repr=False) 

62 needs: list = field(default_factory=list, repr=False) 

63 data: dict = field(default_factory=dict, repr=False) 

64 common: bool = field(default=False, repr=False) 

65 N: int = field(default=1, repr=False) 

66 

67 def run(self): 

68 """ 

69 This method should be filled in by each subclass. 

70 """ 

71 pass # pragma: nocover 

72 

73 def reset(self): 

74 self.data = {} 

75 

76 def assign(self, output): 

77 if len(self.assigns) > 1: 

78 assert len(self.assigns) == len(output) 

79 for j in range(len(self.assigns)): 

80 var = self.assigns[j] 

81 if type(output[j]) is not np.ndarray: 

82 output[j] = np.array([output[j]]) 

83 self.data[var] = output[j] 

84 else: 

85 var = self.assigns[0] 

86 if type(output) is not np.ndarray: 

87 output = np.array([output]) 

88 self.data[var] = output 

89 

90 def expand_information(self, origins, probs, atoms, which=None): 

91 """ 

92 This method is only called internally when a RandomEvent or MarkovEvent 

93 runs its quasi_run() method. It expands the set of of "probability blobs" 

94 by applying a random realization event. All extant blobs for which the 

95 shock applies are replicated for each atom in the random event, with the 

96 probability mass divided among the replicates. 

97 

98 Parameters 

99 ---------- 

100 origins : np.array 

101 Array that tracks which arrival state space node each blob originated 

102 from. This is expanded into origins_new, which is returned. 

103 probs : np.array 

104 Vector of probabilities of each of the random possibilities. 

105 atoms : [np.array] 

106 List of arrays with realization values for the distribution. Each 

107 array corresponds to one variable that is assigned by this event. 

108 which : np.array or None 

109 If given, a Boolean array indicating which of the pre-existing blobs 

110 is affected by the given probabilities and atoms. By default, all 

111 blobs are assumed to be affected. 

112 

113 Returns 

114 ------- 

115 origins_new : np.array 

116 Expanded boolean array of indicating the arrival state space node that 

117 each blob originated from. 

118 """ 

119 K = probs.size 

120 N = self.N 

121 if which is None: 

122 which = np.ones(N, dtype=bool) 

123 other = np.logical_not(which) 

124 M = np.sum(which) # how many blobs are we affecting? 

125 MX = N - M # how many blobs are we not affecting? 

126 

127 # Update probabilities of outcomes 

128 pmv_old = np.reshape(self.data["pmv_"][which], (M, 1)) 

129 pmv_new = (pmv_old * np.reshape(probs, (1, K))).flatten() 

130 self.data["pmv_"] = np.concatenate((self.data["pmv_"][other], pmv_new)) 

131 

132 # Replicate the pre-existing data for each atom 

133 for var in self.data.keys(): 

134 if (var == "pmv_") or (var in self.assigns): 

135 continue # don't double expand pmv, and don't touch assigned variables 

136 data_old = np.reshape(self.data[var][which], (M, 1)) 

137 data_new = np.tile(data_old, (1, K)).flatten() 

138 self.data[var] = np.concatenate((self.data[var][other], data_new)) 

139 

140 # If any of the assigned variables don't exist yet, add dummy versions 

141 # of them. This section exists so that the code works with "partial events" 

142 # on both the first pass and subsequent passes. 

143 for j in range(len(self.assigns)): 

144 var = self.assigns[j] 

145 if var in self.data.keys(): 

146 continue 

147 self.data[var] = np.zeros(N, dtype=atoms[j].dtype) 

148 # Zeros are just dummy values 

149 

150 # Add the new random variables to the simulation data. This generates 

151 # replicates for the affected blobs and leaves the others untouched, 

152 # still with their dummy values. They will be altered on later passes. 

153 for j in range(len(self.assigns)): 

154 var = self.assigns[j] 

155 data_new = np.tile(np.reshape(atoms[j], (1, K)), (M, 1)).flatten() 

156 self.data[var] = np.concatenate((self.data[var][other], data_new)) 

157 

158 # Expand the origins array to account for the new replicates 

159 origins_new = np.tile(np.reshape(origins[which], (M, 1)), (1, K)).flatten() 

160 origins_new = np.concatenate((origins[other], origins_new)) 

161 self.N = MX + M * K 

162 

163 # Send the new origins array back to the calling process 

164 return origins_new 

165 

166 def add_idiosyncratic_bernoulli_info(self, origins, probs): 

167 """ 

168 Special method for adding Bernoulli outcomes to the information set when 

169 probabilities are idiosyncratic to each agent. All extant blobs are duplicated 

170 with the appropriate probability 

171 

172 Parameters 

173 ---------- 

174 origins : np.array 

175 Array that tracks which arrival state space node each blob originated 

176 from. This is expanded into origins_new, which is returned. 

177 probs : np.array 

178 Vector of probabilities of drawing True for each blob. 

179 

180 Returns 

181 ------- 

182 origins_new : np.array 

183 Expanded boolean array of indicating the arrival state space node that 

184 each blob originated from. 

185 """ 

186 N = self.N 

187 

188 # # Update probabilities of outcomes, replicating each one 

189 pmv_old = np.reshape(self.data["pmv_"], (N, 1)) 

190 P = np.reshape(probs, (N, 1)) 

191 PX = np.concatenate([1.0 - P, P], axis=1) 

192 pmv_new = (pmv_old * PX).flatten() 

193 self.data["pmv_"] = pmv_new 

194 

195 # Replicate the pre-existing data for each atom 

196 for var in self.data.keys(): 

197 if (var == "pmv_") or (var in self.assigns): 

198 continue # don't double expand pmv, and don't touch assigned variables 

199 data_old = np.reshape(self.data[var], (N, 1)) 

200 data_new = np.tile(data_old, (1, 2)).flatten() 

201 self.data[var] = data_new 

202 

203 # Add the (one and only) new random variable to the simulation data 

204 var = self.assigns[0] 

205 data_new = np.tile(np.array([[0, 1]]), (N, 1)).flatten() 

206 self.data[var] = data_new 

207 

208 # Expand the origins array to account for the new replicates 

209 origins_new = np.tile(np.reshape(origins, (N, 1)), (1, 2)).flatten() 

210 self.N = N * 2 

211 

212 # Send the new origins array back to the calling process 

213 return origins_new 

214 

215 

216@dataclass(kw_only=True) 

217class DynamicEvent(ModelEvent): 

218 """ 

219 Class for representing model dynamics for an agent, consisting of an expression 

220 to be evaluated and variables to which the results are assigned. 

221 

222 Parameters 

223 ---------- 

224 expr : Callable 

225 Function or expression to be evaluated for the assigned variables. 

226 args : list[str] 

227 Ordered list of argument names for the expression. 

228 """ 

229 

230 expr: Callable = field(default_factory=NullFunc, repr=False) 

231 args: list[str] = field(default_factory=list, repr=False) 

232 

233 def evaluate(self): 

234 temp_dict = self.data.copy() 

235 temp_dict.update(self.parameters) 

236 args = (temp_dict[arg] for arg in self.args) 

237 out = self.expr(*args) 

238 return out 

239 

240 def run(self): 

241 self.assign(self.evaluate()) 

242 

243 def quasi_run(self, origins, norm=None): 

244 self.run() 

245 return origins 

246 

247 

248@dataclass(kw_only=True) 

249class RandomEvent(ModelEvent): 

250 """ 

251 Class for representing the realization of random variables for an agent, 

252 consisting of a shock distribution and variables to which the results are assigned. 

253 

254 Parameters 

255 ---------- 

256 dstn : Distribution 

257 Distribution of one or more random variables that are drawn from during 

258 this event and assigned to the corresponding variables. 

259 """ 

260 

261 dstn: Distribution = field(default_factory=Distribution, repr=False) 

262 

263 def reset(self): 

264 self.dstn.reset() 

265 ModelEvent.reset(self) 

266 

267 def draw(self): 

268 out = np.empty((len(self.assigns), self.N)) 

269 if not self.common: 

270 out[:, :] = self.dstn.draw(self.N) 

271 else: 

272 out[:, :] = self.dstn.draw(1) 

273 if len(self.assigns) == 1: 

274 out = out.flatten() 

275 return out 

276 

277 def run(self): 

278 self.assign(self.draw()) 

279 

280 def _apply_harmenberg(self, probs, atoms, norm): 

281 """ 

282 Apply Harmenberg permanent-income normalization to ``probs`` in place. 

283 

284 If ``norm`` matches one of the variable names in ``self.assigns``, 

285 scale ``probs`` by the corresponding atoms; otherwise leave ``probs`` 

286 unchanged. Returns ``probs`` for chained use. 

287 """ 

288 try: 

289 harm_idx = self.assigns.index(norm) 

290 probs *= atoms[harm_idx] 

291 except ValueError: 

292 pass 

293 return probs 

294 

295 def quasi_run(self, origins, norm=None): 

296 # Get distribution 

297 atoms = self.dstn.atoms 

298 probs = self.dstn.pmv.copy() 

299 

300 probs = self._apply_harmenberg(probs, atoms, norm) 

301 

302 # Expand the set of simulated blobs 

303 origins_new = self.expand_information(origins, probs, atoms) 

304 return origins_new 

305 

306 

307@dataclass(kw_only=True) 

308class RandomIndexedEvent(RandomEvent): 

309 """ 

310 Class for representing the realization of random variables for an agent, 

311 consisting of a list of shock distributions, an index for the list, and the 

312 variables to which the results are assigned. 

313 

314 Parameters 

315 ---------- 

316 dstn : [Distribution] 

317 List of distributions of one or more random variables that are drawn 

318 from during this event and assigned to the corresponding variables. 

319 index : str 

320 Name of the index that is used to choose a distribution for each agent. 

321 """ 

322 

323 index: str = field(default="", repr=False) 

324 dstn: list[Distribution] = field(default_factory=list, repr=False) 

325 

326 def draw(self): 

327 idx = self.data[self.index] 

328 K = len(self.assigns) 

329 out = np.empty((K, self.N)) 

330 out.fill(np.nan) 

331 

332 if self.common: 

333 k = idx[0] # this will behave badly if index is not itself common 

334 out[:, :] = self.dstn[k].draw(1) 

335 return out 

336 

337 for k in range(len(self.dstn)): 

338 these = idx == k 

339 if not np.any(these): 

340 continue 

341 out[:, these] = self.dstn[k].draw(np.sum(these)) 

342 if K == 1: 

343 out = out.flatten() 

344 return out 

345 

346 def reset(self): 

347 for k in range(len(self.dstn)): 

348 self.dstn[k].reset() 

349 ModelEvent.reset(self) 

350 

351 def quasi_run(self, origins, norm=None): 

352 origins_new = origins.copy() 

353 J = len(self.dstn) 

354 

355 for j in range(J): 

356 idx = self.data[self.index] 

357 these = idx == j 

358 

359 # Get distribution 

360 atoms = self.dstn[j].atoms 

361 probs = self.dstn[j].pmv.copy() 

362 

363 probs = self._apply_harmenberg(probs, atoms, norm) 

364 

365 # Expand the set of simulated blobs 

366 origins_new = self.expand_information( 

367 origins_new, probs, atoms, which=these 

368 ) 

369 

370 # Return the altered origins array 

371 return origins_new 

372 

373 

374@dataclass(kw_only=True) 

375class MarkovEvent(ModelEvent): 

376 """ 

377 Class for representing the realization of a Markov draw for an agent, in which 

378 a Markov probabilities (array, vector, or a single float) is used to determine 

379 the realization of some discrete outcome. If the probabilities are a 2D array, 

380 it represents a Markov matrix (rows sum to 1), and there must be an index; if 

381 the probabilities are a vector, it should be a stochastic vector; if it's a 

382 single float, it represents a Bernoulli probability. 

383 """ 

384 

385 probs: str = field(default="", repr=False) 

386 index: str = field(default="", repr=False) 

387 N: int = field(default=1, repr=False) 

388 seed: int = field(default=0, repr=False) 

389 # seed is overwritten when each period is created 

390 

391 def __post_init__(self): 

392 self.reset_rng() 

393 

394 def reset(self): 

395 self.reset_rng() 

396 ModelEvent.reset(self) 

397 

398 def reset_rng(self): 

399 self.RNG = np.random.RandomState(self.seed) 

400 

401 def draw(self): 

402 # Initialize the output 

403 out = -np.ones(self.N, dtype=int) 

404 if self.probs in self.parameters: 

405 probs = self.parameters[self.probs] 

406 probs_are_param = True 

407 else: 

408 probs = self.data[self.probs] 

409 probs_are_param = False 

410 

411 # Make the base draw(s) 

412 if self.common: 

413 X = self.RNG.rand(1) 

414 else: 

415 X = self.RNG.rand(self.N) 

416 

417 if self.index: # it's a Markov matrix 

418 idx = self.data[self.index] 

419 J = probs.shape[0] 

420 for j in range(J): 

421 these = idx == j 

422 if not np.any(these): 

423 continue 

424 P = np.cumsum(probs[j, :]) 

425 if self.common: 

426 out[:] = np.searchsorted(P, X[0]) # only one value of X! 

427 else: 

428 out[these] = np.searchsorted(P, X[these]) 

429 return out 

430 

431 if (isinstance(probs, np.ndarray)) and ( 

432 probs_are_param 

433 ): # it's a stochastic vector 

434 P = np.cumsum(probs) 

435 if self.common: 

436 out[:] = np.searchsorted(P, X[0]) 

437 return out 

438 else: 

439 return np.searchsorted(P, X) 

440 

441 # Otherwise, this is just a Bernoulli RV 

442 P = probs 

443 if self.common: 

444 out[:] = X < P 

445 return out 

446 else: 

447 return X < P # basic Bernoulli 

448 

449 def run(self): 

450 self.assign(self.draw()) 

451 

452 def quasi_run(self, origins, norm=None): 

453 if self.probs in self.parameters: 

454 probs = self.parameters[self.probs] 

455 probs_are_param = True 

456 else: 

457 probs = self.data[self.probs] 

458 probs_are_param = False 

459 

460 # If it's a Markov matrix: 

461 if self.index: 

462 K = probs.shape[0] 

463 atoms = np.array([np.arange(probs.shape[1], dtype=int)]) 

464 origins_new = origins.copy() 

465 for k in range(K): 

466 idx = self.data[self.index] 

467 these = idx == k 

468 probs_temp = probs[k, :] 

469 origins_new = self.expand_information( 

470 origins_new, probs_temp, atoms, which=these 

471 ) 

472 return origins_new 

473 

474 # If it's a stochastic vector: 

475 if (isinstance(probs, np.ndarray)) and (probs_are_param): 

476 atoms = np.array([np.arange(probs.shape[0], dtype=int)]) 

477 origins_new = self.expand_information(origins, probs, atoms) 

478 return origins_new 

479 

480 # Otherwise, this is just a Bernoulli RV, but it might have idiosyncratic probability 

481 if probs_are_param: 

482 P = probs 

483 atoms = np.array([[False, True]]) 

484 origins_new = self.expand_information(origins, np.array([1 - P, P]), atoms) 

485 return origins_new 

486 

487 # Final case: probability is idiosyncratic Bernoulli 

488 origins_new = self.add_idiosyncratic_bernoulli_info(origins, probs) 

489 return origins_new 

490 

491 

492@dataclass(kw_only=True) 

493class EvaluationEvent(ModelEvent): 

494 """ 

495 Class for representing the evaluation of a model function. This might be from 

496 the solution of the model (like a policy function or decision rule) or just 

497 a non-algebraic function used in the model. This looks a lot like DynamicEvent. 

498 

499 Parameters 

500 ---------- 

501 func : Callable 

502 Model function that is evaluated in this event, with the output assigned 

503 to the appropriate variables. 

504 """ 

505 

506 func: Callable = field(default_factory=NullFunc, repr=False) 

507 arguments: list[str] = field(default_factory=list, repr=False) 

508 

509 def evaluate(self): 

510 temp_dict = self.data.copy() 

511 temp_dict.update(self.parameters) 

512 args_temp = (temp_dict[arg] for arg in self.arguments) 

513 out = self.func(*args_temp) 

514 return out 

515 

516 def run(self): 

517 self.assign(self.evaluate()) 

518 

519 def quasi_run(self, origins, norm=None): 

520 self.run() 

521 return origins 

522 

523 

524@dataclass(kw_only=True) 

525class SimBlock: 

526 """ 

527 Class for representing a "block" of a simulated model, which might be a whole 

528 period or a "stage" within a period. 

529 

530 Parameters 

531 ---------- 

532 description : str 

533 Textual description of what happens in this simulated block. 

534 statement : str 

535 Verbatim model statement that was used to create this block. 

536 content : dict 

537 Dictionary of objects that are constant / universal within the block. 

538 This includes both traditional numeric parameters as well as functions. 

539 arrival : list[str] 

540 List of inbound states: information available at the *start* of the block. 

541 events: list[ModelEvent] 

542 Ordered list of events that happen during the block. 

543 data: dict 

544 Dictionary that stores current variable values. 

545 N : int 

546 Number of idiosyncratic agents in this block. 

547 """ 

548 

549 statement: str = field(default="", repr=False) 

550 content: dict = field(default_factory=dict) 

551 description: str = field(default="", repr=False) 

552 arrival: list[str] = field(default_factory=list, repr=False) 

553 events: list[ModelEvent] = field(default_factory=list, repr=False) 

554 data: dict = field(default_factory=dict, repr=False) 

555 N: int = field(default=1, repr=False) 

556 

557 def run(self): 

558 """ 

559 Run this simulated block by running each of its events in order. 

560 """ 

561 for j in range(len(self.events)): 

562 event = self.events[j] 

563 for k in range(len(event.assigns)): 

564 var = event.assigns[k] 

565 if var in event.data.keys(): 

566 del event.data[var] 

567 for k in range(len(event.needs)): 

568 var = event.needs[k] 

569 event.data[var] = self.data[var] 

570 event.N = self.N 

571 event.run() 

572 for k in range(len(event.assigns)): 

573 var = event.assigns[k] 

574 self.data[var] = event.data[var] 

575 

576 def reset(self): 

577 """ 

578 Reset the simulated block by resetting each of its events. 

579 """ 

580 self.data = {} 

581 for j in range(len(self.events)): 

582 self.events[j].reset() 

583 

584 def _lookup_content(self, name, kind): 

585 """Return ``self.content[name]`` or raise ValueError naming ``kind``.""" 

586 try: 

587 return self.content[name] 

588 except KeyError: 

589 verb = "distribute the" if kind == "parameter" else "find a" 

590 raise ValueError(f"Could not {verb} {kind} called {name}!") from None 

591 

592 def distribute_content(self): 

593 """ 

594 Fill in parameters, functions, and distributions to each event. 

595 """ 

596 for event in self.events: 

597 for param in event.parameters.keys(): 

598 event.parameters[param] = self._lookup_content(param, "parameter") 

599 if (type(event) is RandomEvent) or (type(event) is RandomIndexedEvent): 

600 event.dstn = self._lookup_content(event._dstn_name, "distribution") 

601 if type(event) is EvaluationEvent: 

602 event.func = self._lookup_content(event._func_name, "function") 

603 

604 def _build_input_grids(self, grid_specs, arrival_N): 

605 """ 

606 Build input and output grid dictionaries from grid specifications. 

607 

608 Creates grids_in for arrival variables and grids_out for outcome variables, 

609 tracking which arrival variables have been covered and whether each output 

610 grid is continuous. 

611 

612 Parameters 

613 ---------- 

614 grid_specs : dict 

615 Dictionary of grid specifications keyed by variable name. 

616 arrival_N : int 

617 Number of arrival variables in this block. 

618 

619 Returns 

620 ------- 

621 grids_in : dict 

622 Grids for arrival (input) variables. 

623 grids_out : dict 

624 Grids for outcome (output) variables. 

625 continuous_grid_out_bool : list 

626 List of booleans indicating whether each output grid is continuous. 

627 grid_orders : dict 

628 Polynomial order for each variable grid (-1 for discrete, None if not polynomial). 

629 grid_nests : dict 

630 Number of times each variable grid is exponentially nested (None if not exponential). 

631 dummy_grid : np.ndarray or None 

632 Dummy grid used only when arrival_N == 0. 

633 """ 

634 completed = arrival_N * [False] 

635 grids_in = {} 

636 grids_out = {} 

637 dummy_grid = None 

638 if arrival_N == 0: # should only be for initializer block 

639 dummy_grid = np.array([0]) 

640 grids_in["_dummy"] = dummy_grid 

641 

642 continuous_grid_out_bool = [] 

643 grid_orders = {} 

644 grid_nests = {} 

645 for var in grid_specs.keys(): 

646 spec = grid_specs[var] 

647 try: 

648 idx = self.arrival.index(var) 

649 completed[idx] = True 

650 is_arrival = True 

651 except ValueError: 

652 is_arrival = False 

653 

654 # Determine whether it's polynomial, exponential, or custom 

655 if "custom" in spec: # it's custom-specified 

656 this_grid = np.array(spec["custom"]).astype(float).flatten() 

657 if not np.all(np.diff(this_grid) > 0): 

658 raise ValueError( 

659 "Custom grid for " + var + " is not strictly increasing!" 

660 ) 

661 new_grid = this_grid 

662 is_cont = True 

663 grid_orders[var] = None 

664 grid_nests[var] = None 

665 elif ("min" in spec) and ("max" in spec): 

666 bot = spec["min"] 

667 top = spec["max"] 

668 try: 

669 N = spec["N"] 

670 except: 

671 raise KeyError( 

672 "Need to specify number of gridpoints N in " + var + " grid!" 

673 ) 

674 if "order" in spec: # it's polynomially spaced 

675 Q = spec["order"] 

676 new_grid = make_polynomial_grid(bot, top, N, Q) 

677 grid_orders[var] = Q 

678 grid_nests[var] = None 

679 elif "nest" in spec: # it's exponentially spaced 

680 K = spec["nest"] 

681 if type(K) is not int: 

682 raise TypeError( 

683 "The nest level for " + var + " must be an integer!" 

684 ) 

685 if K < 1: 

686 raise ValueError( 

687 "Nest level must be at least 1, but is " 

688 + str(K) 

689 + " for " 

690 + var 

691 + "!" 

692 ) 

693 new_grid = make_grid_exp_mult( 

694 0.0, top - bot, N, timestonest=K, offset=bot 

695 ) 

696 grid_orders[var] = None 

697 grid_nests[var] = K 

698 else: # default to linearly spaced grid 

699 new_grid = np.linspace(bot, top, N) 

700 grid_orders[var] = 1.0 

701 grid_nests[var] = None 

702 is_cont = True 

703 elif "N" in spec: 

704 new_grid = np.arange(spec["N"], dtype=int) 

705 is_cont = False 

706 grid_orders[var] = -1 

707 grid_nests[var] = None 

708 else: 

709 new_grid = None # could not make grid, construct later 

710 is_cont = False 

711 grid_orders[var] = None 

712 grid_nests[var] = None 

713 

714 if is_arrival: 

715 grids_in[var] = new_grid 

716 else: 

717 grids_out[var] = new_grid 

718 continuous_grid_out_bool.append(is_cont) 

719 

720 # Verify that specifications were passed for all arrival variables 

721 for j in range(len(self.arrival)): 

722 if not completed[j]: 

723 raise ValueError( 

724 "No grid specification was provided for " + self.arrival[j] + "!" 

725 ) 

726 

727 return ( 

728 grids_in, 

729 grids_out, 

730 continuous_grid_out_bool, 

731 grid_orders, 

732 grid_nests, 

733 dummy_grid, 

734 ) 

735 

736 def _build_twist_grids( 

737 self, 

738 twist, 

739 grids_in, 

740 grid_orders, 

741 grid_nests, 

742 grids_out, 

743 continuous_grid_out_bool, 

744 ): 

745 """ 

746 Override output grids with arrival-matching grids for continuation variables. 

747 

748 When an intertemporal twist is provided, the result grids for continuation 

749 variables are set to match the corresponding arrival variable grids. 

750 

751 Parameters 

752 ---------- 

753 twist : dict 

754 Mapping from continuation variable names to arrival variable names. 

755 grids_in : dict 

756 Grids for arrival variables. 

757 grid_orders : dict 

758 Polynomial orders for each variable grid (modified in place). 

759 grid_nests : dict 

760 Exponential nesting for each variable grid (modified in place). 

761 grids_out : dict 

762 Grids for output variables (modified in place). 

763 continuous_grid_out_bool : list 

764 Boolean continuity flags for output grids (extended in place). 

765 

766 Returns 

767 ------- 

768 grids_out : dict 

769 Updated output grids. 

770 grid_orders : dict 

771 Updated grid orders. 

772 grid_nests : dict 

773 Updated exponential nesting. 

774 grid_out_is_continuous : np.ndarray 

775 Boolean array indicating continuity of each output grid. 

776 """ 

777 for cont_var in twist.keys(): 

778 arr_var = twist[cont_var] 

779 if cont_var not in list(grids_out.keys()): 

780 is_cont = np.issubdtype(grids_in[arr_var].dtype, np.floating) 

781 continuous_grid_out_bool.append(is_cont) 

782 grids_out[cont_var] = copy(grids_in[arr_var]) 

783 grid_orders[cont_var] = grid_orders[arr_var] 

784 grid_nests[cont_var] = grid_nests[arr_var] 

785 grid_out_is_continuous = np.array(continuous_grid_out_bool) 

786 return grids_out, grid_orders, grid_nests, grid_out_is_continuous 

787 

788 def _project_onto_output_grids( 

789 self, 

790 grids_out, 

791 grid_out_is_continuous, 

792 grid_orders, 

793 grid_nests, 

794 cont_vars, 

795 twist, 

796 N_orig, 

797 J, 

798 N, 

799 ): 

800 """ 

801 Project quasi-simulation results onto the discretized output grids. 

802 

803 Loops over each output variable, dispatching to the appropriate 

804 aggregation routine based on whether the grid is continuous or discrete. 

805 

806 Parameters 

807 ---------- 

808 grids_out : dict 

809 Output grids (may be updated for None grids). 

810 grid_out_is_continuous : np.ndarray 

811 Boolean continuity flags for each output variable. 

812 grid_orders : dict 

813 Polynomial orders for each variable grid. 

814 grid_nests : dict 

815 Exponential nesting count for each variable grid. 

816 cont_vars : list 

817 Names of continuation variables. 

818 twist : dict or None 

819 The intertemporal twist mapping (used only to check if provided). 

820 N_orig : int 

821 Number of original arrival-state grid points. 

822 J : int 

823 Size of the arrival state mesh. 

824 N : int 

825 Number of agents in this block. 

826 

827 Returns 

828 ------- 

829 matrices_out : dict 

830 Transition matrices for each output variable. 

831 cont_idx : dict 

832 Lower-bracket indices for continuation variables. 

833 cont_alpha : dict 

834 Interpolation weights for continuation variables. 

835 cont_M : dict 

836 Grid sizes for continuation variables. 

837 cont_discrete : dict 

838 Whether each continuation variable uses a discrete grid. 

839 grids_out : dict 

840 Updated output grids (some None entries may be filled in). 

841 grid_out_is_continuous : np.ndarray 

842 Updated continuity flags (may be changed for size-1 float grids). 

843 """ 

844 origin_array = self.origin_array 

845 matrices_out = {} 

846 cont_idx = {} 

847 cont_alpha = {} 

848 cont_M = {} 

849 cont_discrete = {} 

850 k = 0 

851 for var in grids_out.keys(): 

852 if var not in self.data.keys(): 

853 raise ValueError( 

854 "Variable " + var + " does not exist but a grid was specified!" 

855 ) 

856 grid = grids_out[var] 

857 vals = self.data[var] 

858 pmv = self.data["pmv_"] 

859 M = grid.size if grid is not None else 0 

860 

861 # Semi-hacky fix to deal with omitted arrival variables 

862 if (M == 1) and np.issubdtype(vals.dtype, np.floating): 

863 grid = grid.astype(float) 

864 grids_out[var] = grid 

865 grid_out_is_continuous[k] = True 

866 

867 if grid_out_is_continuous[k]: 

868 # Split the final values among discrete gridpoints on the interior. 

869 if M > 1: 

870 Q = grid_orders[var] 

871 K = grid_nests[var] 

872 is_cont = var in cont_vars 

873 if Q is not None: # it's a polynomial grid 

874 temp_func = ( 

875 aggregate_blobs_onto_polynomial_grid_alt 

876 if is_cont 

877 else aggregate_blobs_onto_polynomial_grid 

878 ) 

879 temp_out = temp_func( 

880 vals, 

881 pmv, 

882 origin_array, 

883 grid, 

884 J, 

885 Q, 

886 ) 

887 elif K is not None: # it's a multi-exponential grid 

888 temp_func = ( 

889 aggregate_blobs_onto_exponential_grid_alt 

890 if is_cont 

891 else aggregate_blobs_onto_exponential_grid 

892 ) 

893 temp_out = temp_func( 

894 vals, 

895 pmv, 

896 origin_array, 

897 grid, 

898 J, 

899 K, 

900 ) 

901 else: # it's a custom grid 

902 temp_func = ( 

903 aggregate_blobs_onto_custom_grid_alt 

904 if is_cont 

905 else aggregate_blobs_onto_custom_grid 

906 ) 

907 temp_out = temp_func( 

908 vals, 

909 pmv, 

910 origin_array, 

911 grid, 

912 J, 

913 ) 

914 

915 # Unpack the output from that function call 

916 if is_cont: 

917 cont_M[var] = M 

918 cont_discrete[var] = False 

919 trans_matrix, cont_idx[var], cont_alpha[var] = temp_out 

920 else: 

921 trans_matrix = temp_out 

922 

923 else: # Skip if the grid is a dummy with only one value. 

924 trans_matrix = np.ones((J, M)) 

925 if var in cont_vars: 

926 cont_idx[var] = np.zeros(N, dtype=int) 

927 cont_alpha[var] = np.zeros(N) 

928 cont_M[var] = M 

929 cont_discrete[var] = False 

930 

931 else: # Grid is discrete, can use simpler method 

932 if grid is None: 

933 M = np.max(vals.astype(int)) 

934 if var == "dead": 

935 M = 2 

936 grid = np.arange(M, dtype=int) 

937 grids_out[var] = grid 

938 M = grid.size 

939 vals = vals.astype(int) 

940 trans_matrix = aggregate_blobs_onto_discrete_grid( 

941 vals, pmv, origin_array, M, J 

942 ) 

943 if var in cont_vars: 

944 cont_idx[var] = vals 

945 cont_alpha[var] = np.zeros(N) 

946 cont_M[var] = M 

947 cont_discrete[var] = True 

948 

949 # Store the transition matrix for this variable 

950 matrices_out[var] = trans_matrix 

951 k += 1 

952 

953 return ( 

954 matrices_out, 

955 cont_idx, 

956 cont_alpha, 

957 cont_M, 

958 cont_discrete, 

959 grids_out, 

960 grid_out_is_continuous, 

961 ) 

962 

963 def _build_master_transition_array( 

964 self, cont_vars, cont_idx, cont_alpha, cont_M, cont_discrete, N_orig, N, D 

965 ): 

966 """ 

967 Construct the master arrival-to-continuation transition array. 

968 

969 Combines per-variable index arrays and interpolation weights into a 

970 single tensor using multilinear interpolation. The offset index arithmetic 

971 and continuation-variable ordering are load-bearing. 

972 

973 Parameters 

974 ---------- 

975 cont_vars : list 

976 Names of continuation variables, ordered to match arrival variables. 

977 cont_idx : dict 

978 Lower-bracket indices for each continuation variable. 

979 cont_alpha : dict 

980 Interpolation weights (upper bracket) for each continuation variable. 

981 cont_M : dict 

982 Grid size for each continuation variable. 

983 cont_discrete : dict 

984 Whether each continuation variable uses a discrete (not continuous) grid. 

985 N_orig : int 

986 Number of arrival-state grid points. 

987 N : int 

988 Total number of quasi-simulated agents. 

989 D : int 

990 Number of continuation dimensions. 

991 

992 Returns 

993 ------- 

994 master_trans_array_X : np.ndarray 

995 Unnormalized master transition array of shape (N_orig, prod(cont_M)). 

996 """ 

997 pmv = self.data["pmv_"] 

998 origin_array = self.origin_array 

999 

1000 # Count the number of non-trivial dimensions. A continuation dimension 

1001 # is non-trivial if it is both continuous and has more than one grid node. 

1002 C = 0 

1003 shape = [N_orig] 

1004 trivial = [] 

1005 for var in cont_vars: 

1006 shape.append(cont_M[var]) 

1007 if (not cont_discrete[var]) and (cont_M[var] > 1): 

1008 C += 1 

1009 trivial.append(False) 

1010 else: 

1011 trivial.append(True) 

1012 trivial = np.array(trivial) 

1013 

1014 # Make a binary array of offsets from the base index 

1015 bin_array_base = np.array(list(product([0, 1], repeat=C))) 

1016 bin_array = np.empty((2**C, D), dtype=int) 

1017 some_zeros = np.zeros(2**C, dtype=int) 

1018 c = 0 

1019 for d in range(D): 

1020 bin_array[:, d] = some_zeros if trivial[d] else bin_array_base[:, c] 

1021 c += not trivial[d] 

1022 

1023 # Make a vector of dimensional offsets from the base index 

1024 dim_offsets = np.ones(D, dtype=int) 

1025 for d in range(D - 1): 

1026 dim_offsets[d] = np.prod(shape[(d + 2) :]) 

1027 dim_offsets_X = np.tile(dim_offsets, (2**C, 1)) 

1028 offsets = np.sum(bin_array * dim_offsets_X, axis=1) 

1029 

1030 # Make combined arrays of indices and alphas 

1031 index_array = np.empty((N, D), dtype=int) 

1032 alpha_array = np.empty((N, D, 2)) 

1033 for d in range(D): 

1034 var = cont_vars[d] 

1035 index_array[:, d] = cont_idx[var] 

1036 alpha_array[:, d, 0] = 1.0 - cont_alpha[var] 

1037 alpha_array[:, d, 1] = cont_alpha[var] 

1038 idx_array = np.dot(index_array, dim_offsets) 

1039 

1040 # Make the master transition array 

1041 blank = np.zeros(np.array((N_orig, np.prod(shape[1:])))) 

1042 master_trans_array_X = calc_overall_trans_probs( 

1043 blank, idx_array, alpha_array, bin_array, offsets, pmv, origin_array 

1044 ) 

1045 return master_trans_array_X 

1046 

1047 def _condition_on_survival(self, master_trans_array_X, matrices_out, N_orig): 

1048 """ 

1049 Condition the master transition array on agent survival. 

1050 

1051 Divides through by the survival probability so that the transition 

1052 array represents the distribution conditional on not dying this period. 

1053 

1054 Parameters 

1055 ---------- 

1056 master_trans_array_X : np.ndarray 

1057 Unconditioned master transition array of shape (N_orig, M) where 

1058 M = prod(cont_M). Reshaped internally to (N_orig, N_orig, 2) 

1059 assuming one binary continuation variable (dead/alive). 

1060 matrices_out : dict 

1061 Per-variable transition matrices; must contain 'dead'. 

1062 N_orig : int 

1063 Number of arrival-state grid points. 

1064 

1065 Returns 

1066 ------- 

1067 master_trans_array_X : np.ndarray 

1068 Survival-conditioned master transition array of shape (N_orig, N_orig). 

1069 """ 

1070 master_trans_array_X = np.reshape(master_trans_array_X, (N_orig, N_orig, 2)) 

1071 survival_probs = np.reshape(matrices_out["dead"][:, 0], [N_orig, 1]) 

1072 master_trans_array_X = master_trans_array_X[..., 0] / survival_probs 

1073 return master_trans_array_X 

1074 

1075 def make_transition_matrices(self, grid_specs, twist=None, norm=None): 

1076 """ 

1077 Construct a transition matrix for this block, moving from a discretized 

1078 grid of arrival variables to a discretized grid of end-of-block variables. 

1079 User specifies how the grids of pre-states should be built. Output is 

1080 stored in attributes of self as follows: 

1081 

1082 - matrices : A dictionary of arrays that cast from the arrival state space 

1083 to the grid of outcome variables. Doing np.dot(dstn, matrices[var]) 

1084 will yield the discretized distribution of that outcome variable. 

1085 - grids : A dictionary of discretized grids for outcome variables. Doing 

1086 np.dot(np.dot(dstn, matrices[var]), grids[var]) yields the *average* 

1087 of that outcome in the population. 

1088 - trans_array : The full-period Markov transition matrix that goes from 

1089 arrival variables in t to arrival variables in t+1, including 

1090 mortality. 

1091 

1092 Parameters 

1093 ---------- 

1094 grid_specs : dict 

1095 Dictionary of dictionaries of grid specifications. The specification 

1096 for a continuous state variable's grid should include an entry for the 

1097 `min` and `max`, as well as a number of gridpoints `N`. The spacing of 

1098 the points is determined by setting `order` (polynomial) or `nest` 

1099 (multi-exponential). If neither is provided, then spacing is linear. 

1100 If only `N` is given, then the gridpoints are the integers 0,...,N. 

1101 Alternatively, a grid specification can provide only the key `custom`, 

1102 with a strictly increasing 1D array as its value. 

1103 twist : dict or None 

1104 Mapping from end-of-period (continuation) variables to successor's 

1105 arrival variables. When this is specified, additional output is created 

1106 for the "full period" arrival-to-arrival transition matrix. 

1107 norm : str or None 

1108 Name of the shock variable by which to normalize for Harmenberg 

1109 aggregation. By default, no normalization happens. 

1110 

1111 Returns 

1112 ------- 

1113 None 

1114 """ 

1115 arrival_N = len(self.arrival) 

1116 

1117 # Build input and output grids from grid specifications 

1118 ( 

1119 grids_in, 

1120 grids_out, 

1121 continuous_grid_out_bool, 

1122 grid_orders, 

1123 grid_nests, 

1124 dummy_grid, 

1125 ) = self._build_input_grids(grid_specs, arrival_N) 

1126 

1127 # If a twist was specified, override output grids for continuation variables 

1128 if twist is not None: 

1129 grids_out, grid_orders, grid_nests, grid_out_is_continuous = ( 

1130 self._build_twist_grids( 

1131 twist, 

1132 grids_in, 

1133 grid_orders, 

1134 grid_nests, 

1135 grids_out, 

1136 continuous_grid_out_bool, 

1137 ) 

1138 ) 

1139 else: 

1140 grid_out_is_continuous = np.array(continuous_grid_out_bool) 

1141 

1142 # Make meshes of all the arrival grids, which will be the initial simulation data 

1143 if arrival_N > 0: 

1144 state_meshes = np.meshgrid( 

1145 *[grids_in[var] for var in self.arrival], indexing="ij" 

1146 ) 

1147 else: # this only happens in the initializer block 

1148 state_meshes = [dummy_grid.copy()] 

1149 state_init = { 

1150 self.arrival[k]: state_meshes[k].flatten() for k in range(arrival_N) 

1151 } 

1152 N_orig = state_meshes[0].size 

1153 mesh_tuples = [ 

1154 [state_init[self.arrival[k]][n] for k in range(arrival_N)] 

1155 for n in range(N_orig) 

1156 ] 

1157 

1158 # Quasi-simulate this block 

1159 self.run_quasi_sim(state_init, norm=norm) 

1160 

1161 # Add survival to output if mortality is in the model 

1162 if "dead" in self.data.keys(): 

1163 grids_out["dead"] = None 

1164 

1165 # Get continuation variable names, making sure they're in the same order 

1166 # as named by the arrival variables. This should maybe be done in the 

1167 # simulator when it's initialized. 

1168 if twist is not None: 

1169 cont_vars_orig = list(twist.keys()) 

1170 temp_dict = {twist[var]: var for var in cont_vars_orig} 

1171 cont_vars = [] 

1172 for var in self.arrival: 

1173 cont_vars.append(temp_dict[var]) 

1174 if "dead" in self.data.keys(): 

1175 cont_vars.append("dead") 

1176 grid_out_is_continuous = np.concatenate( 

1177 (grid_out_is_continuous, [False]) 

1178 ) 

1179 else: 

1180 cont_vars = list(grids_out.keys()) # all outcomes are arrival vars 

1181 D = len(cont_vars) 

1182 

1183 # Project the final results onto the output or result grids 

1184 N = self.N 

1185 J = state_meshes[0].size 

1186 ( 

1187 matrices_out, 

1188 cont_idx, 

1189 cont_alpha, 

1190 cont_M, 

1191 cont_discrete, 

1192 grids_out, 

1193 grid_out_is_continuous, 

1194 ) = self._project_onto_output_grids( 

1195 grids_out, 

1196 grid_out_is_continuous, 

1197 grid_orders, 

1198 grid_nests, 

1199 cont_vars, 

1200 twist, 

1201 N_orig, 

1202 J, 

1203 N, 

1204 ) 

1205 

1206 # Construct the master arrival-to-continuation transition array 

1207 master_trans_array_X = self._build_master_transition_array( 

1208 cont_vars, cont_idx, cont_alpha, cont_M, cont_discrete, N_orig, N, D 

1209 ) 

1210 

1211 # Condition on survival if relevant 

1212 if "dead" in self.data.keys(): 

1213 master_trans_array_X = self._condition_on_survival( 

1214 master_trans_array_X, matrices_out, N_orig 

1215 ) 

1216 

1217 # Reshape the transition matrix depending on what kind of block this is 

1218 if arrival_N == 0: 

1219 # If this is the initializer block, the "transition" matrix is really 

1220 # just the initial distribution of states at model birth; flatten it. 

1221 master_init_array = master_trans_array_X.flatten() 

1222 else: 

1223 # In an ordinary period, reshape the transition array so it's square. 

1224 master_trans_array = np.reshape(master_trans_array_X, (N_orig, N_orig)) 

1225 

1226 # Store the results as attributes of self 

1227 grids = {} 

1228 grids.update(grids_in) 

1229 grids.update(grids_out) 

1230 self.grids = grids 

1231 self.matrices = matrices_out 

1232 self.mesh = mesh_tuples 

1233 if twist is not None: 

1234 self.trans_array = master_trans_array 

1235 if arrival_N == 0: 

1236 self.init_dstn = master_init_array 

1237 

1238 def run_quasi_sim(self, data, j0=0, twist=None, norm=None): 

1239 """ 

1240 "Quasi-simulate" this block from given starting data at some event index, 

1241 looping back to end at the same point (only if j0 > 0 and twist is given). 

1242 To quasi-simulate means to run the model forward for *every* possible shock 

1243 realization, tracking probability masses. 

1244 

1245 If the quasi-simulation loops through the twist, mortality is ignored. 

1246 

1247 Parameters 

1248 ---------- 

1249 data : dict 

1250 Dictionary of initial data, mapping variable names to vectors of values. 

1251 j0 : int, optional 

1252 Event index number at which to start (and end)) the quasi-simulation. 

1253 By default, it is run from index 0. 

1254 twist : dict, optional 

1255 Optional dictionary mapping end-of-block variables back to arrival variables. 

1256 If this is provided *and* j0 > 0, then the quasi-sim is run for a complete 

1257 period, starting and ending at the same index. Else it's run to end of period. 

1258 norm : str or None 

1259 The name of the variable on which to perform Harmenberg normalization. 

1260 

1261 Returns 

1262 ------- 

1263 None 

1264 """ 

1265 # Make the initial vector of probability masses 

1266 if not data: # data is empty because it's initializer block 

1267 N_orig = 1 

1268 else: 

1269 key = list(data.keys())[0] 

1270 N_orig = data[key].size 

1271 self.N = N_orig 

1272 state_init = deepcopy(data) 

1273 state_init["pmv_"] = np.ones(self.N) 

1274 

1275 # Initialize the array of arrival states 

1276 origin_array = np.arange(self.N, dtype=int) 

1277 

1278 # Reset the block's state and give it the initial state data 

1279 self.reset() 

1280 self.data.update(state_init) 

1281 

1282 # Loop through each event in order and quasi-simulate it 

1283 J = len(self.events) 

1284 for j in range(j0, J): 

1285 event = self.events[j] 

1286 event.data = self.data # Give event *all* data directly 

1287 event.N = self.N 

1288 origin_array = event.quasi_run(origin_array, norm=norm) 

1289 self.N = self.data["pmv_"].size 

1290 

1291 # If we didn't start at the beginning and there is a twist, loop back to 

1292 # the start and do the remaining events 

1293 if twist is not None: 

1294 new_data = {"pmv_": self.data["pmv_"].copy()} 

1295 for end_var in twist.keys(): 

1296 arr_var = twist[end_var] 

1297 new_data[arr_var] = self.data[end_var].copy() 

1298 self.data = new_data 

1299 for j in range(j0): 

1300 event = self.events[j] 

1301 event.data = self.data # Give event *all* data directly 

1302 event.N = self.N 

1303 origin_array = event.quasi_run(origin_array, norm=norm) 

1304 self.N = self.data["pmv_"].size 

1305 

1306 # Assign the origin array as an attribute of self 

1307 self.origin_array = origin_array 

1308 

1309 

1310@dataclass(kw_only=True) 

1311class AgentSimulator: 

1312 """ 

1313 A class for representing an entire simulator structure for an AgentType. 

1314 It includes a sequence of SimBlocks representing periods of the model, which 

1315 could be built from the information on an AgentType instance. 

1316 

1317 Parameters 

1318 ---------- 

1319 name : str 

1320 Short name of this model.s 

1321 description : str 

1322 Textual description of what happens in this simulated block. 

1323 statement : str 

1324 Verbatim model statement that was used to create this simulator. 

1325 comments : dict 

1326 Dictionary of comments or descriptions for various model objects. 

1327 parameters : list[str] 

1328 List of parameter names used in the model. 

1329 distributions : list[str] 

1330 List of distribution names used in the model. 

1331 functions : list[str] 

1332 List of function names used in the model. 

1333 common: list[str] 

1334 Names of variables that are common across idiosyncratic agents. 

1335 types: dict 

1336 Dictionary of data types for all variables in the model. 

1337 N_agents: int 

1338 Number of idiosyncratic agents in this simulation. 

1339 T_total: int 

1340 Total number of periods in these agents' model. 

1341 T_sim: int 

1342 Maximum number of periods that will be simulated, determining the size 

1343 of the history arrays. 

1344 T_age: int 

1345 Period after which to automatically terminate an agent if they would 

1346 survive past this period. 

1347 stop_dead : bool 

1348 Whether simulated agents who draw dead=True should actually cease acting. 

1349 Default is True. Setting to False allows "cohort-style" simulation that 

1350 will generate many agents that survive to old ages. In most cases, T_sim 

1351 should not exceed T_age, unless the user really does want multiple succ- 

1352 essive cohorts to be born and fully simulated. 

1353 replace_dead : bool 

1354 Whether simulated agents who are marked as dead should be replaced with 

1355 newborns (default True) or simply cease acting without replacement (False). 

1356 The latter option is useful for models with state-dependent mortality, 

1357 to allow "cohort-style" simulation with the correct distribution of states 

1358 for survivors at each age. Setting to False has no effect if stop_dead is True. 

1359 periods: list[SimBlock] 

1360 Ordered list of simulation blocks, each representing a period. 

1361 twist : dict 

1362 Dictionary that maps period t-1 variables to period t variables, as a 

1363 relabeling "between" periods. 

1364 initializer : SimBlock 

1365 A special simulated block that should have *no* arrival variables, because 

1366 it represents the initialization of "newborn" agents. 

1367 data : dict 

1368 Dictionary that holds *current* values of model variables. 

1369 track_vars : list[str] 

1370 List of names of variables whose history should be tracked in the simulation. 

1371 history : dict 

1372 Dictionary that holds the histories of tracked variables. 

1373 """ 

1374 

1375 name: str = field(default="") 

1376 description: str = field(default="") 

1377 statement: str = field(default="", repr=False) 

1378 comments: dict = field(default_factory=dict, repr=False) 

1379 parameters: list[str] = field(default_factory=list, repr=False) 

1380 distributions: list[str] = field(default_factory=list, repr=False) 

1381 functions: list[str] = field(default_factory=list, repr=False) 

1382 common: list[str] = field(default_factory=list, repr=False) 

1383 types: dict = field(default_factory=dict, repr=False) 

1384 N_agents: int = field(default=1) 

1385 T_total: int = field(default=1, repr=False) 

1386 T_sim: int = field(default=1) 

1387 T_age: int = field(default=0, repr=False) 

1388 stop_dead: bool = field(default=True) 

1389 replace_dead: bool = field(default=True) 

1390 periods: list[SimBlock] = field(default_factory=list, repr=False) 

1391 twist: dict = field(default_factory=dict, repr=False) 

1392 data: dict = field(default_factory=dict, repr=False) 

1393 initializer: field(default_factory=SimBlock, repr=False) 

1394 track_vars: list[str] = field(default_factory=list, repr=False) 

1395 history: dict = field(default_factory=dict, repr=False) 

1396 

1397 def simulate(self, T=None): 

1398 """ 

1399 Simulates the model for T periods, including replacing dead agents as 

1400 warranted and storing tracked variables in the history. If T is not 

1401 specified, the agents are simulated for the entire T_sim periods. 

1402 This is the primary user-facing simulation method. 

1403 """ 

1404 if T is None: 

1405 T = self.T_sim - self.t_sim # All remaining simulated periods 

1406 if (T + self.t_sim) > self.T_sim: 

1407 raise ValueError("Can't simulate more than T_sim periods!") 

1408 

1409 # Execute the simulation loop for T periods 

1410 for t in range(T): 

1411 # Do the ordinary work for simulating a period 

1412 self.sim_one_period() 

1413 

1414 # Mark agents who have reached maximum allowable age 

1415 if "dead" in self.data.keys() and self.T_age > 0: 

1416 too_old = self.t_age == self.T_age 

1417 self.data["dead"][too_old] = True 

1418 

1419 # Record tracked variables and advance age 

1420 self.store_tracked_vars() 

1421 self.advance_age() 

1422 

1423 # Handle death and replacement depending on simulation style 

1424 if "dead" in self.data.keys() and self.stop_dead: 

1425 self.mark_dead_agents() 

1426 self.t_sim += 1 

1427 

1428 def reset(self): 

1429 """ 

1430 Completely reset this simulator back to its original state so that it 

1431 can be run from scratch. This should allow it to generate the same results 

1432 every single time the simulator is run (if nothing changes). 

1433 """ 

1434 N = self.N_agents 

1435 T = self.T_sim 

1436 self.t_sim = 0 # Time index for the simulation 

1437 

1438 # Reset the variable data and history arrays 

1439 self.clear_data() 

1440 self.history = {} 

1441 for var in self.track_vars: 

1442 self.history[var] = np.empty((T, N), dtype=self.types[var]) 

1443 

1444 # Reset all of the blocks / periods 

1445 self.initializer.reset() 

1446 for t in range(len(self.periods)): 

1447 self.periods[t].reset() 

1448 

1449 # Specify all agents as "newborns" assigned to the initializer block 

1450 self.t_seq_bool_array = np.zeros((self.T_total, N), dtype=bool) 

1451 self.t_age = -np.ones(N, dtype=int) 

1452 

1453 def clear_data(self, skip=None): 

1454 """ 

1455 Reset all current data arrays back to blank, other than those designated 

1456 to be skipped, if any. 

1457 

1458 Parameters 

1459 ---------- 

1460 skip : [str] or None 

1461 Names of variables *not* to be cleared from data. Default is None. 

1462 

1463 Returns 

1464 ------- 

1465 None 

1466 """ 

1467 if skip is None: 

1468 skip = [] 

1469 N = self.N_agents 

1470 for var in self.types.keys(): 

1471 if var in skip: 

1472 continue 

1473 this_type = self.types[var] 

1474 if this_type is float: 

1475 self.data[var] = np.full((N,), np.nan) 

1476 elif this_type is bool: 

1477 self.data[var] = np.zeros((N,), dtype=bool) 

1478 elif this_type is int: 

1479 self.data[var] = np.zeros((N,), dtype=np.int32) 

1480 elif this_type is complex: 

1481 self.data[var] = np.full((N,), np.nan, dtype=complex) 

1482 else: 

1483 raise ValueError( 

1484 "Type " 

1485 + str(this_type) 

1486 + " of variable " 

1487 + var 

1488 + " was not recognized!" 

1489 ) 

1490 

1491 def mark_dead_agents(self): 

1492 """ 

1493 Looks at the special data field "dead" and marks those agents for replacement. 

1494 If no variable called "dead" has been defined, this is skipped. 

1495 """ 

1496 who_died = self.data["dead"] 

1497 self.t_seq_bool_array[:, who_died] = False 

1498 self.t_age[who_died] = -1 

1499 

1500 def create_newborns(self): 

1501 """ 

1502 Calls the initializer to generate newborns where needed. 

1503 """ 

1504 # Skip this step if there are no newborns 

1505 newborns = self.t_age == -1 

1506 if not np.any(newborns): 

1507 return 

1508 

1509 # Generate initial arrival variables 

1510 N = np.sum(newborns) 

1511 self.initializer.data = {} # by definition 

1512 self.initializer.N = N 

1513 self.initializer.run() 

1514 

1515 # Set the initial arrival data for newborns and clear other variables 

1516 init_arrival = self.periods[0].arrival 

1517 for var in self.types: 

1518 self.data[var][newborns] = ( 

1519 self.initializer.data[var] 

1520 if var in init_arrival 

1521 else np.empty(N, dtype=self.types[var]) 

1522 ) 

1523 

1524 # Set newborns' period to 0 

1525 self.t_age[newborns] = 0 

1526 self.t_seq_bool_array[0, newborns] = True 

1527 

1528 def store_tracked_vars(self): 

1529 """ 

1530 Record current values of requested variables in the history dictionary. 

1531 """ 

1532 for var in self.track_vars: 

1533 self.history[var][self.t_sim, :] = self.data[var] 

1534 

1535 def advance_age(self): 

1536 """ 

1537 Increments age for all agents, altering t_age and t_age_bool. Agents in 

1538 the last period of the sequence will be assigned to the initial period. 

1539 In a lifecycle model, those agents should be marked as dead and replaced 

1540 in short order. 

1541 """ 

1542 alive = self.t_age >= 0 # Don't age the dead 

1543 self.t_age[alive] += 1 

1544 X = self.t_seq_bool_array # For shorter typing on next line 

1545 self.t_seq_bool_array[:, alive] = np.concatenate( 

1546 (X[-1:, alive], X[:-1, alive]), axis=0 

1547 ) 

1548 

1549 def sim_one_period(self): 

1550 """ 

1551 Simulates one period of the model by advancing all agents one period. 

1552 This includes creating newborns, but it does NOT include eliminating 

1553 dead agents nor storing tracked results in the history. This method 

1554 should usually not be called by a user, instead using simulate(1) if 

1555 you want to run the model for exactly one period. 

1556 """ 

1557 # Use the "twist" information to advance last period's end-of-period 

1558 # information/values to be the arrival variables for this period. Then, for 

1559 # each variable other than those brought in with the twist, wipe it clean. 

1560 keepers = [] 

1561 for var_tm1 in self.twist: 

1562 var_t = self.twist[var_tm1] 

1563 keepers.append(var_t) 

1564 self.data[var_t] = self.data[var_tm1].copy() 

1565 self.clear_data(skip=keepers) 

1566 

1567 # Create newborns first so the arrival vars exist. This should be done in 

1568 # the first simulated period (t_sim=0) or if decedents should be replaced. 

1569 if self.replace_dead or self.t_sim == 0: 

1570 self.create_newborns() 

1571 

1572 # Loop through ages and run the model on the appropriately aged agents 

1573 for t in range(self.T_total): 

1574 these = self.t_seq_bool_array[t, :] 

1575 if not np.any(these): 

1576 continue # Skip any "empty ages" 

1577 this_period = self.periods[t] 

1578 

1579 data_temp = {var: self.data[var][these] for var in this_period.arrival} 

1580 this_period.data = data_temp 

1581 this_period.N = np.sum(these) 

1582 this_period.run() 

1583 

1584 # Extract all of the variables from this period and write it to data 

1585 for var in this_period.data.keys(): 

1586 self.data[var][these] = this_period.data[var] 

1587 

1588 # Put time information into the data dictionary 

1589 self.data["t_age"] = self.t_age.copy() 

1590 self.data["t_seq"] = np.argmax(self.t_seq_bool_array, axis=0).astype(int) 

1591 

1592 def make_transition_matrices( 

1593 self, grid_specs, norm=None, fake_news_timing=False, for_t=None 

1594 ): 

1595 """ 

1596 Build Markov-style transition matrices for each period of the model, as 

1597 well as the initial distribution of arrival variables for newborns. 

1598 Stores results to the attributes of self as follows: 

1599 

1600 - trans_arrays : List of Markov matrices for transitioning from the arrival 

1601 state space in period t to the arrival state space in t+1. 

1602 This transition includes death (and replacement). 

1603 - newborn_dstn : Stochastic vector as a NumPy array, representing the distribution 

1604 of arrival states for "newborns" who were just initialized. 

1605 - state_grids : Nested list of tuples representing the arrival state space for 

1606 each period. Each element corresponds to the discretized arrival 

1607 state space point with the same index in trans_arrays (and 

1608 newborn_dstn). Arrival states are ordered within a tuple in the 

1609 same order as the model file. Linked from period[t].mesh. 

1610 - outcome_arrays : List of dictionaries of arrays that cast from the arrival 

1611 state space to the grid of outcome variables, for each period. 

1612 Doing np.dot(state_dstn, outcome_arrays[t][var]) will yield 

1613 the discretized distribution of that outcome variable. Linked 

1614 from periods[t].matrices. 

1615 - outcome_grids : List of dictionaries of discretized outcomes in each period. 

1616 Keys are names of outcome variables, and entries are vectors 

1617 of discretized values that the outcome variable can take on. 

1618 Doing np.dot(np.dot(state_dstn, outcome_arrays[var]), outcome_grids[var]) 

1619 yields the *average* of that outcome in the population. Linked 

1620 from periods[t].grids. 

1621 

1622 Parameters 

1623 ---------- 

1624 grid_specs : dict 

1625 Dictionary of dictionaries with specifications for discretized grids 

1626 of all variables of interest. If any arrival variables are omitted, 

1627 they will be given a default trivial grid with one node at 0. This 

1628 should only be done if that arrival variable is closely tied to the 

1629 Harmenberg normalizing variable; see below. A grid specification must 

1630 include a number of gridpoints N, and should also include a min and 

1631 max if the variable is continuous. If the variable is discrete, the 

1632 grid values are assumed to be 0,..,N. 

1633 norm : str or None 

1634 Name of the variable for which Harmenberg normalization should be 

1635 applied, if any. This should be a variable that is directly drawn 

1636 from a distribution, not a "downstream" variable. 

1637 fake_news_timing : bool 

1638 Indicator for whether this call is part of the "fake news" algorithm 

1639 for constructing sequence space Jacobians (SSJs). This should only 

1640 ever be set to True in that situation, which affects how mortality 

1641 is handled between periods. In short, the simulator usually assumes 

1642 that "newborns" start with t_seq=0, but during the fake news algorithm, 

1643 that is not the case. 

1644 for_t : list or None 

1645 Optional list of time indices for which the matrices should be built. 

1646 When not specified, all periods are constructed. The most common use 

1647 for this arg is during the "fake news" algorithm for lifecycle models. 

1648 

1649 Returns 

1650 ------- 

1651 None 

1652 """ 

1653 # Sort grid specifications into those needed by the initializer vs those 

1654 # used by other blocks (ordinary periods) 

1655 arrival = self.periods[0].arrival 

1656 arrival_N = len(arrival) 

1657 check_bool = np.zeros(arrival_N, dtype=bool) 

1658 grid_specs_init_orig = {} 

1659 grid_specs_other = {} 

1660 for name in grid_specs.keys(): 

1661 if name in arrival: 

1662 idx = arrival.index(name) 

1663 check_bool[idx] = True 

1664 grid_specs_init_orig[name] = copy(grid_specs[name]) 

1665 grid_specs_other[name] = copy(grid_specs[name]) 

1666 

1667 # Build the dictionary of arrival variables, making sure it's in the 

1668 # same order as named self.arrival. For any arrival grids that are 

1669 # not specified, make a dummy specification. 

1670 grid_specs_init = {} 

1671 for n in range(arrival_N): 

1672 name = arrival[n] 

1673 if check_bool[n]: 

1674 grid_specs_init[name] = grid_specs_init_orig[name] 

1675 continue 

1676 dummy_grid_spec = {"N": 1} 

1677 grid_specs_init[name] = dummy_grid_spec 

1678 grid_specs_other[name] = dummy_grid_spec 

1679 

1680 # Make the initial state distribution for newborns 

1681 self.initializer.make_transition_matrices(grid_specs_init) 

1682 self.newborn_dstn = self.initializer.init_dstn 

1683 K = self.newborn_dstn.size 

1684 

1685 # Make the period-by-period transition matrices 

1686 these_t = list(range(len(self.periods))) if for_t is None else for_t 

1687 for t in these_t: 

1688 block = self.periods[t] 

1689 block.make_transition_matrices( 

1690 grid_specs_other, twist=self.twist, norm=norm 

1691 ) 

1692 block.reset() 

1693 self.grid_specs = grid_specs_other 

1694 self.norm = norm 

1695 

1696 # Extract the master transition matrices into a single list 

1697 p2p_trans_arrays = [self.periods[t].trans_array for t in these_t] 

1698 

1699 # Apply agent replacement to the last period of the model, representing 

1700 # newborns filling in for decedents. This will usually only do anything 

1701 # at all in "one period infinite horizon" models. If this is part of the 

1702 # fake news algorithm for constructing SSJs, then replace decedents with 

1703 # newborns in *all* periods, because model timing is funny in this case. 

1704 if fake_news_timing: 

1705 T_set = np.arange(len(self.periods)).tolist() 

1706 elif for_t is None: 

1707 T_set = [len(self.periods) - 1] 

1708 else: 

1709 T_set = [] 

1710 newborn_dstn = np.reshape(self.newborn_dstn, (1, K)) 

1711 for t in T_set: 

1712 if t not in these_t: 

1713 continue 

1714 if "dead" not in self.periods[t].matrices.keys(): 

1715 continue 

1716 death_prbs = self.periods[t].matrices["dead"][:, 1] 

1717 p2p_trans_arrays[t] *= np.tile(np.reshape(1 - death_prbs, (K, 1)), (1, K)) 

1718 p2p_trans_arrays[t] += np.reshape(death_prbs, (K, 1)) * newborn_dstn 

1719 

1720 # Store the transition arrays as attributes of self 

1721 self.trans_arrays = p2p_trans_arrays 

1722 

1723 # Build and store lists of state meshes, outcome arrays, and outcome grids 

1724 self.state_grids = [self.periods[t].mesh for t in these_t] 

1725 self.outcome_grids = [self.periods[t].grids for t in these_t] 

1726 self.outcome_arrays = [self.periods[t].matrices for t in these_t] 

1727 

1728 def find_steady_state(self): 

1729 """ 

1730 Calculates the steady state distribution of arrival states for a "one period 

1731 infinite horizon" model, storing the result to the attribute steady_state_dstn. 

1732 Should only be run after make_transition_matrices(), and only if T_total = 1 

1733 and the model is infinite horizon. 

1734 """ 

1735 if self.T_total != 1: 

1736 raise ValueError( 

1737 "This method currently only works with one period infinite horizon problems." 

1738 ) 

1739 

1740 # Find the eigenvector associated with the largest eigenvalue of the 

1741 # infinite horizon transition matrix. The largest eigenvalue *should* 

1742 # be 1 for any Markov matrix, but double check to be sure. 

1743 trans_T = csr_matrix(self.trans_arrays[0].transpose()) 

1744 v, V = eigs(trans_T, k=1) 

1745 if not np.isclose(v[0], 1.0): 

1746 raise ValueError( 

1747 "The largest eigenvalue of the transition matrix isn't close to 1!" 

1748 ) 

1749 

1750 # Normalize that eigenvector and make sure its real, then store it 

1751 D = V[:, 0] 

1752 SS_dstn = (D / np.sum(D)).real 

1753 self.steady_state_dstn = SS_dstn 

1754 

1755 def get_long_run_dstn(self, var): 

1756 """ 

1757 Calculate and return the long run / steady state population distribution 

1758 of one named variable. Should only be run after find_steady_state(). 

1759 

1760 Parameters 

1761 ---------- 

1762 var : str 

1763 Name of the variable for which to calculate the long run distribution. 

1764 

1765 Returns 

1766 ------- 

1767 var_dstn : np.array 

1768 Long run / steady state population distribution of the variable, 

1769 as a stochastic vector defined on the variable's discretized grid. 

1770 """ 

1771 if not hasattr(self, "steady_state_dstn"): 

1772 raise ValueError("This method can only be run after find_steady_state()!") 

1773 

1774 dstn = self.steady_state_dstn 

1775 array = self.outcome_arrays[0][var] 

1776 var_dstn = np.dot(dstn, array) 

1777 return var_dstn 

1778 

1779 def get_long_run_average(self, var): 

1780 """ 

1781 Calculate and return the long run / steady state population average of 

1782 one named variable. Should only be run after find_steady_state(). 

1783 

1784 Parameters 

1785 ---------- 

1786 var : str 

1787 Name of the variable for which to calculate the long run average. 

1788 

1789 Returns 

1790 ------- 

1791 var_mean : float 

1792 Long run / steady state population average of the variable. 

1793 """ 

1794 var_dstn = self.get_long_run_dstn(var) 

1795 grid = self.outcome_grids[0][var] 

1796 var_mean = np.dot(var_dstn, grid) 

1797 return var_mean 

1798 

1799 def find_target_state(self, target_var, bounds=None, N=201, tol=1e-8, **kwargs): 

1800 """ 

1801 Find the "target" level of a state variable: the value such that the expectation 

1802 of next period's state is the same value (when following the policy function), 

1803 *and* is locally stable (pushes up from below and down from above when nearby). 

1804 Only works for standard infinite horizon models with a single endogenous state 

1805 variable. Other variables whose values must be known (e.g. exogenously evolving 

1806 states) can also be specified. 

1807 

1808 The search procedure is to first examine a grid of candidates on the bounds, 

1809 calculating E[Delta x] for state x, and then perform a local search for each 

1810 interval where it flips from positive to negative. 

1811 

1812 This procedure ignores mortality entirely. It represents a stable or target 

1813 level conditional on the agent continuing from t to t+1. 

1814 

1815 If additional information must be known, other model variables can be passed 

1816 as keyword arguments, e.g. pLvl=1.0. This feature is used for exogenous state 

1817 variables, such as persistent income pLvl in the GenIncProcess model. The user 

1818 simply passes its mean (central) value, which is easily known in advance. 

1819 

1820 Parameters 

1821 ---------- 

1822 target_var : str 

1823 Name of the state variable of interest. 

1824 bounds : [float], optional 

1825 Upper and lower boundaries for the target search. If not provided, defaults 

1826 to [0.0, 100.0]. 

1827 N : int, optional 

1828 Number of values of the variable of interest to test on the initial pass. 

1829 If not provided, defaults to 201. This affects the "resolution" when there 

1830 are multiple possible target levels (uncommon). 

1831 tol : float, optional 

1832 Maximum acceptable deviation from true target E[Delta x] = 0 to be accepted. 

1833 If not specified, defaults to 1e-8. 

1834 

1835 Returns 

1836 ------- 

1837 state_targ : [float] 

1838 List of target_var x values such that E[Delta x] = 0, which can be empty. 

1839 """ 

1840 if self.T_total != 1: 

1841 raise ValueError( 

1842 "This method currently only works with one period infinite horizon problems." 

1843 ) 

1844 bounds = bounds or [0.0, 100.0] 

1845 state_grid = np.linspace(bounds[0], bounds[1], num=N) 

1846 

1847 period = self.periods[0] 

1848 fixed = self._validate_fixed_kwargs(kwargs) 

1849 idx0 = self._find_search_start_event(target_var, fixed, period) 

1850 data_init, trivial_vars = self._make_target_data_init( 

1851 target_var, fixed, period, state_grid, idx0, N 

1852 ) 

1853 

1854 # Run the quasi-simulation on the initial grid of states 

1855 period.run_quasi_sim(data_init, j0=idx0, twist=self.twist) 

1856 E_delta_state = self._compute_expected_delta(period, target_var, state_grid, N) 

1857 

1858 # Find indices in the grid where E[Delta x] flips from positive to negative 

1859 sign = E_delta_state > 0.0 

1860 flip_idx = np.argwhere( 

1861 np.logical_and(sign[:-1], np.logical_not(sign[1:])) 

1862 ).flatten() 

1863 if flip_idx.size == 0: 

1864 return [] 

1865 

1866 # Reduce the fixed values in data_init to single valued vectors 

1867 for var in trivial_vars: 

1868 data_init[var] = np.array([0]) 

1869 for key in fixed.keys(): 

1870 data_init[key] = np.array([fixed[key]]) 

1871 

1872 def delta_zero_func(x): 

1873 data_init[target_var] = np.array([x]) 

1874 period.run_quasi_sim(data_init, j0=idx0, twist=self.twist) 

1875 E_delta = np.dot(period.data["pmv_"], period.data[target_var]) - x 

1876 return E_delta 

1877 

1878 state_targ = [] 

1879 for i in flip_idx: 

1880 x_targ = brentq( 

1881 delta_zero_func, state_grid[i], state_grid[i + 1], xtol=tol, rtol=tol 

1882 ) 

1883 state_targ.append(x_targ) 

1884 return state_targ 

1885 

1886 def _validate_fixed_kwargs(self, kwargs): 

1887 var_names = list(self.types.keys()) 

1888 fixed = {} 

1889 for name in kwargs: 

1890 if name not in var_names: 

1891 raise ValueError( 

1892 "Could not find a model variable called " + name + " to hold fixed!" 

1893 ) 

1894 fixed[name] = kwargs[name] 

1895 return fixed 

1896 

1897 @staticmethod 

1898 def _find_search_start_event(target_var, fixed, period): 

1899 if target_var in period.arrival: 

1900 return 0 

1901 var_names = [target_var] + list(fixed.keys()) 

1902 var_count = len(var_names) 

1903 found = [False] * var_count 

1904 found_count = 0 

1905 j = 0 

1906 event_count = len(period.events) 

1907 while (found_count < var_count) and (j < event_count): 

1908 assigns = period.events[j].assigns 

1909 for i in range(var_count): 

1910 if (var_names[i] in assigns) and (not found[i]): 

1911 found[i] = True 

1912 found_count += 1 

1913 j += 1 

1914 if not np.all(found): 

1915 raise ValueError( 

1916 "Could not find events that assign target variable and all fixed variables!" 

1917 ) 

1918 return j 

1919 

1920 @staticmethod 

1921 def _make_target_data_init(target_var, fixed, period, state_grid, idx0, N): 

1922 data_init = {} 

1923 trivial_vars = [] 

1924 for var in period.arrival: 

1925 data_init[var] = np.zeros(N, dtype=int) 

1926 trivial_vars.append(var) 

1927 for j in range(idx0): 

1928 for var in period.events[j].assigns: 

1929 data_init[var] = np.zeros(N, dtype=int) 

1930 trivial_vars.append(var) 

1931 for key, val in fixed.items(): 

1932 data_init[key] = val * np.ones(N) 

1933 data_init[target_var] = state_grid 

1934 return data_init, trivial_vars 

1935 

1936 @staticmethod 

1937 def _compute_expected_delta(period, target_var, state_grid, N): 

1938 origins = period.origin_array 

1939 data_final = period.data[target_var] 

1940 pmv_final = period.data["pmv_"] 

1941 E_state_next = np.empty(N) 

1942 for n in range(N): 

1943 these = origins == n 

1944 E_state_next[n] = np.dot(pmv_final[these], data_final[these]) 

1945 return E_state_next - state_grid 

1946 

1947 def _validate_shock_call(self, calc_dstn, calc_avg, shock, from_dstn): 

1948 if not (calc_dstn or calc_avg): 

1949 raise ValueError( 

1950 "At least one of calc_dstn or calc_avg must be true, or there's no work!" 

1951 ) 

1952 if (shock is None) and (from_dstn is None): 

1953 raise ValueError( 

1954 "The shock or from_dstn must be specified, or there's nothing to simulate!" 

1955 ) 

1956 if self.T_total != 1: 

1957 raise ValueError( 

1958 "simulate_shock_by_grids is only implemented for infinite-horizon models with T_total == 1." 

1959 ) 

1960 if not hasattr(self, "trans_arrays"): 

1961 raise KeyError( 

1962 "This method can't be run before running make_transition_matrices!" 

1963 ) 

1964 

1965 @staticmethod 

1966 def _normalize_shock_args(shock, outcomes): 

1967 if shock is None: 

1968 shock = [] 

1969 if type(shock) is str: 

1970 shock = [shock] 

1971 if isinstance(outcomes, str): 

1972 outcomes = [outcomes] 

1973 return shock, outcomes 

1974 

1975 def _resolve_initial_dstn(self, from_dstn): 

1976 if from_dstn is None: 

1977 if not hasattr(self, "steady_state_dstn"): 

1978 self.find_steady_state() 

1979 return self.steady_state_dstn 

1980 dstn_sum = np.sum(from_dstn) 

1981 dstn_N = from_dstn.size 

1982 if not np.isclose(dstn_sum, 1.0): 

1983 raise ValueError( 

1984 "Specified from_dstn should be a stochastic vector, but its values sum to " 

1985 + str(dstn_sum) 

1986 ) 

1987 arrival_N = len(self.state_grids[0]) 

1988 if arrival_N != dstn_N: 

1989 raise ValueError( 

1990 "Specified from_dstn should be a vector of size " 

1991 + str(arrival_N) 

1992 + ", but has size " 

1993 + str(dstn_N) 

1994 + "!" 

1995 ) 

1996 return from_dstn 

1997 

1998 def _parse_shock_statement(self, S): 

1999 op = next((c for c in ("+", "*", "=") if c in S), None) 

2000 if op is None: 

2001 raise ValueError( 

2002 "The shock statement (" + S + ") did not contain a valid operator!" 

2003 ) 

2004 loc = S.index(op) 

2005 var = S[:loc].strip() 

2006 val = S[(loc + 1) :].strip() 

2007 if var not in self.twist: 

2008 raise KeyError( 

2009 "All shocked variables must be continuation states, but " 

2010 + var 

2011 + " is not!" 

2012 ) 

2013 try: 

2014 float(val) 

2015 except (ValueError, TypeError): 

2016 raise ValueError("Couldn't interpret " + val + " as a number!") 

2017 return var, op, val 

2018 

2019 def _build_shock_event_strings(self, shock): 

2020 event_strings = [] 

2021 shock_vars = [] 

2022 for S in shock: 

2023 var, op, val = self._parse_shock_statement(S) 

2024 var_alt = self.twist[var] 

2025 if op == "+": 

2026 this_event = var + " = " + var_alt + " + " + val 

2027 elif op == "*": 

2028 this_event = var + " = " + var_alt + " * " + val 

2029 else: 

2030 this_event = var + " = " + var_alt + " * 0.0 + " + val 

2031 event_strings.append(this_event) 

2032 shock_vars.append(var) 

2033 for var in self.twist.keys(): 

2034 if var in shock_vars: 

2035 continue 

2036 event_strings.append(var + " = " + self.twist[var]) 

2037 return event_strings 

2038 

2039 def _apply_shock_block(self, init_dstn, event_strings): 

2040 grid_specs_temp = { 

2041 var: self.grid_specs[var] 

2042 for var in self.grid_specs 

2043 if (var in self.twist) or (var in self.periods[0].arrival) 

2044 } 

2045 shock_model = {"name": "exogenous shock", "dynamics": event_strings} 

2046 shock_block, _info, _offset, _solution, _comments = make_template_block( 

2047 shock_model, arrival=self.periods[0].arrival 

2048 ) 

2049 shock_block.make_transition_matrices(grid_specs_temp, twist=self.twist) 

2050 shock_block.reset() 

2051 return np.dot(init_dstn, shock_block.trans_array) 

2052 

2053 def _stream_avg_history(self, init_dstn, outcomes, T, trans_array): 

2054 outcome_arrays_0 = self.outcome_arrays[0] 

2055 outcome_grids_0 = self.outcome_grids[0] 

2056 history_avg = {name: np.empty(T) for name in outcomes} 

2057 current_dstn = init_dstn.copy() 

2058 for t in range(T): 

2059 for name in outcomes: 

2060 this_dstn_t = np.dot(current_dstn, outcome_arrays_0[name]) 

2061 history_avg[name][t] = np.dot(outcome_grids_0[name], this_dstn_t) 

2062 current_dstn = current_dstn @ trans_array 

2063 return history_avg 

2064 

2065 def _full_dstn_history(self, init_dstn, outcomes, T, trans_array, calc_avg): 

2066 current_dstn = init_dstn.copy() 

2067 state_dstn_by_t = np.empty((current_dstn.size, T)) 

2068 for t in range(T): 

2069 state_dstn_by_t[:, t] = current_dstn 

2070 current_dstn = current_dstn @ trans_array 

2071 history_dstn = {} 

2072 history_avg = {} 

2073 for name in outcomes: 

2074 this_outcome = self.outcome_arrays[0][name] 

2075 this_dstn = np.dot(this_outcome.T, state_dstn_by_t) 

2076 history_dstn[name] = this_dstn 

2077 if calc_avg: 

2078 history_avg[name] = np.dot(self.outcome_grids[0][name], this_dstn) 

2079 return history_dstn, history_avg 

2080 

2081 def simulate_shock_by_grids( 

2082 self, 

2083 outcomes, 

2084 T, 

2085 shock=None, 

2086 from_dstn=None, 

2087 calc_dstn=False, 

2088 calc_avg=True, 

2089 ): 

2090 """ 

2091 Generate the time series of population outcomes in response to an unexpected 

2092 shock. The shock can be specified as additive or multiplicative events that 

2093 are applied to the steady state distribution of arrival states, or as a user- 

2094 specified distribution of arrival states. This method is intended only for 

2095 infinite horizon, single period models. Stores results in the dictionary 

2096 attributes history_avg and history_dstn respectively. 

2097 

2098 This method can only be run after running make_transition_matrices. 

2099 

2100 Parameters 

2101 ---------- 

2102 outcomes : str or [str] 

2103 Names of one or more outcome variables 

2104 T : int 

2105 Number of periods to simulate after the shock. 

2106 shock : str or [str], optional 

2107 One or more of "shock operations" to be applied to the steady state 

2108 (or custom distribution, if specified). Each shock operation should 

2109 name a continuation variable (something named on the left side of 

2110 the twist), and be followed by an operator and a value. At this time, 

2111 the only valid operators are "+", "*", and "=". For example, the shock 

2112 "aNrm + 0.1" means that 0.1 should be added to (the distribution of) 

2113 end-of-period assets, while "pLvl * 0.8" means that permanent income 

2114 should be reduced by 20% for the entire population. The "=" operator 

2115 shifts the entire population to the specified value. Not all arrival 

2116 variables must be named in this argument. Indeed, none need be named. 

2117 The numeric value should *not* use scientific notation nor other math 

2118 operations; e.g. use "0.0001" and not "1e-4". 

2119 from_dstn : np.array, optional 

2120 If provided, a user-specified distribution of arrival states. If none is 

2121 given (typical), then the steady state distribution is used. Any shocks 

2122 described in shock are applied to this initial distribution. 

2123 calc_dstn : bool, optional 

2124 Whether to store the distribution of outcomes over time in history_dstn. 

2125 The default is False. 

2126 calc_avg : bool, optional 

2127 Whether to store the population average of the outcomes over time in 

2128 history_avg. The default is True. 

2129 

2130 Returns 

2131 ------- 

2132 None. 

2133 """ 

2134 self._validate_shock_call(calc_dstn, calc_avg, shock, from_dstn) 

2135 shock, outcomes = self._normalize_shock_args(shock, outcomes) 

2136 init_dstn = self._resolve_initial_dstn(from_dstn) 

2137 event_strings = self._build_shock_event_strings(shock) 

2138 init_dstn = self._apply_shock_block(init_dstn, event_strings) 

2139 

2140 trans_array = csc_matrix(self.trans_arrays[0]) 

2141 if calc_dstn: 

2142 history_dstn, history_avg = self._full_dstn_history( 

2143 init_dstn, outcomes, T, trans_array, calc_avg 

2144 ) 

2145 else: 

2146 history_dstn = {} 

2147 history_avg = self._stream_avg_history(init_dstn, outcomes, T, trans_array) 

2148 

2149 self.history_dstn = history_dstn 

2150 self.history_avg = history_avg 

2151 

2152 def simulate_cohort_by_grids( 

2153 self, 

2154 outcomes, 

2155 T_max=None, 

2156 calc_dstn=False, 

2157 calc_avg=True, 

2158 from_dstn=None, 

2159 ): 

2160 """ 

2161 Generate a simulated "cohort style" history for this type of agents using 

2162 discretized grid methods. Can only be run after running make_transition_matrices(). 

2163 Starting from the distribution of states at birth, the population is moved 

2164 forward in time via the transition matrices, and the distribution and/or 

2165 average of specified outcomes are stored in the dictionary attributes 

2166 history_dstn and history_avg respectively. 

2167 

2168 Parameters 

2169 ---------- 

2170 outcomes : str or [str] 

2171 Names of one or more outcome variables to be tracked during the grid 

2172 simulation. Each named variable should have an outcome grid specified 

2173 when make_transition_matrices() was called, whether explicitly or 

2174 implicitly. The existence of these grids is checked as a first step. 

2175 T_max : int or None 

2176 If specified, the number of periods of the model to actually generate 

2177 output for. If not specified, all periods are run. 

2178 calc_dstn : bool 

2179 Whether outcome distributions should be stored in the dictionary 

2180 attribute history_dstn. The default is False. 

2181 calc_avg : bool 

2182 Whether outcome averages should be stored in the dictionary attribute 

2183 history_avg. The default is True. 

2184 from_dstn : np.array or None 

2185 Optional initial distribution of arrival states. If not specified, the 

2186 newborn distribution in the initializer is assumed to be used. 

2187 

2188 Returns 

2189 ------- 

2190 None 

2191 """ 

2192 T_max = self._validate_cohort_call(T_max) 

2193 if not (calc_dstn or calc_avg): 

2194 return 

2195 

2196 if isinstance(outcomes, str): 

2197 outcomes = [outcomes] 

2198 history_dstn = {name: [] for name in outcomes} if calc_dstn else None 

2199 history_avg = {name: np.empty(T_max) for name in outcomes} if calc_avg else None 

2200 

2201 current_dstn = ( 

2202 self.newborn_dstn.copy() if from_dstn is None else from_dstn.copy() 

2203 ) 

2204 state_dstn_by_age = [] 

2205 

2206 for t in range(T_max): 

2207 state_dstn_by_age.append(current_dstn) 

2208 self._record_cohort_outcomes( 

2209 t, current_dstn, outcomes, history_dstn, history_avg 

2210 ) 

2211 current_dstn = np.dot(self.trans_arrays[t].T, current_dstn) 

2212 

2213 if calc_dstn: 

2214 self._stack_uniform_dstns(outcomes, history_dstn) 

2215 

2216 self.state_dstn_by_age = state_dstn_by_age 

2217 if calc_dstn: 

2218 self.history_dstn = history_dstn 

2219 if calc_avg: 

2220 self.history_avg = history_avg 

2221 

2222 def _validate_cohort_call(self, T_max): 

2223 if not hasattr(self, "newborn_dstn"): 

2224 raise ValueError( 

2225 "The newborn state distribution does not exist; make_transition_matrices() must be run before grid simulations!" 

2226 ) 

2227 if not hasattr(self, "trans_arrays"): 

2228 raise ValueError( 

2229 "The transition arrays do not exist; make_transition_matrices() must be run before grid simulations!" 

2230 ) 

2231 if T_max is None: 

2232 T_max = self.T_total 

2233 T_max = np.minimum(T_max, self.T_total) 

2234 if len(self.trans_arrays) < T_max: 

2235 raise ValueError( 

2236 "There are somehow fewer elements of trans_array than there should be!" 

2237 ) 

2238 return T_max 

2239 

2240 def _record_cohort_outcomes( 

2241 self, t, current_dstn, outcomes, history_dstn, history_avg 

2242 ): 

2243 for name in outcomes: 

2244 this_outcome = self.periods[t].matrices[name].transpose() 

2245 this_dstn = np.dot(this_outcome, current_dstn) 

2246 if history_dstn is not None: 

2247 history_dstn[name].append(this_dstn) 

2248 if history_avg is not None: 

2249 history_avg[name][t] = np.dot(this_dstn, self.periods[t].grids[name]) 

2250 

2251 @staticmethod 

2252 def _stack_uniform_dstns(outcomes, history_dstn): 

2253 for name in outcomes: 

2254 dstn_sizes = np.array([dstn.size for dstn in history_dstn[name]]) 

2255 if np.all(dstn_sizes == dstn_sizes[0]): 

2256 history_dstn[name] = np.stack(history_dstn[name], axis=1) 

2257 

2258 def describe_model(self, display=True): 

2259 """ 

2260 Convenience method that prints model information to screen. 

2261 """ 

2262 # Make a twist statement 

2263 twist_statement = "" 

2264 for var_tm1 in self.twist.keys(): 

2265 var_t = self.twist[var_tm1] 

2266 new_line = var_tm1 + "[t-1] <---> " + var_t + "[t]\n" 

2267 twist_statement += new_line 

2268 

2269 # Assemble the overall model statement 

2270 output = "" 

2271 output += "----------------------------------\n" 

2272 output += "%%%%% INITIALIZATION AT BIRTH %%%%\n" 

2273 output += "----------------------------------\n" 

2274 output += self.initializer.statement 

2275 output += "----------------------------------\n" 

2276 output += "%%%% DYNAMICS WITHIN PERIOD t %%%%\n" 

2277 output += "----------------------------------\n" 

2278 output += self.statement 

2279 output += "----------------------------------\n" 

2280 output += "%%%%%%% RELABELING / TWIST %%%%%%%\n" 

2281 output += "----------------------------------\n" 

2282 output += twist_statement 

2283 output += "-----------------------------------" 

2284 

2285 # Return or print the output 

2286 if display: 

2287 print(output) 

2288 return 

2289 else: 

2290 return output 

2291 

2292 def describe_symbols(self, display=True): 

2293 """ 

2294 Convenience method that prints symbol information to screen. 

2295 """ 

2296 # Get names and types 

2297 symbols_lines = [] 

2298 comments = [] 

2299 for key in self.comments.keys(): 

2300 comments.append(self.comments[key]) 

2301 

2302 # Get type of object 

2303 if key in self.types.keys(): 

2304 this_type = str(self.types[key].__name__) 

2305 elif key in self.distributions: 

2306 this_type = "dstn" 

2307 elif key in self.parameters: 

2308 this_type = "param" 

2309 elif key in self.functions: 

2310 this_type = "func" 

2311 

2312 # Add tags 

2313 if key in self.common: 

2314 this_type += ", common" 

2315 # if key in self.solution: 

2316 # this_type += ', solution' 

2317 this_line = key + " (" + this_type + ")" 

2318 symbols_lines.append(this_line) 

2319 

2320 # Add comments, aligned 

2321 symbols_text = "" 

2322 longest = np.max([len(this) for this in symbols_lines]) 

2323 for j in range(len(symbols_lines)): 

2324 line = symbols_lines[j] 

2325 comment = comments[j] 

2326 L = len(line) 

2327 pad = (longest + 1) - L 

2328 symbols_text += line + pad * " " + ": " + comment + "\n" 

2329 

2330 # Return or print the output 

2331 output = symbols_text 

2332 if display: 

2333 print(output) 

2334 return 

2335 else: 

2336 return output 

2337 

2338 def describe(self, symbols=True, model=True, display=True): 

2339 """ 

2340 Convenience method for showing all information about the model. 

2341 """ 

2342 # Asssemble the requested output 

2343 output = self.name + ": " + self.description + "\n" 

2344 if symbols or model: 

2345 output += "\n" 

2346 if symbols: 

2347 output += "----------------------------------\n" 

2348 output += "%%%%%%%%%%%%% SYMBOLS %%%%%%%%%%%%\n" 

2349 output += "----------------------------------\n" 

2350 output += self.describe_symbols(display=False) 

2351 if model: 

2352 output += self.describe_model(display=False) 

2353 if symbols and not model: 

2354 output += "----------------------------------" 

2355 

2356 # Return or print the output 

2357 if display: 

2358 print(output) 

2359 return 

2360 else: 

2361 return output 

2362 

2363 

2364def _parse_model_fields(model, common_override=None): 

2365 """ 

2366 Extract the top-level fields from a parsed model dictionary. 

2367 

2368 Uses dict.get() with safe defaults rather than try/except for each field, 

2369 so that missing keys silently receive their default values. 

2370 

2371 Parameters 

2372 ---------- 

2373 model : dict 

2374 Parsed YAML model dictionary. 

2375 common_override : list or None 

2376 If provided, overrides the model's 'common' field entirely. 

2377 

2378 Returns 

2379 ------- 

2380 model_name : str 

2381 Name of the model, or 'DEFAULT_NAME' if absent. 

2382 description : str 

2383 Human-readable description, or a placeholder if absent. 

2384 variables : list 

2385 Declared variable lines from model['symbols']['variables']. 

2386 twist : dict 

2387 Intertemporal twist mapping, or empty dict if absent. 

2388 common : list 

2389 Variables shared across all agents. 

2390 arrival : list 

2391 Explicitly listed arrival variable names. 

2392 """ 

2393 symbols = model.get("symbols", {}) 

2394 model_name = model.get("name", "DEFAULT_NAME") 

2395 description = model.get("description", "(no description provided)") 

2396 variables = symbols.get("variables", []) 

2397 twist = model.get("twist", {}) 

2398 arrival = symbols.get("arrival", []) 

2399 if common_override is not None: 

2400 common = common_override 

2401 else: 

2402 common = symbols.get("common", []) 

2403 return model_name, description, variables, twist, common, arrival 

2404 

2405 

2406def _build_periods( 

2407 template, agent, content, solution, offset, time_vary, time_inv, RNG, T_seq, T_cycle 

2408): 

2409 """ 

2410 Construct the list of per-period SimBlock copies for an AgentSimulator. 

2411 

2412 For each period in the solution sequence, a deep copy of the template block 

2413 is made and populated with the appropriate parameter data drawn from the agent. 

2414 

2415 Parameters 

2416 ---------- 

2417 template : SimBlock 

2418 Template block with structure but no parameter values. 

2419 agent : AgentType 

2420 The agent whose solution and time-varying attributes supply parameter values. 

2421 content : dict 

2422 Keys are the names of objects needed by the template block. 

2423 solution : list 

2424 Names of objects that come from the agent's solution attribute. 

2425 offset : list 

2426 Names of time-varying objects whose index is shifted back by one period. 

2427 time_vary : list 

2428 Names of objects that vary across periods (drawn from agent attributes). 

2429 time_inv : list 

2430 Names of objects that are time-invariant (same across all periods). 

2431 RNG : np.random.Generator 

2432 Random number generator used to assign unique seeds to MarkovEvents. 

2433 T_seq : int 

2434 Number of periods in the solution sequence. 

2435 T_cycle : int 

2436 Number of periods per cycle (used to wrap the time index). 

2437 

2438 Returns 

2439 ------- 

2440 periods : list[SimBlock] 

2441 Fully populated list of period blocks, one per entry in the solution. 

2442 """ 

2443 # Build the time-invariant parameter dictionary once 

2444 time_inv_dict = {} 

2445 for name in content: 

2446 if name in time_inv: 

2447 if not hasattr(agent, name): 

2448 raise ValueError( 

2449 "Couldn't get a value for time-invariant object " 

2450 + name 

2451 + ": attribute does not exist on the agent." 

2452 ) 

2453 time_inv_dict[name] = getattr(agent, name) 

2454 

2455 periods = [] 

2456 t_cycle = 0 

2457 for t in range(T_seq): 

2458 # Make a fresh copy of the template period 

2459 new_period = deepcopy(template) 

2460 

2461 # Make sure each period's events have unique seeds; this is only for MarkovEvents 

2462 for event in new_period.events: 

2463 if hasattr(event, "seed"): 

2464 event.seed = RNG.integers(0, 2**31 - 1) 

2465 

2466 # Make the parameter dictionary for this period 

2467 new_param_dict = deepcopy(time_inv_dict) 

2468 for name in content: 

2469 if name in solution: 

2470 if type(agent.solution[t]) is dict: 

2471 new_param_dict[name] = agent.solution[t][name] 

2472 else: 

2473 new_param_dict[name] = getattr(agent.solution[t], name) 

2474 elif name in time_vary: 

2475 s = (t_cycle - 1) if name in offset else t_cycle 

2476 attr = getattr(agent, name, None) 

2477 if attr is None: 

2478 raise ValueError( 

2479 "Couldn't get a value for time-varying object " 

2480 + name 

2481 + ": attribute does not exist on the agent." 

2482 ) 

2483 try: 

2484 new_param_dict[name] = attr[s] 

2485 except (IndexError, TypeError): 

2486 raise ValueError( 

2487 "Couldn't get a value for time-varying object " 

2488 + name 

2489 + " at time index " 

2490 + str(s) 

2491 + "!" 

2492 ) 

2493 elif name in time_inv: 

2494 continue 

2495 else: 

2496 raise ValueError( 

2497 "The object called " 

2498 + name 

2499 + " is not named in time_inv nor time_vary!" 

2500 ) 

2501 

2502 # Fill in content for this period, then add it to the list 

2503 new_period.content = new_param_dict 

2504 new_period.distribute_content() 

2505 periods.append(new_period) 

2506 

2507 # Advance time according to the cycle 

2508 t_cycle += 1 

2509 if t_cycle == T_cycle: 

2510 t_cycle = 0 

2511 

2512 return periods 

2513 

2514 

2515def _load_agent_model(agent): 

2516 if hasattr(agent, "model_statement"): 

2517 model_statement = copy(agent.model_statement) 

2518 else: 

2519 with importlib.resources.open_text("HARK.models", agent.model_file) as f: 

2520 model_statement = f.read() 

2521 return yaml.safe_load(model_statement) 

2522 

2523 

2524def _build_declared_types(variables, comments, arrival, common): 

2525 types = {} 

2526 for var_line in variables: 

2527 var_name, var_type, flags, desc = parse_declaration_for_parts(var_line) 

2528 if var_type is not None: 

2529 try: 

2530 var_type = eval(var_type) 

2531 except (NameError, SyntaxError) as exc: 

2532 raise ValueError( 

2533 f"Couldn't understand type {var_type} for declared variable {var_name}!" 

2534 ) from exc 

2535 else: 

2536 var_type = float 

2537 types[var_name] = var_type 

2538 comments[var_name] = desc 

2539 if ("arrival" in flags) and (var_name not in arrival): 

2540 arrival.append(var_name) 

2541 if ("common" in flags) and (var_name not in common): 

2542 common.append(var_name) 

2543 return types 

2544 

2545 

2546def _classify_symbols(information): 

2547 parameters = [] 

2548 functions = [] 

2549 distributions = [] 

2550 for key, val in information.items(): 

2551 if val is None: 

2552 parameters.append(key) 

2553 elif type(val) is NullFunc: 

2554 functions.append(key) 

2555 elif type(val) is Distribution: 

2556 distributions.append(key) 

2557 return parameters, functions, distributions 

2558 

2559 

2560def _augment_types_with_undeclared(information, types, comments): 

2561 for var, this in information.items(): 

2562 if var in types: 

2563 continue 

2564 if (this is None) or (type(this) is Distribution) or (type(this) is NullFunc): 

2565 continue 

2566 types[var] = float 

2567 comments[var] = "" 

2568 if "dead" in types: 

2569 types["dead"] = bool 

2570 comments["dead"] = "whether agent died this period" 

2571 types["t_seq"] = int 

2572 types["t_age"] = int 

2573 comments["t_seq"] = "which period of the sequence the agent is on" 

2574 comments["t_age"] = "how many periods the agent has already lived for" 

2575 

2576 

2577def _resolve_initializer_values(agent, init_info): 

2578 init_dict = {} 

2579 for name in init_info.keys(): 

2580 try: 

2581 init_dict[name] = getattr(agent, name) 

2582 except AttributeError as exc: 

2583 raise ValueError( 

2584 f"Couldn't get a value for initializer object {name}!" 

2585 ) from exc 

2586 return init_dict 

2587 

2588 

2589def _resolve_T_age(T_age, cycles, T_seq): 

2590 if T_age is None: 

2591 T_age = 0 

2592 if cycles > 0: 

2593 T_age = np.minimum(T_seq - 1, T_age) 

2594 return T_age 

2595 

2596 

2597def make_simulator_from_agent(agent, stop_dead=True, replace_dead=True, common=None): 

2598 """ 

2599 Build an AgentSimulator instance based on an AgentType instance. The AgentType 

2600 should have its model attribute defined so that it can be parsed and translated 

2601 into the simulator structure. The names of objects in the model statement 

2602 should correspond to attributes of the AgentType. 

2603 

2604 Parameters 

2605 ---------- 

2606 agent : AgentType 

2607 Agents for whom a new simulator is to be constructed. 

2608 stop_dead : bool 

2609 Whether simulated agents who draw dead=True should actually cease acting. 

2610 Default is True. Setting to False allows "cohort-style" simulation that 

2611 will generate many agents that survive to old ages. In most cases, T_sim 

2612 should not exceed T_age, unless the user really does want multiple succ- 

2613 essive cohorts to be born and fully simulated. 

2614 replace_dead : bool 

2615 Whether simulated agents who are marked as dead should be replaced with 

2616 newborns (default True) or simply cease acting without replacement (False). 

2617 The latter option is useful for models with state-dependent mortality, 

2618 to allow "cohort-style" simulation with the correct distribution of states 

2619 for survivors at each age. Setting False has no effect if stop_dead is True. 

2620 common : [str] or None 

2621 List of random variables that should be treated as commonly shared across 

2622 all agents, rather than idiosyncratically drawn. If this is provided, it 

2623 will override the model defaults. 

2624 

2625 Returns 

2626 ------- 

2627 new_simulator : AgentSimulator 

2628 A simulator structure based on the agents. 

2629 """ 

2630 model = _load_agent_model(agent) 

2631 time_vary = agent.time_vary 

2632 time_inv = agent.time_inv 

2633 cycles = agent.cycles 

2634 T_age = agent.T_age 

2635 comments = {} 

2636 RNG = agent.RNG # this is only for generating seeds for MarkovEvents 

2637 

2638 model_name, description, variables, twist, common, arrival = _parse_model_fields( 

2639 model, common_override=common 

2640 ) 

2641 

2642 types = _build_declared_types(variables, comments, arrival, common) 

2643 

2644 template_period, information, offset, solution, block_comments = ( 

2645 make_template_block(model, arrival, common) 

2646 ) 

2647 comments.update(block_comments) 

2648 

2649 initializer, init_info = make_initializer(model, arrival, common) 

2650 statement = template_period.statement 

2651 content = template_period.content 

2652 

2653 parameters, functions, distributions = _classify_symbols(information) 

2654 _augment_types_with_undeclared(information, types, comments) 

2655 

2656 initializer.content = _resolve_initializer_values(agent, init_info) 

2657 initializer.distribute_content() 

2658 

2659 T_seq = len(agent.solution) 

2660 T_cycle = agent.T_cycle 

2661 periods = _build_periods( 

2662 template_period, 

2663 agent, 

2664 content, 

2665 solution, 

2666 offset, 

2667 time_vary, 

2668 time_inv, 

2669 RNG, 

2670 T_seq, 

2671 T_cycle, 

2672 ) 

2673 

2674 T_age = _resolve_T_age(T_age, cycles, T_seq) 

2675 T_sim = getattr(agent, "T_sim", 0) 

2676 

2677 # Make and return the new simulator 

2678 new_simulator = AgentSimulator( 

2679 name=model_name, 

2680 description=description, 

2681 statement=statement, 

2682 comments=comments, 

2683 parameters=parameters, 

2684 functions=functions, 

2685 distributions=distributions, 

2686 common=common, 

2687 types=types, 

2688 N_agents=agent.AgentCount, 

2689 T_total=T_seq, 

2690 T_sim=T_sim, 

2691 T_age=T_age, 

2692 stop_dead=stop_dead, 

2693 replace_dead=replace_dead, 

2694 periods=periods, 

2695 twist=twist, 

2696 initializer=initializer, 

2697 track_vars=agent.track_vars, 

2698 ) 

2699 new_simulator.solution = solution # this is for use by SSJ constructor 

2700 return new_simulator 

2701 

2702 

2703def _extract_symbol_class( 

2704 model, class_name, constructor, validator_msg, offset, solution, comments 

2705): 

2706 """ 

2707 Parse and collect one class of symbols (parameters, functions, or distributions). 

2708 

2709 Handles the near-identical pattern repeated for each symbol class: 

2710 iterate over declaration lines, build the result dict, record comments, 

2711 and append names to the offset and solution lists as flagged. 

2712 

2713 Parameters 

2714 ---------- 

2715 model : dict 

2716 Parsed model dictionary containing a 'symbols' sub-dict. 

2717 class_name : str 

2718 Key within model['symbols'] to look up ('parameters', 'functions', or 

2719 'distributions'). 

2720 constructor : callable or None 

2721 Called with no arguments to create each entry's value. Pass None for 

2722 parameters (which use None as their placeholder value). 

2723 validator_msg : str or None 

2724 If provided, the expected datatype string (e.g. 'func' or 'dstn'). When a 

2725 declaration carries a different datatype, a ValueError is raised. Pass None 

2726 to skip validation (used for parameters). 

2727 offset : list 

2728 Accumulated list of offset-flagged names; extended in place. 

2729 solution : list 

2730 Accumulated list of solution-flagged names; extended in place. 

2731 comments : dict 

2732 Accumulated comment strings keyed by name; updated in place. 

2733 

2734 Returns 

2735 ------- 

2736 result : dict 

2737 Mapping from symbol name to its constructed value (or None for parameters). 

2738 """ 

2739 result = {} 

2740 symbols = model.get("symbols", {}) 

2741 if class_name not in symbols: 

2742 return result 

2743 lines = symbols[class_name] 

2744 for line in lines: 

2745 name, datatype, flags, desc = parse_declaration_for_parts(line) 

2746 if ( 

2747 (validator_msg is not None) 

2748 and (datatype is not None) 

2749 and (datatype != validator_msg) 

2750 ): 

2751 raise ValueError( 

2752 name 

2753 + " was declared as a " 

2754 + class_name[:-1] 

2755 + ", but given a different datatype!" 

2756 ) 

2757 result[name] = constructor() if constructor is not None else None 

2758 comments[name] = desc 

2759 if ("offset" in flags) and (name not in offset): 

2760 offset.append(name) 

2761 if ("solution" in flags) and (name not in solution): 

2762 solution.append(name) 

2763 return result 

2764 

2765 

2766def _build_event_list(lines, info, common): 

2767 """ 

2768 Walk a sequence of statement lines and build the corresponding event list. 

2769 

2770 Each line is parsed into a new event via :func:`make_new_event`, the event's 

2771 assigned variables are folded into ``info`` (raising on duplicates), and 

2772 events whose assigned variables overlap ``common`` are flagged as common. 

2773 

2774 Returns ``(events, names_used)``. 

2775 """ 

2776 events = [] 

2777 names_used = [] 

2778 for line in lines: 

2779 new_event, used = make_new_event(line, info) 

2780 events.append(new_event) 

2781 names_used += used 

2782 

2783 for var in new_event.assigns: 

2784 if var in info.keys(): 

2785 raise ValueError(var + " is assigned, but already exists!") 

2786 info[var] = 0 

2787 

2788 for var in new_event.assigns: 

2789 if var in common: 

2790 new_event.common = True 

2791 break 

2792 return events, names_used 

2793 

2794 

2795def _format_block_statement(events): 

2796 """ 

2797 Format an aligned text representation of an event list for SimBlock display. 

2798 

2799 Each event's ``statement`` is right-padded to the longest statement, then 

2800 appended with ``": " + event.description``. 

2801 """ 

2802 if not events: 

2803 return "" 

2804 longest = np.max([len(event.statement) for event in events]) 

2805 parts = [] 

2806 for event in events: 

2807 pad = (longest + 1) - len(event.statement) 

2808 parts.append(event.statement + pad * " " + ": " + event.description + "\n") 

2809 return "".join(parts) 

2810 

2811 

2812def _collect_initializer_symbols(model, class_name, constructor, expected_datatype): 

2813 """ 

2814 Collect one class of symbols (parameters, functions, or distributions) for the 

2815 initializer SimBlock. 

2816 

2817 Parameters 

2818 ---------- 

2819 model : dict 

2820 Parsed model dictionary containing a 'symbols' sub-dict. 

2821 class_name : str 

2822 Key within ``model['symbols']`` to look up. 

2823 constructor : callable or None 

2824 Called with no arguments to build each entry's placeholder value. Pass 

2825 ``None`` for parameters (which use ``None`` as their placeholder). 

2826 expected_datatype : str or None 

2827 If non-None, declarations must either omit the datatype or use this 

2828 exact string; otherwise a ValueError is raised. 

2829 """ 

2830 result = {} 

2831 symbols = model.get("symbols", {}) 

2832 if class_name not in symbols: 

2833 return result 

2834 for line in symbols[class_name]: 

2835 name, datatype, flags, desc = parse_declaration_for_parts(line) 

2836 if ( 

2837 expected_datatype is not None 

2838 and datatype is not None 

2839 and datatype != expected_datatype 

2840 ): 

2841 raise ValueError( 

2842 f"{name} was declared as a {class_name[:-1]}, " 

2843 "but given a different datatype!" 

2844 ) 

2845 result[name] = constructor() if constructor is not None else None 

2846 return result 

2847 

2848 

2849def make_template_block(model, arrival=None, common=None): 

2850 """ 

2851 Construct a new SimBlock object as a "template" of the model block. It has 

2852 events and reference information, but no values filled in. 

2853 

2854 Parameters 

2855 ---------- 

2856 model : dict 

2857 Dictionary with model block information, probably read in as a yaml. 

2858 arrival : [str] or None 

2859 List of arrival variables that were flagged or explicitly listed. 

2860 common : [str] or None 

2861 List of variables that are common or shared across all agents, rather 

2862 than idiosyncratically drawn. 

2863 

2864 Returns 

2865 ------- 

2866 template_block : SimBlock 

2867 A "template" of this model block, with no parameters (etc) on it. 

2868 info : dict 

2869 Dictionary of model objects that were referenced within the block. Keys 

2870 are object names and entries reveal what kind of object they are: 

2871 - None --> parameter 

2872 - 0 --> outcome/data variable (including arrival variables) 

2873 - NullFunc --> function 

2874 - Distribution --> distribution 

2875 offset : [str] 

2876 List of object names that are offset in time by one period. 

2877 solution : [str] 

2878 List of object names that are part of the model solution. 

2879 comments : dict 

2880 Dictionary of comments included with declared functions, distributions, 

2881 and parameters. 

2882 """ 

2883 if arrival is None: 

2884 arrival = [] 

2885 if common is None: 

2886 common = [] 

2887 

2888 # Extract explicitly listed metadata using dict.get for safe defaults 

2889 symbols = model.get("symbols", {}) 

2890 name = model.get("name", None) 

2891 offset = symbols.get("offset", []) 

2892 solution = symbols.get("solution", []) 

2893 

2894 # Extract parameters, functions, and distributions using the shared helper 

2895 comments = {} 

2896 parameters = _extract_symbol_class( 

2897 model, "parameters", None, None, offset, solution, comments 

2898 ) 

2899 functions = _extract_symbol_class( 

2900 model, "functions", NullFunc, "func", offset, solution, comments 

2901 ) 

2902 distributions = _extract_symbol_class( 

2903 model, "distributions", Distribution, "dstn", offset, solution, comments 

2904 ) 

2905 

2906 # Combine those dictionaries into a single "information" dictionary, which 

2907 # represents objects available *at that point* in the dynamic block 

2908 content = parameters.copy() 

2909 content.update(functions) 

2910 content.update(distributions) 

2911 info = deepcopy(content) 

2912 for var in arrival: 

2913 info[var] = 0 # Mark as a state variable 

2914 

2915 # Parse the model dynamics and build the event list 

2916 dynamics = format_block_statement(model["dynamics"]) 

2917 events, names_used_in_dynamics = _build_event_list(dynamics, info, common) 

2918 

2919 # Remove content that is never referenced within the dynamics 

2920 for name in [n for n in content.keys() if n not in names_used_in_dynamics]: 

2921 del content[name] 

2922 

2923 # Make a single string model statement 

2924 statement = _format_block_statement(events) 

2925 

2926 # Make a description for the template block 

2927 if name is None: 

2928 description = "template block for unnamed block" 

2929 else: 

2930 description = "template block for " + name 

2931 

2932 # Make and return the new SimBlock 

2933 template_block = SimBlock( 

2934 description=description, 

2935 arrival=arrival, 

2936 content=content, 

2937 statement=statement, 

2938 events=events, 

2939 ) 

2940 return template_block, info, offset, solution, comments 

2941 

2942 

2943def make_initializer(model, arrival=None, common=None): 

2944 """ 

2945 Construct a new SimBlock object to be the agent initializer, based on the 

2946 model dictionary. It has structure and events, but no parameters (etc). 

2947 

2948 Parameters 

2949 ---------- 

2950 model : dict 

2951 Dictionary with model initializer information, probably read in as a yaml. 

2952 arrival : [str] 

2953 List of arrival variables that were flagged or explicitly listed. 

2954 

2955 Returns 

2956 ------- 

2957 initializer : SimBlock 

2958 A "template" of this model block, with no parameters (etc) on it. 

2959 init_requires : dict 

2960 Dictionary of model objects that are needed by the initializer to run. 

2961 Keys are object names and entries reveal what kind of object they are: 

2962 - None --> parameter 

2963 - 0 --> outcome variable (these should include all arrival variables) 

2964 - NullFunc --> function 

2965 - Distribution --> distribution 

2966 """ 

2967 if arrival is None: 

2968 arrival = [] 

2969 if common is None: 

2970 common = [] 

2971 name = model.get("name", "DEFAULT_NAME") 

2972 

2973 parameters = _collect_initializer_symbols(model, "parameters", None, None) 

2974 functions = _collect_initializer_symbols(model, "functions", NullFunc, "func") 

2975 distributions = _collect_initializer_symbols( 

2976 model, "distributions", Distribution, "dstn" 

2977 ) 

2978 

2979 # Combine those dictionaries into a single "information" dictionary 

2980 content = parameters.copy() 

2981 content.update(functions) 

2982 content.update(distributions) 

2983 info = deepcopy(content) 

2984 

2985 # Parse the initialization routine and build the event list 

2986 initialize = format_block_statement(model["initialize"]) 

2987 events, _ = _build_event_list(initialize, info, common) 

2988 

2989 # Verify that all arrival variables were created in the initializer 

2990 for var in arrival: 

2991 if var not in info.keys(): 

2992 raise ValueError( 

2993 "The arrival variable " + var + " was not set in the initialize block!" 

2994 ) 

2995 

2996 # Make a blank dictionary with information the initializer needs 

2997 init_requires = {} 

2998 for event in events: 

2999 for var in event.parameters.keys(): 

3000 if var not in init_requires.keys(): 

3001 try: 

3002 init_requires[var] = parameters[var] 

3003 except KeyError as exc: 

3004 raise ValueError( 

3005 f"{var} was referenced in initialize, " 

3006 "but not declared as a parameter!" 

3007 ) from exc 

3008 if type(event) is RandomEvent: 

3009 dstn_name = getattr(event, "_dstn_name", None) 

3010 try: 

3011 init_requires[dstn_name] = distributions[dstn_name] 

3012 except KeyError as exc: 

3013 raise ValueError( 

3014 f"{dstn_name} was referenced in initialize, " 

3015 "but not declared as a distribution!" 

3016 ) from exc 

3017 if type(event) is EvaluationEvent: 

3018 func_name = getattr(event, "_func_name", None) 

3019 try: 

3020 init_requires[func_name] = functions[func_name] 

3021 except KeyError as exc: 

3022 raise ValueError( 

3023 f"{func_name} was referenced in initialize, " 

3024 "but not declared as a function!" 

3025 ) from exc 

3026 

3027 # Make a single string initializer statement 

3028 statement = _format_block_statement(events) 

3029 

3030 # Make and return the new SimBlock 

3031 initializer = SimBlock( 

3032 description="agent initializer for " + name, 

3033 content=init_requires, 

3034 statement=statement, 

3035 events=events, 

3036 ) 

3037 return initializer, init_requires 

3038 

3039 

3040def make_new_event(statement, info): 

3041 """ 

3042 Makes a "blank" version of a model event based on a statement line. Determines 

3043 which objects are needed vs assigned vs parameters / information from context. 

3044 

3045 Parameters 

3046 ---------- 

3047 statement : str 

3048 One line of a model statement, which will be turned into an event. 

3049 info : dict 

3050 Empty dictionary of model information that already exists. Consists of 

3051 arrival variables, already assigned variables, parameters, and functions. 

3052 Typing of each is based on the kind of "empty" object. 

3053 

3054 Returns 

3055 ------- 

3056 new_event : ModelEvent 

3057 A new model event with values and information missing, but structure set. 

3058 names_used : [str] 

3059 List of names of objects used in this expression. 

3060 """ 

3061 # First determine what kind of event this is 

3062 has_eq = "=" in statement 

3063 has_tld = "~" in statement 

3064 has_amp = "@" in statement 

3065 has_brc = ("{" in statement) and ("}" in statement) 

3066 has_brk = ("[" in statement) and ("]" in statement) 

3067 event_type = None 

3068 if has_eq: 

3069 if has_tld: 

3070 raise ValueError("A statement line can't have both an = and a ~!") 

3071 if has_amp: 

3072 event_type = EvaluationEvent 

3073 else: 

3074 event_type = DynamicEvent 

3075 if has_tld: 

3076 if has_brc: 

3077 event_type = MarkovEvent 

3078 elif has_brk: 

3079 event_type = RandomIndexedEvent 

3080 else: 

3081 event_type = RandomEvent 

3082 if event_type is None: 

3083 raise ValueError("Statement line was not any valid type!") 

3084 

3085 # Now make and return an appropriate event for that type 

3086 if event_type is DynamicEvent: 

3087 event_maker = make_new_dynamic 

3088 if event_type is RandomEvent: 

3089 event_maker = make_new_random 

3090 if event_type is RandomIndexedEvent: 

3091 event_maker = make_new_random_indexed 

3092 if event_type is MarkovEvent: 

3093 event_maker = make_new_markov 

3094 if event_type is EvaluationEvent: 

3095 event_maker = make_new_evaluation 

3096 

3097 new_event, names_used = event_maker(statement, info) 

3098 return new_event, names_used 

3099 

3100 

3101def make_new_dynamic(statement, info): 

3102 """ 

3103 Construct a new instance of DynamicEvent based on the given model statement 

3104 line and a blank dictionary of parameters. The statement should already be 

3105 verified to be a valid dynamic statement: it has an = but no ~ or @. 

3106 

3107 Parameters 

3108 ---------- 

3109 statement : str 

3110 One line dynamics statement, which will be turned into a DynamicEvent. 

3111 info : dict 

3112 Empty dictionary of available information. 

3113 

3114 Returns 

3115 ------- 

3116 new_dynamic : DynamicEvent 

3117 A new dynamic event with values and information missing, but structure set. 

3118 names_used : [str] 

3119 List of names of objects used in this expression. 

3120 """ 

3121 # Cut the statement up into its LHS, RHS, and description 

3122 lhs, rhs, description = parse_line_for_parts(statement, "=") 

3123 

3124 # Parse the LHS (assignment) to get assigned variables 

3125 assigns = parse_assignment(lhs) 

3126 

3127 # Parse the RHS (dynamic statement) to extract object names used 

3128 obj_names, is_indexed = extract_var_names_from_expr(rhs) 

3129 

3130 # Allocate each variable to needed dynamic variables or parameters 

3131 needs = [] 

3132 parameters = {} 

3133 for j in range(len(obj_names)): 

3134 var = obj_names[j] 

3135 if var not in info.keys(): 

3136 raise ValueError( 

3137 var + " is used in a dynamic expression, but does not (yet) exist!" 

3138 ) 

3139 val = info[var] 

3140 if type(val) is NullFunc: 

3141 raise ValueError( 

3142 var + " is used in a dynamic expression, but it's a function!" 

3143 ) 

3144 if type(val) is Distribution: 

3145 raise ValueError( 

3146 var + " is used in a dynamic expression, but it's a distribution!" 

3147 ) 

3148 if val is None: 

3149 parameters[var] = None 

3150 else: 

3151 needs.append(var) 

3152 

3153 # Declare a SymPy symbol for each variable used; these are temporary 

3154 _args = [] 

3155 for j in range(len(obj_names)): 

3156 _var = obj_names[j] 

3157 if is_indexed[j]: 

3158 exec(_var + " = IndexedBase('" + _var + "')") 

3159 else: 

3160 exec(_var + " = symbols('" + _var + "')") 

3161 _args.append(symbols(_var)) 

3162 

3163 # Make a SymPy expression, then lambdify it 

3164 sympy_expr = symbols(rhs) 

3165 expr = lambdify(_args, sympy_expr) 

3166 

3167 # Make an overall list of object names referenced in this event 

3168 names_used = assigns + obj_names 

3169 

3170 # Make and return the new dynamic event 

3171 new_dynamic = DynamicEvent( 

3172 description=description, 

3173 statement=lhs + " = " + rhs, 

3174 assigns=assigns, 

3175 needs=needs, 

3176 parameters=parameters, 

3177 expr=expr, 

3178 args=obj_names, 

3179 ) 

3180 return new_dynamic, names_used 

3181 

3182 

3183def make_new_random(statement, info): 

3184 """ 

3185 Make a new random variable realization event based on the given model statement 

3186 line and a blank dictionary of parameters. The statement should already be 

3187 verified to be a valid random statement: it has a ~ but no = or []. 

3188 

3189 Parameters 

3190 ---------- 

3191 statement : str 

3192 One line of the model statement, which will be turned into a random event. 

3193 info : dict 

3194 Empty dictionary of available information. 

3195 

3196 Returns 

3197 ------- 

3198 new_random : RandomEvent 

3199 A new random event with values and information missing, but structure set. 

3200 names_used : [str] 

3201 List of names of objects used in this expression. 

3202 """ 

3203 # Cut the statement up into its LHS, RHS, and description 

3204 lhs, rhs, description = parse_line_for_parts(statement, "~") 

3205 

3206 # Parse the LHS (assignment) to get assigned variables 

3207 assigns = parse_assignment(lhs) 

3208 

3209 # Verify that the RHS is actually a distribution 

3210 if type(info[rhs]) is not Distribution: 

3211 raise ValueError( 

3212 rhs + " was treated as a distribution, but not declared as one!" 

3213 ) 

3214 

3215 # Make an overall list of object names referenced in this event 

3216 names_used = assigns + [rhs] 

3217 

3218 # Make and return the new random event 

3219 new_random = RandomEvent( 

3220 description=description, 

3221 statement=lhs + " ~ " + rhs, 

3222 assigns=assigns, 

3223 needs=[], 

3224 parameters={}, 

3225 dstn=info[rhs], 

3226 ) 

3227 new_random._dstn_name = rhs 

3228 return new_random, names_used 

3229 

3230 

3231def make_new_random_indexed(statement, info): 

3232 """ 

3233 Make a new indexed random variable realization event based on the given model 

3234 statement line and a blank dictionary of parameters. The statement should 

3235 already be verified to be a valid random statement: it has a ~ and []. 

3236 

3237 Parameters 

3238 ---------- 

3239 statement : str 

3240 One line of the model statement, which will be turned into a random event. 

3241 info : dict 

3242 Empty dictionary of available information. 

3243 

3244 Returns 

3245 ------- 

3246 new_random_indexed : RandomEvent 

3247 A new random indexed event with values and information missing, but structure set. 

3248 names_used : [str] 

3249 List of names of objects used in this expression. 

3250 """ 

3251 # Cut the statement up into its LHS, RHS, and description 

3252 lhs, rhs, description = parse_line_for_parts(statement, "~") 

3253 

3254 # Parse the LHS (assignment) to get assigned variables 

3255 assigns = parse_assignment(lhs) 

3256 

3257 # Split the RHS into the distribution and the index 

3258 dstn, index = parse_random_indexed(rhs) 

3259 

3260 # Verify that the RHS is actually a distribution 

3261 if type(info[dstn]) is not Distribution: 

3262 raise ValueError( 

3263 dstn + " was treated as a distribution, but not declared as one!" 

3264 ) 

3265 

3266 # Make an overall list of object names referenced in this event 

3267 names_used = assigns + [dstn, index] 

3268 

3269 # Make and return the new random indexed event 

3270 new_random_indexed = RandomIndexedEvent( 

3271 description=description, 

3272 statement=lhs + " ~ " + rhs, 

3273 assigns=assigns, 

3274 needs=[index], 

3275 parameters={}, 

3276 index=index, 

3277 ) 

3278 new_random_indexed._dstn_name = dstn 

3279 return new_random_indexed, names_used 

3280 

3281 

3282def make_new_markov(statement, info): 

3283 """ 

3284 Make a new Markov-type event based on the given model statement line and a 

3285 blank dictionary of parameters. The statement should already be verified to 

3286 be a valid Markov statement: it has a ~ and {} and maybe (). This can represent 

3287 a Markov matrix transition event, a draw from a discrete index, or just a 

3288 Bernoulli random variable. If a Bernoulli event, the "probabilties" can be 

3289 idiosyncratic data. 

3290 

3291 Parameters 

3292 ---------- 

3293 statement : str 

3294 One line of the model statement, which will be turned into a random event. 

3295 info : dict 

3296 Empty dictionary of available information. 

3297 

3298 Returns 

3299 ------- 

3300 new_markov : MarkovEvent 

3301 A new Markov draw event with values and information missing, but structure set. 

3302 names_used : [str] 

3303 List of names of objects used in this expression. 

3304 """ 

3305 # Cut the statement up into its LHS, RHS, and description 

3306 lhs, rhs, description = parse_line_for_parts(statement, "~") 

3307 

3308 # Parse the LHS (assignment) to get assigned variables 

3309 assigns = parse_assignment(lhs) 

3310 

3311 # Parse the RHS (Markov statement) for the array and index 

3312 probs, index = parse_markov(rhs) 

3313 if index is None: 

3314 needs = [] 

3315 else: 

3316 needs = [index] 

3317 

3318 # Determine whether probs is an idiosyncratic variable or a parameter, and 

3319 # set up the event to grab it appropriately 

3320 if info[probs] is None: 

3321 parameters = {probs: None} 

3322 else: 

3323 needs += [probs] 

3324 parameters = {} 

3325 

3326 # Make an overall list of object names referenced in this event 

3327 names_used = assigns + needs + [probs] 

3328 

3329 # Make and return the new Markov event 

3330 new_markov = MarkovEvent( 

3331 description=description, 

3332 statement=lhs + " ~ " + rhs, 

3333 assigns=assigns, 

3334 needs=needs, 

3335 parameters=parameters, 

3336 probs=probs, 

3337 index=index, 

3338 ) 

3339 return new_markov, names_used 

3340 

3341 

3342def make_new_evaluation(statement, info): 

3343 """ 

3344 Make a new function evaluation event based the given model statement line 

3345 and a blank dictionary of parameters. The statement should already be verified 

3346 to be a valid evaluation statement: it has an @ and an = but no ~. 

3347 

3348 Parameters 

3349 ---------- 

3350 statement : str 

3351 One line of the model statement, which will be turned into an eval event. 

3352 info : dict 

3353 Empty dictionary of available information. 

3354 

3355 Returns 

3356 ------- 

3357 new_evaluation : EvaluationEvent 

3358 A new evaluation event with values and information missing, but structure set. 

3359 names_used : [str] 

3360 List of names of objects used in this expression. 

3361 """ 

3362 # Cut the statement up into its LHS, RHS, and description 

3363 lhs, rhs, description = parse_line_for_parts(statement, "=") 

3364 

3365 # Parse the LHS (assignment) to get assigned variables 

3366 assigns = parse_assignment(lhs) 

3367 

3368 # Parse the RHS (evaluation) for the function and its arguments 

3369 func, arguments = parse_evaluation(rhs) 

3370 

3371 # Allocate each variable to needed dynamic variables or parameters 

3372 needs = [] 

3373 parameters = {} 

3374 for j in range(len(arguments)): 

3375 var = arguments[j] 

3376 if var not in info.keys(): 

3377 raise ValueError( 

3378 var + " is used in an evaluation statement, but does not (yet) exist!" 

3379 ) 

3380 val = info[var] 

3381 if type(val) is NullFunc: 

3382 raise ValueError( 

3383 var 

3384 + " is used as an argument an evaluation statement, but it's a function!" 

3385 ) 

3386 if type(val) is Distribution: 

3387 raise ValueError( 

3388 var + " is used in an evaluation statement, but it's a distribution!" 

3389 ) 

3390 if val is None: 

3391 parameters[var] = None 

3392 else: 

3393 needs.append(var) 

3394 

3395 # Make an overall list of object names referenced in this event 

3396 names_used = assigns + arguments + [func] 

3397 

3398 # Make and return the new evaluation event 

3399 new_evaluation = EvaluationEvent( 

3400 description=description, 

3401 statement=lhs + " = " + rhs, 

3402 assigns=assigns, 

3403 needs=needs, 

3404 parameters=parameters, 

3405 arguments=arguments, 

3406 func=info[func], 

3407 ) 

3408 new_evaluation._func_name = func 

3409 return new_evaluation, names_used 

3410 

3411 

3412def look_for_char_and_remove(phrase, symb): 

3413 """ 

3414 Check whether a symbol appears in a string, and remove it if it does. 

3415 

3416 Parameters 

3417 ---------- 

3418 phrase : str 

3419 String to be searched for a symbol. 

3420 symb : char 

3421 Single character to be searched for. 

3422 

3423 Returns 

3424 ------- 

3425 out : str 

3426 Possibly shortened input phrase. 

3427 found : bool 

3428 Whether the symbol was found and removed. 

3429 """ 

3430 found = symb in phrase 

3431 out = phrase.replace(symb, "") 

3432 return out, found 

3433 

3434 

3435def parse_declaration_for_parts(line): 

3436 """ 

3437 Split a declaration line from a model file into the object's name, its datatype, 

3438 any metadata flags, and any provided comment or description. 

3439 

3440 Parameters 

3441 ---------- 

3442 line : str 

3443 Line of to be parsed into the object name, object type, and a comment or description. 

3444 

3445 Returns 

3446 ------- 

3447 name : str 

3448 Name of the object. 

3449 datatype : str or None 

3450 Provided datatype string, in parentheses, if any. 

3451 flags : [str] 

3452 List of metadata flags that were detected. These include ! for a variable 

3453 that is in arrival, * for any non-variable that's part of the solution, 

3454 + for any object that is offset in time, and & for a common random variable. 

3455 

3456 desc : str 

3457 Comment or description, after //, if any. 

3458 """ 

3459 flags = [] 

3460 check_for_flags = {"offset": "+", "arrival": "!", "solution": "*", "common": "&"} 

3461 

3462 # First, separate off the comment or description, if any 

3463 slashes = line.find("\\") 

3464 desc = "" if slashes == -1 else line[(slashes + 2) :].strip() 

3465 rem = line if slashes == -1 else line[:slashes].strip() 

3466 

3467 # Now look for bracketing parentheses declaring a datatype 

3468 lp = rem.find("(") 

3469 if lp > -1: 

3470 rp = rem.find(")") 

3471 if rp == -1: 

3472 raise ValueError("Unclosed parentheses on object declaration line!") 

3473 datatype = rem[(lp + 1) : rp].strip() 

3474 leftover = rem[:lp].strip() 

3475 else: 

3476 datatype = None 

3477 leftover = rem 

3478 

3479 # What's left over should be the object name plus any flags 

3480 for key in check_for_flags.keys(): 

3481 symb = check_for_flags[key] 

3482 leftover, found = look_for_char_and_remove(leftover, symb) 

3483 if found: 

3484 flags.append(key) 

3485 

3486 # Remove any remaining spaces, and that *should* be the name 

3487 name = leftover.replace(" ", "") 

3488 # TODO: Check for valid name formatting based on characters. 

3489 

3490 return name, datatype, flags, desc 

3491 

3492 

3493def parse_line_for_parts(statement, symb): 

3494 """ 

3495 Split one line of a model statement into its LHS, RHS, and description. The 

3496 description is everything following \\, while the LHS and RHS are determined 

3497 by a special symbol. 

3498 

3499 Parameters 

3500 ---------- 

3501 statement : str 

3502 One line of a model statement, which will be parsed for its parts. 

3503 symb : char 

3504 The character that represents the divide between LHS and RHS 

3505 

3506 Returns 

3507 ------- 

3508 lhs : str 

3509 The left-hand (assignment) side of the expression. 

3510 rhs : str 

3511 The right-hand (evaluation) side of the expression. 

3512 desc : str 

3513 The provided description of the expression. 

3514 """ 

3515 eq = statement.find(symb) 

3516 lhs = statement[:eq].replace(" ", "") 

3517 not_lhs = statement[(eq + 1) :] 

3518 comment = not_lhs.find("\\") 

3519 desc = "" if comment == -1 else not_lhs[(comment + 2) :].strip() 

3520 rhs = not_lhs if comment == -1 else not_lhs[:comment] 

3521 rhs = rhs.replace(" ", "") 

3522 return lhs, rhs, desc 

3523 

3524 

3525def parse_assignment(lhs): 

3526 """ 

3527 Get ordered list of assigned variables from the LHS of a model line. 

3528 

3529 Parameters 

3530 ---------- 

3531 lhs : str 

3532 Left-hand side of a model expression 

3533 

3534 Returns 

3535 ------- 

3536 assigns : List[str] 

3537 List of variable names that are assigned in this model line. 

3538 """ 

3539 if lhs[0] == "(": 

3540 if not lhs[-1] == ")": 

3541 raise ValueError("Parentheses on assignment was not closed!") 

3542 assigns = [] 

3543 pos = 0 

3544 while pos != -1: 

3545 pos += 1 

3546 end = lhs.find(",", pos) 

3547 var = lhs[pos:end] 

3548 if var != "": 

3549 assigns.append(var) 

3550 pos = end 

3551 else: 

3552 assigns = [lhs] 

3553 return assigns 

3554 

3555 

3556def extract_var_names_from_expr(expression): 

3557 """ 

3558 Parse the RHS of a dynamic model statement to get variable names used in it. 

3559 

3560 Parameters 

3561 ---------- 

3562 expression : str 

3563 RHS of a model statement to be parsed for variable names. 

3564 

3565 Returns 

3566 ------- 

3567 var_names : List[str] 

3568 List of variable names used in the expression. These *should* be dynamic 

3569 variables and parameters, but not functions. 

3570 indexed : List[bool] 

3571 Indicators for whether each variable seems to be used with indexing. 

3572 """ 

3573 var_names = [] 

3574 indexed = [] 

3575 math_symbols = "+-/*^%.(),[]{}<>" 

3576 digits = "01234567890" 

3577 cur = "" 

3578 for j in range(len(expression)): 

3579 c = expression[j] 

3580 if (c in math_symbols) or ((c in digits) and cur == ""): 

3581 if cur == "": 

3582 continue 

3583 if cur in var_names: 

3584 cur = "" 

3585 continue 

3586 var_names.append(cur) 

3587 if c == "[": 

3588 indexed.append(True) 

3589 else: 

3590 indexed.append(False) 

3591 cur = "" 

3592 else: 

3593 cur += c 

3594 if cur != "" and cur not in var_names: 

3595 var_names.append(cur) 

3596 indexed.append(False) # final symbol couldn't possibly be indexed 

3597 return var_names, indexed 

3598 

3599 

3600def parse_evaluation(expression): 

3601 """ 

3602 Separate a function evaluation expression into the function that is called 

3603 and the variable inputs that are passed to it. 

3604 

3605 Parameters 

3606 ---------- 

3607 expression : str 

3608 RHS of a function evaluation model statement, which will be parsed for 

3609 the function and its inputs. 

3610 

3611 Returns 

3612 ------- 

3613 func_name : str 

3614 Name of the function that will be called in this event. 

3615 arg_names : List[str] 

3616 List of arguments of the function. 

3617 """ 

3618 # Get the name of the function: what's to the left of the @ 

3619 amp = expression.find("@") 

3620 func_name = expression[:amp] 

3621 

3622 # Check for parentheses formatting 

3623 rem = expression[(amp + 1) :] 

3624 if not rem[0] == "(": 

3625 raise ValueError( 

3626 "The @ in a function evaluation statement must be followed by (!" 

3627 ) 

3628 if not rem[-1] == ")": 

3629 raise ValueError("A function evaluation statement must end in )!") 

3630 rem = rem[1:-1] 

3631 

3632 # Parse what's inside the parentheses for argument names 

3633 arg_names = [] 

3634 pos = 0 

3635 go = True 

3636 while go: 

3637 end = rem.find(",", pos) 

3638 if end > -1: 

3639 arg = rem[pos:end] 

3640 else: 

3641 arg = rem[pos:] 

3642 go = False 

3643 if arg != "": 

3644 arg_names.append(arg) 

3645 pos = end + 1 

3646 

3647 return func_name, arg_names 

3648 

3649 

3650def parse_markov(expression): 

3651 """ 

3652 Separate a Markov draw declaration into the array of probabilities and the 

3653 index for idiosyncratic values. 

3654 

3655 Parameters 

3656 ---------- 

3657 expression : str 

3658 RHS of a function evaluation model statement, which will be parsed for 

3659 the probabilities name and index name. 

3660 

3661 Returns 

3662 ------- 

3663 probs : str 

3664 Name of the probabilities object in this statement. 

3665 index : str 

3666 Name of the indexing variable in this statement. 

3667 """ 

3668 # Get the name of the probabilitie 

3669 lb = expression.find("{") # this *should* be 0 

3670 rb = expression.find("}") 

3671 if lb == -1 or rb == -1 or rb < (lb + 2): 

3672 raise ValueError("A Markov assignment must have an {array}!") 

3673 probs = expression[(lb + 1) : rb] 

3674 

3675 # Get the name of the index, if any 

3676 x = rb + 1 

3677 lp = expression.find("(", x) 

3678 rp = expression.find(")", x) 

3679 if lp == -1 and rp == -1: # no index present at all 

3680 return probs, None 

3681 if lp == -1 or rp == -1 or rp < (lp + 2): 

3682 raise ValueError("Improper Markov formatting: should be {probs}(index)!") 

3683 index = expression[(lp + 1) : rp] 

3684 

3685 return probs, index 

3686 

3687 

3688def parse_random_indexed(expression): 

3689 """ 

3690 Separate an indexed random variable assignment into the distribution and 

3691 the index for it. 

3692 

3693 Parameters 

3694 ---------- 

3695 expression : str 

3696 RHS of a function evaluation model statement, which will be parsed for 

3697 the distribution name and index name. 

3698 

3699 Returns 

3700 ------- 

3701 dstn : str 

3702 Name of the distribution in this statement. 

3703 index : str 

3704 Name of the indexing variable in this statement. 

3705 """ 

3706 # Get the name of the index 

3707 lb = expression.find("[") 

3708 rb = expression.find("]") 

3709 if lb == -1 or rb == -1 or rb < (lb + 2): 

3710 raise ValueError("An indexed random variable assignment must have an [index]!") 

3711 index = expression[(lb + 1) : rb] 

3712 

3713 # Get the name of the distribution 

3714 dstn = expression[:lb] 

3715 

3716 return dstn, index 

3717 

3718 

3719def format_block_statement(statement): 

3720 """ 

3721 Ensure that a string statement of a model block (maybe a period, maybe an 

3722 initializer) is formatted as a list of strings, one statement per entry. 

3723 

3724 Parameters 

3725 ---------- 

3726 statement : str 

3727 A model statement, which might be for a block or an initializer. The 

3728 statement might be formatted as a list or as a single string. 

3729 

3730 Returns 

3731 ------- 

3732 block_statements: [str] 

3733 A list of model statements, one per entry. 

3734 """ 

3735 if type(statement) is str: 

3736 if statement.find("\n") > -1: 

3737 block_statements = [] 

3738 pos = 0 

3739 end = statement.find("\n", pos) 

3740 while end > -1: 

3741 new_line = statement[pos:end] 

3742 block_statements.append(new_line) 

3743 pos = end + 1 

3744 end = statement.find("\n", pos) 

3745 else: 

3746 block_statements = [statement.copy()] 

3747 if type(statement) is list: 

3748 for line in statement: 

3749 if type(line) is not str: 

3750 raise ValueError("The model statement somehow includes a non-string!") 

3751 block_statements = statement.copy() 

3752 return block_statements 

3753 

3754 

3755@njit 

3756def aggregate_blobs_onto_polynomial_grid( 

3757 vals, pmv, origins, grid, J, Q 

3758): # pragma: no cover 

3759 """ 

3760 Numba-compatible helper function for casting "probability blobs" onto a polynomial 

3761 grid of outcome values, based on their origin in the arrival state space. This 

3762 version is for non-continuation variables, returning only the probability array 

3763 mapping from arrival states to the outcome variable. 

3764 """ 

3765 bot = grid[0] 

3766 top = grid[-1] 

3767 M = grid.size 

3768 Mm1 = M - 1 

3769 N = pmv.size 

3770 scale = 1.0 / (top - bot) 

3771 order = 1.0 / Q 

3772 diffs = grid[1:] - grid[:-1] 

3773 

3774 probs = np.zeros((J, M)) 

3775 

3776 for n in range(N): 

3777 x = vals[n] 

3778 jj = origins[n] 

3779 p = pmv[n] 

3780 if (x > bot) and (x < top): 

3781 ii = int(np.floor(((x - bot) * scale) ** order * Mm1)) 

3782 temp = (x - grid[ii]) / diffs[ii] 

3783 probs[jj, ii] += (1.0 - temp) * p 

3784 probs[jj, ii + 1] += temp * p 

3785 elif x <= bot: 

3786 probs[jj, 0] += p 

3787 else: 

3788 probs[jj, -1] += p 

3789 return probs 

3790 

3791 

3792@njit 

3793def aggregate_blobs_onto_exponential_grid( 

3794 vals, pmv, origins, grid, J, K 

3795): # pragma: no cover 

3796 """ 

3797 Numba-compatible helper function for casting "probability blobs" onto an exponential 

3798 grid of outcome values, based on their origin in the arrival state space. This 

3799 version is for non-continuation variables, returning only the probability array 

3800 mapping from arrival states to the outcome variable. 

3801 """ 

3802 bot = grid[0] 

3803 top = grid[-1] 

3804 M = grid.size 

3805 Mm1 = M - 1 

3806 N = pmv.size 

3807 diffs = grid[1:] - grid[:-1] 

3808 ltop = top - bot 

3809 for k in range(K): 

3810 ltop = np.log(ltop + 1) 

3811 scale = 1.0 / ltop 

3812 

3813 probs = np.zeros((J, M)) 

3814 

3815 for n in range(N): 

3816 x = vals[n] 

3817 jj = origins[n] 

3818 p = pmv[n] 

3819 if (x > bot) and (x < top): 

3820 y = x - bot 

3821 for k in range(K): 

3822 y = np.log(y + 1.0) 

3823 ii = int(np.floor(y * scale * Mm1)) 

3824 temp = (x - grid[ii]) / diffs[ii] 

3825 probs[jj, ii] += (1.0 - temp) * p 

3826 probs[jj, ii + 1] += temp * p 

3827 elif x <= bot: 

3828 probs[jj, 0] += p 

3829 else: 

3830 probs[jj, -1] += p 

3831 return probs 

3832 

3833 

3834@njit 

3835def aggregate_blobs_onto_custom_grid( 

3836 vals, 

3837 pmv, 

3838 origins, 

3839 grid, 

3840 J, 

3841): # pragma: no cover 

3842 """ 

3843 Numba-compatible helper function for casting "probability blobs" onto a custom 

3844 grid of outcome values, based on their origin in the arrival state space. This 

3845 version is for non-continuation variables, returning only the probability array 

3846 mapping from arrival states to the outcome variable. 

3847 """ 

3848 bot = grid[0] 

3849 top = grid[-1] 

3850 M = grid.size 

3851 N = pmv.size 

3852 diffs = grid[1:] - grid[:-1] 

3853 

3854 probs = np.zeros((J, M)) 

3855 idx = np.searchsorted(grid, vals) - 1 

3856 

3857 for n in range(N): 

3858 x = vals[n] 

3859 jj = origins[n] 

3860 p = pmv[n] 

3861 if (x > bot) and (x < top): 

3862 ii = idx[n] 

3863 temp = (x - grid[ii]) / diffs[ii] 

3864 probs[jj, ii] += (1.0 - temp) * p 

3865 probs[jj, ii + 1] += temp * p 

3866 elif x <= bot: 

3867 probs[jj, 0] += p 

3868 else: 

3869 probs[jj, -1] += p 

3870 return probs 

3871 

3872 

3873@njit 

3874def aggregate_blobs_onto_polynomial_grid_alt( 

3875 vals, pmv, origins, grid, J, Q 

3876): # pragma: no cover 

3877 """ 

3878 Numba-compatible helper function for casting "probability blobs" onto a polynomial 

3879 grid of outcome values, based on their origin in the arrival state space. This 

3880 version is for continuation variables, returning the probability array mapping 

3881 from arrival states to the outcome variable, the index in the outcome variable grid 

3882 for each blob, and the alpha weighting between gridpoints. 

3883 """ 

3884 bot = grid[0] 

3885 top = grid[-1] 

3886 M = grid.size 

3887 Mm1 = M - 1 

3888 N = pmv.size 

3889 scale = 1.0 / (top - bot) 

3890 order = 1.0 / Q 

3891 diffs = grid[1:] - grid[:-1] 

3892 

3893 probs = np.zeros((J, M)) 

3894 idx = np.empty(N, dtype=np.dtype(np.int32)) 

3895 alpha = np.empty(N) 

3896 

3897 for n in range(N): 

3898 x = vals[n] 

3899 jj = origins[n] 

3900 p = pmv[n] 

3901 if (x > bot) and (x < top): 

3902 ii = int(np.floor(((x - bot) * scale) ** order * Mm1)) 

3903 temp = (x - grid[ii]) / diffs[ii] 

3904 probs[jj, ii] += (1.0 - temp) * p 

3905 probs[jj, ii + 1] += temp * p 

3906 alpha[n] = temp 

3907 idx[n] = ii 

3908 elif x <= bot: 

3909 probs[jj, 0] += p 

3910 alpha[n] = 0.0 

3911 idx[n] = 0 

3912 else: 

3913 probs[jj, -1] += p 

3914 alpha[n] = 1.0 

3915 idx[n] = M - 2 

3916 return probs, idx, alpha 

3917 

3918 

3919@njit 

3920def aggregate_blobs_onto_exponential_grid_alt( 

3921 vals, pmv, origins, grid, J, K 

3922): # pragma: no cover 

3923 """ 

3924 Numba-compatible helper function for casting "probability blobs" onto an exponential 

3925 grid of outcome values, based on their origin in the arrival state space. This 

3926 version is for continuation variables, returning the probability array mapping 

3927 from arrival states to the outcome variable, the index in the outcome variable grid 

3928 for each blob, and the alpha weighting between gridpoints. 

3929 """ 

3930 bot = grid[0] 

3931 top = grid[-1] 

3932 M = grid.size 

3933 Mm1 = M - 1 

3934 N = pmv.size 

3935 diffs = grid[1:] - grid[:-1] 

3936 ltop = top - bot 

3937 for k in range(K): 

3938 ltop = np.log(ltop + 1) 

3939 scale = 1.0 / ltop 

3940 

3941 probs = np.zeros((J, M)) 

3942 idx = np.empty(N, dtype=np.dtype(np.int32)) 

3943 alpha = np.empty(N) 

3944 

3945 for n in range(N): 

3946 x = vals[n] 

3947 jj = origins[n] 

3948 p = pmv[n] 

3949 if (x > bot) and (x < top): 

3950 y = x - bot 

3951 for k in range(K): 

3952 y = np.log(y + 1) 

3953 ii = int(np.floor(y * scale * Mm1)) 

3954 temp = (x - grid[ii]) / diffs[ii] 

3955 probs[jj, ii] += (1.0 - temp) * p 

3956 probs[jj, ii + 1] += temp * p 

3957 alpha[n] = temp 

3958 idx[n] = ii 

3959 elif x <= bot: 

3960 probs[jj, 0] += p 

3961 alpha[n] = 0.0 

3962 idx[n] = 0 

3963 else: 

3964 probs[jj, -1] += p 

3965 alpha[n] = 1.0 

3966 idx[n] = M - 2 

3967 return probs, idx, alpha 

3968 

3969 

3970@njit 

3971def aggregate_blobs_onto_custom_grid_alt( 

3972 vals, 

3973 pmv, 

3974 origins, 

3975 grid, 

3976 J, 

3977): # pragma: no cover 

3978 """ 

3979 Numba-compatible helper function for casting "probability blobs" onto a custom 

3980 grid of outcome values, based on their origin in the arrival state space. This 

3981 version is for continuation variables, returning the probability array mapping 

3982 from arrival states to the outcome variable, the index in the outcome variable grid 

3983 for each blob, and the alpha weighting between gridpoints. 

3984 """ 

3985 bot = grid[0] 

3986 top = grid[-1] 

3987 M = grid.size 

3988 N = pmv.size 

3989 diffs = grid[1:] - grid[:-1] 

3990 

3991 probs = np.zeros((J, M)) 

3992 idx = np.searchsorted(grid, vals) - 1 

3993 alpha = np.empty(N) 

3994 

3995 for n in range(N): 

3996 x = vals[n] 

3997 jj = origins[n] 

3998 p = pmv[n] 

3999 if (x > bot) and (x < top): 

4000 ii = idx[n] 

4001 temp = (x - grid[ii]) / diffs[ii] 

4002 probs[jj, ii] += (1.0 - temp) * p 

4003 probs[jj, ii + 1] += temp * p 

4004 alpha[n] = temp 

4005 idx[n] = ii 

4006 elif x <= bot: 

4007 probs[jj, 0] += p 

4008 alpha[n] = 0.0 

4009 idx[n] = 0 

4010 else: 

4011 probs[jj, -1] += p 

4012 alpha[n] = 1.0 

4013 idx[n] = M - 2 

4014 return probs, idx, alpha 

4015 

4016 

4017@njit 

4018def aggregate_blobs_onto_discrete_grid(vals, pmv, origins, M, J): # pragma: no cover 

4019 """ 

4020 Numba-compatible helper function for allocating "probability blobs" to a grid 

4021 over a discrete state-- the state itself is truly discrete. 

4022 """ 

4023 probs = np.zeros((J, M)) 

4024 N = pmv.size 

4025 for n in range(N): 

4026 ii = vals[n] 

4027 jj = origins[n] 

4028 p = pmv[n] 

4029 probs[jj, ii] += p 

4030 return probs 

4031 

4032 

4033@njit 

4034def calc_overall_trans_probs( 

4035 out, idx, alpha, binary, offset, pmv, origins 

4036): # pragma: no cover 

4037 """ 

4038 Numba-compatible helper function for combining transition probabilities from 

4039 the arrival state space to *multiple* continuation variables into a single 

4040 unified transition matrix. 

4041 """ 

4042 N = alpha.shape[0] 

4043 B = binary.shape[0] 

4044 D = binary.shape[1] 

4045 for n in range(N): 

4046 ii = origins[n] 

4047 jj_base = idx[n] 

4048 p = pmv[n] 

4049 for b in range(B): 

4050 adj = offset[b] 

4051 P = p 

4052 for d in range(D): 

4053 k = binary[b, d] 

4054 P *= alpha[n, d, k] 

4055 jj = jj_base + adj 

4056 out[ii, jj] += P 

4057 return out