Coverage for HARK / simulator.py: 93%
1437 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 05:31 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 05:31 +0000
1"""
2A module with classes and functions for automated simulation of HARK.AgentType
3models from a human- and machine-readable model specification.
4"""
6from dataclasses import dataclass, field
7from copy import copy, deepcopy
8import numpy as np
9from numba import njit
10from sympy.utilities.lambdify import lambdify
11from sympy import symbols, IndexedBase
12from typing import Callable
13from HARK.utilities import NullFunc, make_exponential_grid
14from HARK.distributions import Distribution
15from scipy.sparse import csr_matrix
16from scipy.sparse.linalg import eigs
17from scipy.optimize import brentq
18from itertools import product
19import importlib.resources
20import yaml
22# Prevent pre-commit from removing sympy
23x = symbols("x")
24del x
25y = IndexedBase("y")
26del y
29@dataclass(kw_only=True)
30class ModelEvent:
31 """
32 Class for representing "events" that happen to agents in the course of their
33 model. These might be statements of dynamics, realization of a random shock,
34 or the evaluation of a function (potentially a control or other solution-
35 based object). This is a superclass for types of events defined below.
37 Parameters
38 ----------
39 description : str
40 Text description of this model event.
41 statement : str
42 The line of the model statement that this event corresponds to.
43 parameters : dict
44 Dictionary of objects that are static / universal within this event.
45 assigns : list[str]
46 List of names of variables that this event assigns values for.
47 needs : list[str]
48 List of names of variables that this event requires to be run.
49 data : dict
50 Dictionary of current variable values within this event.
51 common : bool
52 Indicator for whether the variables assigned in this event are commonly
53 held across all agents, rather than idiosyncratic.
54 N : int
55 Number of agents currently in this event.
56 """
58 statement: str = field(default="")
59 parameters: dict = field(default_factory=dict)
60 description: str = field(default="")
61 assigns: list[str] = field(default_factory=list, repr=False)
62 needs: list = field(default_factory=list, repr=False)
63 data: dict = field(default_factory=dict, repr=False)
64 common: bool = field(default=False, repr=False)
65 N: int = field(default=1, repr=False)
67 def run(self):
68 """
69 This method should be filled in by each subclass.
70 """
71 pass # pragma: nocover
73 def reset(self):
74 self.data = {}
76 def assign(self, output):
77 if len(self.assigns) > 1:
78 assert len(self.assigns) == len(output)
79 for j in range(len(self.assigns)):
80 var = self.assigns[j]
81 if type(output[j]) is not np.ndarray:
82 output[j] = np.array([output[j]])
83 self.data[var] = output[j]
84 else:
85 var = self.assigns[0]
86 if type(output) is not np.ndarray:
87 output = np.array([output])
88 self.data[var] = output
90 def expand_information(self, origins, probs, atoms, which=None):
91 """
92 This method is only called internally when a RandomEvent or MarkovEvent
93 runs its quasi_run() method. It expands the set of of "probability blobs"
94 by applying a random realization event. All extant blobs for which the
95 shock applies are replicated for each atom in the random event, with the
96 probability mass divided among the replicates.
98 Parameters
99 ----------
100 origins : np.array
101 Array that tracks which arrival state space node each blob originated
102 from. This is expanded into origins_new, which is returned.
103 probs : np.array
104 Vector of probabilities of each of the random possibilities.
105 atoms : [np.array]
106 List of arrays with realization values for the distribution. Each
107 array corresponds to one variable that is assigned by this event.
108 which : np.array or None
109 If given, a Boolean array indicating which of the pre-existing blobs
110 is affected by the given probabilities and atoms. By default, all
111 blobs are assumed to be affected.
113 Returns
114 -------
115 origins_new : np.array
116 Expanded boolean array of indicating the arrival state space node that
117 each blob originated from.
118 """
119 K = probs.size
120 N = self.N
121 if which is None:
122 which = np.ones(N, dtype=bool)
123 other = np.logical_not(which)
124 M = np.sum(which) # how many blobs are we affecting?
125 MX = N - M # how many blobs are we not affecting?
127 # Update probabilities of outcomes
128 pmv_old = np.reshape(self.data["pmv_"][which], (M, 1))
129 pmv_new = (pmv_old * np.reshape(probs, (1, K))).flatten()
130 self.data["pmv_"] = np.concatenate((self.data["pmv_"][other], pmv_new))
132 # Replicate the pre-existing data for each atom
133 for var in self.data.keys():
134 if (var == "pmv_") or (var in self.assigns):
135 continue # don't double expand pmv, and don't touch assigned variables
136 data_old = np.reshape(self.data[var][which], (M, 1))
137 data_new = np.tile(data_old, (1, K)).flatten()
138 self.data[var] = np.concatenate((self.data[var][other], data_new))
140 # If any of the assigned variables don't exist yet, add dummy versions
141 # of them. This section exists so that the code works with "partial events"
142 # on both the first pass and subsequent passes.
143 for j in range(len(self.assigns)):
144 var = self.assigns[j]
145 if var in self.data.keys():
146 continue
147 self.data[var] = np.zeros(N, dtype=atoms[j].dtype)
148 # Zeros are just dummy values
150 # Add the new random variables to the simulation data. This generates
151 # replicates for the affected blobs and leaves the others untouched,
152 # still with their dummy values. They will be altered on later passes.
153 for j in range(len(self.assigns)):
154 var = self.assigns[j]
155 data_new = np.tile(np.reshape(atoms[j], (1, K)), (M, 1)).flatten()
156 self.data[var] = np.concatenate((self.data[var][other], data_new))
158 # Expand the origins array to account for the new replicates
159 origins_new = np.tile(np.reshape(origins[which], (M, 1)), (1, K)).flatten()
160 origins_new = np.concatenate((origins[other], origins_new))
161 self.N = MX + M * K
163 # Send the new origins array back to the calling process
164 return origins_new
166 def add_idiosyncratic_bernoulli_info(self, origins, probs):
167 """
168 Special method for adding Bernoulli outcomes to the information set when
169 probabilities are idiosyncratic to each agent. All extant blobs are duplicated
170 with the appropriate probability
172 Parameters
173 ----------
174 origins : np.array
175 Array that tracks which arrival state space node each blob originated
176 from. This is expanded into origins_new, which is returned.
177 probs : np.array
178 Vector of probabilities of drawing True for each blob.
180 Returns
181 -------
182 origins_new : np.array
183 Expanded boolean array of indicating the arrival state space node that
184 each blob originated from.
185 """
186 N = self.N
188 # # Update probabilities of outcomes, replicating each one
189 pmv_old = np.reshape(self.data["pmv_"], (N, 1))
190 P = np.reshape(probs, (N, 1))
191 PX = np.concatenate([1.0 - P, P], axis=1)
192 pmv_new = (pmv_old * PX).flatten()
193 self.data["pmv_"] = pmv_new
195 # Replicate the pre-existing data for each atom
196 for var in self.data.keys():
197 if (var == "pmv_") or (var in self.assigns):
198 continue # don't double expand pmv, and don't touch assigned variables
199 data_old = np.reshape(self.data[var], (N, 1))
200 data_new = np.tile(data_old, (1, 2)).flatten()
201 self.data[var] = data_new
203 # Add the (one and only) new random variable to the simulation data
204 var = self.assigns[0]
205 data_new = np.tile(np.array([[0, 1]]), (N, 1)).flatten()
206 self.data[var] = data_new
208 # Expand the origins array to account for the new replicates
209 origins_new = np.tile(np.reshape(origins, (N, 1)), (1, 2)).flatten()
210 self.N = N * 2
212 # Send the new origins array back to the calling process
213 return origins_new
216@dataclass(kw_only=True)
217class DynamicEvent(ModelEvent):
218 """
219 Class for representing model dynamics for an agent, consisting of an expression
220 to be evaluated and variables to which the results are assigned.
222 Parameters
223 ----------
224 expr : Callable
225 Function or expression to be evaluated for the assigned variables.
226 args : list[str]
227 Ordered list of argument names for the expression.
228 """
230 expr: Callable = field(default_factory=NullFunc, repr=False)
231 args: list[str] = field(default_factory=list, repr=False)
233 def evaluate(self):
234 temp_dict = self.data.copy()
235 temp_dict.update(self.parameters)
236 args = (temp_dict[arg] for arg in self.args)
237 out = self.expr(*args)
238 return out
240 def run(self):
241 self.assign(self.evaluate())
243 def quasi_run(self, origins, norm=None):
244 self.run()
245 return origins
248@dataclass(kw_only=True)
249class RandomEvent(ModelEvent):
250 """
251 Class for representing the realization of random variables for an agent,
252 consisting of a shock distribution and variables to which the results are assigned.
254 Parameters
255 ----------
256 dstn : Distribution
257 Distribution of one or more random variables that are drawn from during
258 this event and assigned to the corresponding variables.
259 """
261 dstn: Distribution = field(default_factory=Distribution, repr=False)
263 def reset(self):
264 self.dstn.reset()
265 ModelEvent.reset(self)
267 def draw(self):
268 out = np.empty((len(self.assigns), self.N))
269 if not self.common:
270 out[:, :] = self.dstn.draw(self.N)
271 else:
272 out[:, :] = self.dstn.draw(1)
273 if len(self.assigns) == 1:
274 out = out.flatten()
275 return out
277 def run(self):
278 self.assign(self.draw())
280 def quasi_run(self, origins, norm=None):
281 # Get distribution
282 atoms = self.dstn.atoms
283 probs = self.dstn.pmv.copy()
285 # Apply Harmenberg normalization if applicable
286 try:
287 harm_idx = self.assigns.index(norm)
288 probs *= atoms[harm_idx]
289 except:
290 pass
292 # Expand the set of simulated blobs
293 origins_new = self.expand_information(origins, probs, atoms)
294 return origins_new
297@dataclass(kw_only=True)
298class RandomIndexedEvent(RandomEvent):
299 """
300 Class for representing the realization of random variables for an agent,
301 consisting of a list of shock distributions, an index for the list, and the
302 variables to which the results are assigned.
304 Parameters
305 ----------
306 dstn : [Distribution]
307 List of distributions of one or more random variables that are drawn
308 from during this event and assigned to the corresponding variables.
309 index : str
310 Name of the index that is used to choose a distribution for each agent.
311 """
313 index: str = field(default="", repr=False)
314 dstn: list[Distribution] = field(default_factory=list, repr=False)
316 def draw(self):
317 idx = self.data[self.index]
318 K = len(self.assigns)
319 out = np.empty((K, self.N))
320 out.fill(np.nan)
322 if self.common:
323 k = idx[0] # this will behave badly if index is not itself common
324 out[:, :] = self.dstn[k].draw(1)
325 return out
327 for k in range(len(self.dstn)):
328 these = idx == k
329 if not np.any(these):
330 continue
331 out[:, these] = self.dstn[k].draw(np.sum(these))
332 if K == 1:
333 out = out.flatten()
334 return out
336 def reset(self):
337 for k in range(len(self.dstn)):
338 self.dstn[k].reset()
339 ModelEvent.reset(self)
341 def quasi_run(self, origins, norm=None):
342 origins_new = origins.copy()
343 J = len(self.dstn)
345 for j in range(J):
346 idx = self.data[self.index]
347 these = idx == j
349 # Get distribution
350 atoms = self.dstn[j].atoms
351 probs = self.dstn[j].pmv.copy()
353 # Apply Harmenberg normalization if applicable
354 try:
355 harm_idx = self.assigns.index(norm)
356 probs *= atoms[harm_idx]
357 except:
358 pass
360 # Expand the set of simulated blobs
361 origins_new = self.expand_information(
362 origins_new, probs, atoms, which=these
363 )
365 # Return the altered origins array
366 return origins_new
369@dataclass(kw_only=True)
370class MarkovEvent(ModelEvent):
371 """
372 Class for representing the realization of a Markov draw for an agent, in which
373 a Markov probabilities (array, vector, or a single float) is used to determine
374 the realization of some discrete outcome. If the probabilities are a 2D array,
375 it represents a Markov matrix (rows sum to 1), and there must be an index; if
376 the probabilities are a vector, it should be a stochastic vector; if it's a
377 single float, it represents a Bernoulli probability.
378 """
380 probs: str = field(default="", repr=False)
381 index: str = field(default="", repr=False)
382 N: int = field(default=1, repr=False)
383 seed: int = field(default=0, repr=False)
384 # seed is overwritten when each period is created
386 def __post_init__(self):
387 self.reset_rng()
389 def reset(self):
390 self.reset_rng()
391 ModelEvent.reset(self)
393 def reset_rng(self):
394 self.RNG = np.random.RandomState(self.seed)
396 def draw(self):
397 # Initialize the output
398 out = -np.ones(self.N, dtype=int)
399 if self.probs in self.parameters:
400 probs = self.parameters[self.probs]
401 probs_are_param = True
402 else:
403 probs = self.data[self.probs]
404 probs_are_param = False
406 # Make the base draw(s)
407 if self.common:
408 X = self.RNG.rand(1)
409 else:
410 X = self.RNG.rand(self.N)
412 if self.index: # it's a Markov matrix
413 idx = self.data[self.index]
414 J = probs.shape[0]
415 for j in range(J):
416 these = idx == j
417 if not np.any(these):
418 continue
419 P = np.cumsum(probs[j, :])
420 if self.common:
421 out[:] = np.searchsorted(P, X[0]) # only one value of X!
422 else:
423 out[these] = np.searchsorted(P, X[these])
424 return out
426 if (isinstance(probs, np.ndarray)) and (
427 probs_are_param
428 ): # it's a stochastic vector
429 P = np.cumsum(probs)
430 if self.common:
431 out[:] = np.searchsorted(P, X[0])
432 return out
433 else:
434 return np.searchsorted(P, X)
436 # Otherwise, this is just a Bernoulli RV
437 P = probs
438 if self.common:
439 out[:] = X < P
440 return out
441 else:
442 return X < P # basic Bernoulli
444 def run(self):
445 self.assign(self.draw())
447 def quasi_run(self, origins, norm=None):
448 if self.probs in self.parameters:
449 probs = self.parameters[self.probs]
450 probs_are_param = True
451 else:
452 probs = self.data[self.probs]
453 probs_are_param = False
455 # If it's a Markov matrix:
456 if self.index:
457 K = probs.shape[0]
458 atoms = np.array([np.arange(probs.shape[1], dtype=int)])
459 origins_new = origins.copy()
460 for k in range(K):
461 idx = self.data[self.index]
462 these = idx == k
463 probs_temp = probs[k, :]
464 origins_new = self.expand_information(
465 origins_new, probs_temp, atoms, which=these
466 )
467 return origins_new
469 # If it's a stochastic vector:
470 if (isinstance(probs, np.ndarray)) and (probs_are_param):
471 atoms = np.array([np.arange(probs.shape[0], dtype=int)])
472 origins_new = self.expand_information(origins, probs, atoms)
473 return origins_new
475 # Otherwise, this is just a Bernoulli RV, but it might have idiosyncratic probability
476 if probs_are_param:
477 P = probs
478 atoms = np.array([[False, True]])
479 origins_new = self.expand_information(origins, np.array([1 - P, P]), atoms)
480 return origins_new
482 # Final case: probability is idiosyncratic Bernoulli
483 origins_new = self.add_idiosyncratic_bernoulli_info(origins, probs)
484 return origins_new
487@dataclass(kw_only=True)
488class EvaluationEvent(ModelEvent):
489 """
490 Class for representing the evaluation of a model function. This might be from
491 the solution of the model (like a policy function or decision rule) or just
492 a non-algebraic function used in the model. This looks a lot like DynamicEvent.
494 Parameters
495 ----------
496 func : Callable
497 Model function that is evaluated in this event, with the output assigned
498 to the appropriate variables.
499 """
501 func: Callable = field(default_factory=NullFunc, repr=False)
502 arguments: list[str] = field(default_factory=list, repr=False)
504 def evaluate(self):
505 temp_dict = self.data.copy()
506 temp_dict.update(self.parameters)
507 args_temp = (temp_dict[arg] for arg in self.arguments)
508 out = self.func(*args_temp)
509 return out
511 def run(self):
512 self.assign(self.evaluate())
514 def quasi_run(self, origins, norm=None):
515 self.run()
516 return origins
519@dataclass(kw_only=True)
520class SimBlock:
521 """
522 Class for representing a "block" of a simulated model, which might be a whole
523 period or a "stage" within a period.
525 Parameters
526 ----------
527 description : str
528 Textual description of what happens in this simulated block.
529 statement : str
530 Verbatim model statement that was used to create this block.
531 content : dict
532 Dictionary of objects that are constant / universal within the block.
533 This includes both traditional numeric parameters as well as functions.
534 arrival : list[str]
535 List of inbound states: information available at the *start* of the block.
536 events: list[ModelEvent]
537 Ordered list of events that happen during the block.
538 data: dict
539 Dictionary that stores current variable values.
540 N : int
541 Number of idiosyncratic agents in this block.
542 """
544 statement: str = field(default="", repr=False)
545 content: dict = field(default_factory=dict)
546 description: str = field(default="", repr=False)
547 arrival: list[str] = field(default_factory=list, repr=False)
548 events: list[ModelEvent] = field(default_factory=list, repr=False)
549 data: dict = field(default_factory=dict, repr=False)
550 N: int = field(default=1, repr=False)
552 def run(self):
553 """
554 Run this simulated block by running each of its events in order.
555 """
556 for j in range(len(self.events)):
557 event = self.events[j]
558 for k in range(len(event.assigns)):
559 var = event.assigns[k]
560 if var in event.data.keys():
561 del event.data[var]
562 for k in range(len(event.needs)):
563 var = event.needs[k]
564 event.data[var] = self.data[var]
565 event.N = self.N
566 event.run()
567 for k in range(len(event.assigns)):
568 var = event.assigns[k]
569 self.data[var] = event.data[var]
571 def reset(self):
572 """
573 Reset the simulated block by resetting each of its events.
574 """
575 self.data = {}
576 for j in range(len(self.events)):
577 self.events[j].reset()
579 def distribute_content(self):
580 """
581 Fill in parameters, functions, and distributions to each event.
582 """
583 for event in self.events:
584 for param in event.parameters.keys():
585 try:
586 event.parameters[param] = self.content[param]
587 except:
588 raise ValueError(
589 "Could not distribute the parameter called " + param + "!"
590 )
591 if (type(event) is RandomEvent) or (type(event) is RandomIndexedEvent):
592 try:
593 event.dstn = self.content[event._dstn_name]
594 except:
595 raise ValueError(
596 "Could not find a distribution called " + event._dstn_name + "!"
597 )
598 if type(event) is EvaluationEvent:
599 try:
600 event.func = self.content[event._func_name]
601 except:
602 raise ValueError(
603 "Could not find a function called " + event._func_name + "!"
604 )
606 def _build_input_grids(self, grid_specs, arrival_N):
607 """
608 Build input and output grid dictionaries from grid specifications.
610 Creates grids_in for arrival variables and grids_out for outcome variables,
611 tracking which arrival variables have been covered and whether each output
612 grid is continuous.
614 Parameters
615 ----------
616 grid_specs : dict
617 Dictionary of grid specifications keyed by variable name.
618 arrival_N : int
619 Number of arrival variables in this block.
621 Returns
622 -------
623 grids_in : dict
624 Grids for arrival (input) variables.
625 grids_out : dict
626 Grids for outcome (output) variables.
627 continuous_grid_out_bool : list
628 List of booleans indicating whether each output grid is continuous.
629 grid_orders : dict
630 Polynomial order for each variable grid (-1 for discrete, None if unknown).
631 dummy_grid : np.ndarray or None
632 Dummy grid used only when arrival_N == 0.
633 """
634 completed = arrival_N * [False]
635 grids_in = {}
636 grids_out = {}
637 dummy_grid = None
638 if arrival_N == 0: # should only be for initializer block
639 dummy_grid = np.array([0])
640 grids_in["_dummy"] = dummy_grid
642 continuous_grid_out_bool = []
643 grid_orders = {}
644 for var in grid_specs.keys():
645 spec = grid_specs[var]
646 try:
647 idx = self.arrival.index(var)
648 completed[idx] = True
649 is_arrival = True
650 except:
651 is_arrival = False
652 if ("min" in spec) and ("max" in spec):
653 Q = spec["order"] if "order" in spec else 1.0
654 bot = spec["min"]
655 top = spec["max"]
656 N = spec["N"]
657 new_grid = make_exponential_grid(bot, top, N, Q)
658 is_cont = True
659 grid_orders[var] = Q
660 elif "N" in spec:
661 new_grid = np.arange(spec["N"], dtype=int)
662 is_cont = False
663 grid_orders[var] = -1
664 else:
665 new_grid = None # could not make grid, construct later
666 is_cont = False
667 grid_orders[var] = None
669 if is_arrival:
670 grids_in[var] = new_grid
671 else:
672 grids_out[var] = new_grid
673 continuous_grid_out_bool.append(is_cont)
675 # Verify that specifications were passed for all arrival variables
676 for j in range(len(self.arrival)):
677 if not completed[j]:
678 raise ValueError(
679 "No grid specification was provided for " + self.arrival[j] + "!"
680 )
682 return grids_in, grids_out, continuous_grid_out_bool, grid_orders, dummy_grid
684 def _build_twist_grids(
685 self, twist, grids_in, grid_orders, grids_out, continuous_grid_out_bool
686 ):
687 """
688 Override output grids with arrival-matching grids for continuation variables.
690 When an intertemporal twist is provided, the result grids for continuation
691 variables are set to match the corresponding arrival variable grids.
693 Parameters
694 ----------
695 twist : dict
696 Mapping from continuation variable names to arrival variable names.
697 grids_in : dict
698 Grids for arrival variables.
699 grid_orders : dict
700 Polynomial orders for each variable grid (modified in place).
701 grids_out : dict
702 Grids for output variables (modified in place).
703 continuous_grid_out_bool : list
704 Boolean continuity flags for output grids (extended in place).
706 Returns
707 -------
708 grids_out : dict
709 Updated output grids.
710 grid_orders : dict
711 Updated grid orders.
712 grid_out_is_continuous : np.ndarray
713 Boolean array indicating continuity of each output grid.
714 """
715 for cont_var in twist.keys():
716 arr_var = twist[cont_var]
717 if cont_var not in list(grids_out.keys()):
718 is_cont = grids_in[arr_var].dtype is np.dtype(np.float64)
719 continuous_grid_out_bool.append(is_cont)
720 grids_out[cont_var] = copy(grids_in[arr_var])
721 grid_orders[cont_var] = grid_orders[arr_var]
722 grid_out_is_continuous = np.array(continuous_grid_out_bool)
723 return grids_out, grid_orders, grid_out_is_continuous
725 def _project_onto_output_grids(
726 self,
727 grids_out,
728 grid_out_is_continuous,
729 grid_orders,
730 cont_vars,
731 twist,
732 N_orig,
733 J,
734 N,
735 ):
736 """
737 Project quasi-simulation results onto the discretized output grids.
739 Loops over each output variable, dispatching to the appropriate
740 aggregation routine based on whether the grid is continuous or discrete.
742 Parameters
743 ----------
744 grids_out : dict
745 Output grids (may be updated for None grids).
746 grid_out_is_continuous : np.ndarray
747 Boolean continuity flags for each output variable.
748 grid_orders : dict
749 Polynomial orders for each variable grid.
750 cont_vars : list
751 Names of continuation variables.
752 twist : dict or None
753 The intertemporal twist mapping (used only to check if provided).
754 N_orig : int
755 Number of original arrival-state grid points.
756 J : int
757 Size of the arrival state mesh.
758 N : int
759 Number of agents in this block.
761 Returns
762 -------
763 matrices_out : dict
764 Transition matrices for each output variable.
765 cont_idx : dict
766 Lower-bracket indices for continuation variables.
767 cont_alpha : dict
768 Interpolation weights for continuation variables.
769 cont_M : dict
770 Grid sizes for continuation variables.
771 cont_discrete : dict
772 Whether each continuation variable uses a discrete grid.
773 grids_out : dict
774 Updated output grids (some None entries may be filled in).
775 grid_out_is_continuous : np.ndarray
776 Updated continuity flags (may be changed for size-1 float grids).
777 """
778 origin_array = self.origin_array
779 matrices_out = {}
780 cont_idx = {}
781 cont_alpha = {}
782 cont_M = {}
783 cont_discrete = {}
784 k = 0
785 for var in grids_out.keys():
786 if var not in self.data.keys():
787 raise ValueError(
788 "Variable " + var + " does not exist but a grid was specified!"
789 )
790 grid = grids_out[var]
791 vals = self.data[var]
792 pmv = self.data["pmv_"]
793 M = grid.size if grid is not None else 0
795 # Semi-hacky fix to deal with omitted arrival variables
796 if (M == 1) and (vals.dtype is np.dtype(np.float64)):
797 grid = grid.astype(float)
798 grids_out[var] = grid
799 grid_out_is_continuous[k] = True
801 if grid_out_is_continuous[k]:
802 # Split the final values among discrete gridpoints on the interior.
803 if M > 1:
804 Q = grid_orders[var]
805 if var in cont_vars:
806 trans_matrix, cont_idx[var], cont_alpha[var] = (
807 aggregate_blobs_onto_polynomial_grid_alt(
808 vals, pmv, origin_array, grid, J, Q
809 )
810 )
811 cont_M[var] = M
812 cont_discrete[var] = False
813 else:
814 trans_matrix = aggregate_blobs_onto_polynomial_grid(
815 vals, pmv, origin_array, grid, J, Q
816 )
817 else: # Skip if the grid is a dummy with only one value.
818 trans_matrix = np.ones((J, M))
819 if var in cont_vars:
820 cont_idx[var] = np.zeros(N, dtype=int)
821 cont_alpha[var] = np.zeros(N)
822 cont_M[var] = M
823 cont_discrete[var] = False
825 else: # Grid is discrete, can use simpler method
826 if grid is None:
827 M = np.max(vals.astype(int))
828 if var == "dead":
829 M = 2
830 grid = np.arange(M, dtype=int)
831 grids_out[var] = grid
832 M = grid.size
833 vals = vals.astype(int)
834 trans_matrix = aggregate_blobs_onto_discrete_grid(
835 vals, pmv, origin_array, M, J
836 )
837 if var in cont_vars:
838 cont_idx[var] = vals
839 cont_alpha[var] = np.zeros(N)
840 cont_M[var] = M
841 cont_discrete[var] = True
843 # Store the transition matrix for this variable
844 matrices_out[var] = trans_matrix
845 k += 1
847 return (
848 matrices_out,
849 cont_idx,
850 cont_alpha,
851 cont_M,
852 cont_discrete,
853 grids_out,
854 grid_out_is_continuous,
855 )
857 def _build_master_transition_array(
858 self, cont_vars, cont_idx, cont_alpha, cont_M, cont_discrete, N_orig, N, D
859 ):
860 """
861 Construct the master arrival-to-continuation transition array.
863 Combines per-variable index arrays and interpolation weights into a
864 single tensor using multilinear interpolation. The offset index arithmetic
865 and continuation-variable ordering are load-bearing.
867 Parameters
868 ----------
869 cont_vars : list
870 Names of continuation variables, ordered to match arrival variables.
871 cont_idx : dict
872 Lower-bracket indices for each continuation variable.
873 cont_alpha : dict
874 Interpolation weights (upper bracket) for each continuation variable.
875 cont_M : dict
876 Grid size for each continuation variable.
877 cont_discrete : dict
878 Whether each continuation variable uses a discrete (not continuous) grid.
879 N_orig : int
880 Number of arrival-state grid points.
881 N : int
882 Total number of quasi-simulated agents.
883 D : int
884 Number of continuation dimensions.
886 Returns
887 -------
888 master_trans_array_X : np.ndarray
889 Unnormalized master transition array of shape (N_orig, prod(cont_M)).
890 """
891 pmv = self.data["pmv_"]
892 origin_array = self.origin_array
894 # Count the number of non-trivial dimensions. A continuation dimension
895 # is non-trivial if it is both continuous and has more than one grid node.
896 C = 0
897 shape = [N_orig]
898 trivial = []
899 for var in cont_vars:
900 shape.append(cont_M[var])
901 if (not cont_discrete[var]) and (cont_M[var] > 1):
902 C += 1
903 trivial.append(False)
904 else:
905 trivial.append(True)
906 trivial = np.array(trivial)
908 # Make a binary array of offsets from the base index
909 bin_array_base = np.array(list(product([0, 1], repeat=C)))
910 bin_array = np.empty((2**C, D), dtype=int)
911 some_zeros = np.zeros(2**C, dtype=int)
912 c = 0
913 for d in range(D):
914 bin_array[:, d] = some_zeros if trivial[d] else bin_array_base[:, c]
915 c += not trivial[d]
917 # Make a vector of dimensional offsets from the base index
918 dim_offsets = np.ones(D, dtype=int)
919 for d in range(D - 1):
920 dim_offsets[d] = np.prod(shape[(d + 2) :])
921 dim_offsets_X = np.tile(dim_offsets, (2**C, 1))
922 offsets = np.sum(bin_array * dim_offsets_X, axis=1)
924 # Make combined arrays of indices and alphas
925 index_array = np.empty((N, D), dtype=int)
926 alpha_array = np.empty((N, D, 2))
927 for d in range(D):
928 var = cont_vars[d]
929 index_array[:, d] = cont_idx[var]
930 alpha_array[:, d, 0] = 1.0 - cont_alpha[var]
931 alpha_array[:, d, 1] = cont_alpha[var]
932 idx_array = np.dot(index_array, dim_offsets)
934 # Make the master transition array
935 blank = np.zeros(np.array((N_orig, np.prod(shape[1:]))))
936 master_trans_array_X = calc_overall_trans_probs(
937 blank, idx_array, alpha_array, bin_array, offsets, pmv, origin_array
938 )
939 return master_trans_array_X
941 def _condition_on_survival(self, master_trans_array_X, matrices_out, N_orig):
942 """
943 Condition the master transition array on agent survival.
945 Divides through by the survival probability so that the transition
946 array represents the distribution conditional on not dying this period.
948 Parameters
949 ----------
950 master_trans_array_X : np.ndarray
951 Unconditioned master transition array of shape (N_orig, M) where
952 M = prod(cont_M). Reshaped internally to (N_orig, N_orig, 2)
953 assuming one binary continuation variable (dead/alive).
954 matrices_out : dict
955 Per-variable transition matrices; must contain 'dead'.
956 N_orig : int
957 Number of arrival-state grid points.
959 Returns
960 -------
961 master_trans_array_X : np.ndarray
962 Survival-conditioned master transition array of shape (N_orig, N_orig).
963 """
964 master_trans_array_X = np.reshape(master_trans_array_X, (N_orig, N_orig, 2))
965 survival_probs = np.reshape(matrices_out["dead"][:, 0], [N_orig, 1])
966 master_trans_array_X = master_trans_array_X[..., 0] / survival_probs
967 return master_trans_array_X
969 def make_transition_matrices(self, grid_specs, twist=None, norm=None):
970 """
971 Construct a transition matrix for this block, moving from a discretized
972 grid of arrival variables to a discretized grid of end-of-block variables.
973 User specifies how the grids of pre-states should be built. Output is
974 stored in attributes of self as follows:
976 - matrices : A dictionary of arrays that cast from the arrival state space
977 to the grid of outcome variables. Doing np.dot(dstn, matrices[var])
978 will yield the discretized distribution of that outcome variable.
979 - grids : A dictionary of discretized grids for outcome variables. Doing
980 np.dot(np.dot(dstn, matrices[var]), grids[var]) yields the *average*
981 of that outcome in the population.
983 Parameters
984 ----------
985 grid_specs : dict
986 Dictionary of dictionaries of grid specifications. For now, these have
987 at most a minimum value, a maximum value, a number of nodes, and a poly-
988 nomial order. They are equispaced if a min and max are specified, and
989 polynomially spaced with the specified order > 0 if provided. Otherwise,
990 they are set at 0,..,N if only N is provided.
991 twist : dict or None
992 Mapping from end-of-period (continuation) variables to successor's
993 arrival variables. When this is specified, additional output is created
994 for the "full period" arrival-to-arrival transition matrix.
995 norm : str or None
996 Name of the shock variable by which to normalize for Harmenberg
997 aggregation. By default, no normalization happens.
999 Returns
1000 -------
1001 None
1002 """
1003 arrival_N = len(self.arrival)
1005 # Build input and output grids from grid specifications
1006 grids_in, grids_out, continuous_grid_out_bool, grid_orders, dummy_grid = (
1007 self._build_input_grids(grid_specs, arrival_N)
1008 )
1010 # If a twist was specified, override output grids for continuation variables
1011 if twist is not None:
1012 grids_out, grid_orders, grid_out_is_continuous = self._build_twist_grids(
1013 twist, grids_in, grid_orders, grids_out, continuous_grid_out_bool
1014 )
1015 else:
1016 grid_out_is_continuous = np.array(continuous_grid_out_bool)
1018 # Make meshes of all the arrival grids, which will be the initial simulation data
1019 if arrival_N > 0:
1020 state_meshes = np.meshgrid(
1021 *[grids_in[k] for k in self.arrival], indexing="ij"
1022 )
1023 else: # this only happens in the initializer block
1024 state_meshes = [dummy_grid.copy()]
1025 state_init = {
1026 self.arrival[k]: state_meshes[k].flatten() for k in range(arrival_N)
1027 }
1028 N_orig = state_meshes[0].size
1029 mesh_tuples = [
1030 [state_init[self.arrival[k]][n] for k in range(arrival_N)]
1031 for n in range(N_orig)
1032 ]
1034 # Quasi-simulate this block
1035 self.run_quasi_sim(state_init, norm=norm)
1037 # Add survival to output if mortality is in the model
1038 if "dead" in self.data.keys():
1039 grids_out["dead"] = None
1041 # Get continuation variable names, making sure they're in the same order
1042 # as named by the arrival variables. This should maybe be done in the
1043 # simulator when it's initialized.
1044 if twist is not None:
1045 cont_vars_orig = list(twist.keys())
1046 temp_dict = {twist[var]: var for var in cont_vars_orig}
1047 cont_vars = []
1048 for var in self.arrival:
1049 cont_vars.append(temp_dict[var])
1050 if "dead" in self.data.keys():
1051 cont_vars.append("dead")
1052 grid_out_is_continuous = np.concatenate(
1053 (grid_out_is_continuous, [False])
1054 )
1055 else:
1056 cont_vars = list(grids_out.keys()) # all outcomes are arrival vars
1057 D = len(cont_vars)
1059 # Project the final results onto the output or result grids
1060 N = self.N
1061 J = state_meshes[0].size
1062 (
1063 matrices_out,
1064 cont_idx,
1065 cont_alpha,
1066 cont_M,
1067 cont_discrete,
1068 grids_out,
1069 grid_out_is_continuous,
1070 ) = self._project_onto_output_grids(
1071 grids_out,
1072 grid_out_is_continuous,
1073 grid_orders,
1074 cont_vars,
1075 twist,
1076 N_orig,
1077 J,
1078 N,
1079 )
1081 # Construct the master arrival-to-continuation transition array
1082 master_trans_array_X = self._build_master_transition_array(
1083 cont_vars, cont_idx, cont_alpha, cont_M, cont_discrete, N_orig, N, D
1084 )
1086 # Condition on survival if relevant
1087 if "dead" in self.data.keys():
1088 master_trans_array_X = self._condition_on_survival(
1089 master_trans_array_X, matrices_out, N_orig
1090 )
1092 # Reshape the transition matrix depending on what kind of block this is
1093 if arrival_N == 0:
1094 # If this is the initializer block, the "transition" matrix is really
1095 # just the initial distribution of states at model birth; flatten it.
1096 master_init_array = master_trans_array_X.flatten()
1097 else:
1098 # In an ordinary period, reshape the transition array so it's square.
1099 master_trans_array = np.reshape(master_trans_array_X, (N_orig, N_orig))
1101 # Store the results as attributes of self
1102 grids = {}
1103 grids.update(grids_in)
1104 grids.update(grids_out)
1105 self.grids = grids
1106 self.matrices = matrices_out
1107 self.mesh = mesh_tuples
1108 if twist is not None:
1109 self.trans_array = master_trans_array
1110 if arrival_N == 0:
1111 self.init_dstn = master_init_array
1113 def run_quasi_sim(self, data, j0=0, twist=None, norm=None):
1114 """
1115 "Quasi-simulate" this block from given starting data at some event index,
1116 looping back to end at the same point (only if j0 > 0 and twist is given).
1117 To quasi-simulate means to run the model forward for *every* possible shock
1118 realization, tracking probability masses.
1120 If the quasi-simulation loops through the twist, mortality is ignored.
1122 Parameters
1123 ----------
1124 data : dict
1125 Dictionary of initial data, mapping variable names to vectors of values.
1126 j0 : int, optional
1127 Event index number at which to start (and end)) the quasi-simulation.
1128 By default, it is run from index 0.
1129 twist : dict, optional
1130 Optional dictionary mapping end-of-block variables back to arrival variables.
1131 If this is provided *and* j0 > 0, then the quasi-sim is run for a complete
1132 period, starting and ending at the same index. Else it's run to end of period.
1133 norm : str or None
1134 The name of the variable on which to perform Harmenberg normalization.
1136 Returns
1137 -------
1138 None
1139 """
1140 # Make the initial vector of probability masses
1141 if not data: # data is empty because it's initializer block
1142 N_orig = 1
1143 else:
1144 key = list(data.keys())[0]
1145 N_orig = data[key].size
1146 self.N = N_orig
1147 state_init = deepcopy(data)
1148 state_init["pmv_"] = np.ones(self.N)
1150 # Initialize the array of arrival states
1151 origin_array = np.arange(self.N, dtype=int)
1153 # Reset the block's state and give it the initial state data
1154 self.reset()
1155 self.data.update(state_init)
1157 # Loop through each event in order and quasi-simulate it
1158 J = len(self.events)
1159 for j in range(j0, J):
1160 event = self.events[j]
1161 event.data = self.data # Give event *all* data directly
1162 event.N = self.N
1163 origin_array = event.quasi_run(origin_array, norm=norm)
1164 self.N = self.data["pmv_"].size
1166 # If we didn't start at the beginning and there is a twist, loop back to
1167 # the start and do the remaining events
1168 if twist is not None:
1169 new_data = {"pmv_": self.data["pmv_"].copy()}
1170 for end_var in twist.keys():
1171 arr_var = twist[end_var]
1172 new_data[arr_var] = self.data[end_var].copy()
1173 self.data = new_data
1174 for j in range(j0):
1175 event = self.events[j]
1176 event.data = self.data # Give event *all* data directly
1177 event.N = self.N
1178 origin_array = event.quasi_run(origin_array, norm=norm)
1179 self.N = self.data["pmv_"].size
1181 # Assign the origin array as an attribute of self
1182 self.origin_array = origin_array
1185@dataclass(kw_only=True)
1186class AgentSimulator:
1187 """
1188 A class for representing an entire simulator structure for an AgentType.
1189 It includes a sequence of SimBlocks representing periods of the model, which
1190 could be built from the information on an AgentType instance.
1192 Parameters
1193 ----------
1194 name : str
1195 Short name of this model.s
1196 description : str
1197 Textual description of what happens in this simulated block.
1198 statement : str
1199 Verbatim model statement that was used to create this simulator.
1200 comments : dict
1201 Dictionary of comments or descriptions for various model objects.
1202 parameters : list[str]
1203 List of parameter names used in the model.
1204 distributions : list[str]
1205 List of distribution names used in the model.
1206 functions : list[str]
1207 List of function names used in the model.
1208 common: list[str]
1209 Names of variables that are common across idiosyncratic agents.
1210 types: dict
1211 Dictionary of data types for all variables in the model.
1212 N_agents: int
1213 Number of idiosyncratic agents in this simulation.
1214 T_total: int
1215 Total number of periods in these agents' model.
1216 T_sim: int
1217 Maximum number of periods that will be simulated, determining the size
1218 of the history arrays.
1219 T_age: int
1220 Period after which to automatically terminate an agent if they would
1221 survive past this period.
1222 stop_dead : bool
1223 Whether simulated agents who draw dead=True should actually cease acting.
1224 Default is True. Setting to False allows "cohort-style" simulation that
1225 will generate many agents that survive to old ages. In most cases, T_sim
1226 should not exceed T_age, unless the user really does want multiple succ-
1227 essive cohorts to be born and fully simulated.
1228 replace_dead : bool
1229 Whether simulated agents who are marked as dead should be replaced with
1230 newborns (default True) or simply cease acting without replacement (False).
1231 The latter option is useful for models with state-dependent mortality,
1232 to allow "cohort-style" simulation with the correct distribution of states
1233 for survivors at each age. Setting to False has no effect if stop_dead is True.
1234 periods: list[SimBlock]
1235 Ordered list of simulation blocks, each representing a period.
1236 twist : dict
1237 Dictionary that maps period t-1 variables to period t variables, as a
1238 relabeling "between" periods.
1239 initializer : SimBlock
1240 A special simulated block that should have *no* arrival variables, because
1241 it represents the initialization of "newborn" agents.
1242 data : dict
1243 Dictionary that holds *current* values of model variables.
1244 track_vars : list[str]
1245 List of names of variables whose history should be tracked in the simulation.
1246 history : dict
1247 Dictionary that holds the histories of tracked variables.
1248 """
1250 name: str = field(default="")
1251 description: str = field(default="")
1252 statement: str = field(default="", repr=False)
1253 comments: dict = field(default_factory=dict, repr=False)
1254 parameters: list[str] = field(default_factory=list, repr=False)
1255 distributions: list[str] = field(default_factory=list, repr=False)
1256 functions: list[str] = field(default_factory=list, repr=False)
1257 common: list[str] = field(default_factory=list, repr=False)
1258 types: dict = field(default_factory=dict, repr=False)
1259 N_agents: int = field(default=1)
1260 T_total: int = field(default=1, repr=False)
1261 T_sim: int = field(default=1)
1262 T_age: int = field(default=0, repr=False)
1263 stop_dead: bool = field(default=True)
1264 replace_dead: bool = field(default=True)
1265 periods: list[SimBlock] = field(default_factory=list, repr=False)
1266 twist: dict = field(default_factory=dict, repr=False)
1267 data: dict = field(default_factory=dict, repr=False)
1268 initializer: field(default_factory=SimBlock, repr=False)
1269 track_vars: list[str] = field(default_factory=list, repr=False)
1270 history: dict = field(default_factory=dict, repr=False)
1272 def simulate(self, T=None):
1273 """
1274 Simulates the model for T periods, including replacing dead agents as
1275 warranted and storing tracked variables in the history. If T is not
1276 specified, the agents are simulated for the entire T_sim periods.
1277 This is the primary user-facing simulation method.
1278 """
1279 if T is None:
1280 T = self.T_sim - self.t_sim # All remaining simulated periods
1281 if (T + self.t_sim) > self.T_sim:
1282 raise ValueError("Can't simulate more than T_sim periods!")
1284 # Execute the simulation loop for T periods
1285 for t in range(T):
1286 # Do the ordinary work for simulating a period
1287 self.sim_one_period()
1289 # Mark agents who have reached maximum allowable age
1290 if "dead" in self.data.keys() and self.T_age > 0:
1291 too_old = self.t_age == self.T_age
1292 self.data["dead"][too_old] = True
1294 # Record tracked variables and advance age
1295 self.store_tracked_vars()
1296 self.advance_age()
1298 # Handle death and replacement depending on simulation style
1299 if "dead" in self.data.keys() and self.stop_dead:
1300 self.mark_dead_agents()
1301 self.t_sim += 1
1303 def reset(self):
1304 """
1305 Completely reset this simulator back to its original state so that it
1306 can be run from scratch. This should allow it to generate the same results
1307 every single time the simulator is run (if nothing changes).
1308 """
1309 N = self.N_agents
1310 T = self.T_sim
1311 self.t_sim = 0 # Time index for the simulation
1313 # Reset the variable data and history arrays
1314 self.clear_data()
1315 self.history = {}
1316 for var in self.track_vars:
1317 self.history[var] = np.empty((T, N), dtype=self.types[var])
1319 # Reset all of the blocks / periods
1320 self.initializer.reset()
1321 for t in range(len(self.periods)):
1322 self.periods[t].reset()
1324 # Specify all agents as "newborns" assigned to the initializer block
1325 self.t_seq_bool_array = np.zeros((self.T_total, N), dtype=bool)
1326 self.t_age = -np.ones(N, dtype=int)
1328 def clear_data(self, skip=None):
1329 """
1330 Reset all current data arrays back to blank, other than those designated
1331 to be skipped, if any.
1333 Parameters
1334 ----------
1335 skip : [str] or None
1336 Names of variables *not* to be cleared from data. Default is None.
1338 Returns
1339 -------
1340 None
1341 """
1342 if skip is None:
1343 skip = []
1344 N = self.N_agents
1345 for var in self.types.keys():
1346 if var in skip:
1347 continue
1348 this_type = self.types[var]
1349 if this_type is float:
1350 self.data[var] = np.full((N,), np.nan)
1351 elif this_type is bool:
1352 self.data[var] = np.zeros((N,), dtype=bool)
1353 elif this_type is int:
1354 self.data[var] = np.zeros((N,), dtype=np.int32)
1355 elif this_type is complex:
1356 self.data[var] = np.full((N,), np.nan, dtype=complex)
1357 else:
1358 raise ValueError(
1359 "Type "
1360 + str(this_type)
1361 + " of variable "
1362 + var
1363 + " was not recognized!"
1364 )
1366 def mark_dead_agents(self):
1367 """
1368 Looks at the special data field "dead" and marks those agents for replacement.
1369 If no variable called "dead" has been defined, this is skipped.
1370 """
1371 who_died = self.data["dead"]
1372 self.t_seq_bool_array[:, who_died] = False
1373 self.t_age[who_died] = -1
1375 def create_newborns(self):
1376 """
1377 Calls the initializer to generate newborns where needed.
1378 """
1379 # Skip this step if there are no newborns
1380 newborns = self.t_age == -1
1381 if not np.any(newborns):
1382 return
1384 # Generate initial arrival variables
1385 N = np.sum(newborns)
1386 self.initializer.data = {} # by definition
1387 self.initializer.N = N
1388 self.initializer.run()
1390 # Set the initial arrival data for newborns and clear other variables
1391 init_arrival = self.periods[0].arrival
1392 for var in self.types:
1393 self.data[var][newborns] = (
1394 self.initializer.data[var]
1395 if var in init_arrival
1396 else np.empty(N, dtype=self.types[var])
1397 )
1399 # Set newborns' period to 0
1400 self.t_age[newborns] = 0
1401 self.t_seq_bool_array[0, newborns] = True
1403 def store_tracked_vars(self):
1404 """
1405 Record current values of requested variables in the history dictionary.
1406 """
1407 for var in self.track_vars:
1408 self.history[var][self.t_sim, :] = self.data[var]
1410 def advance_age(self):
1411 """
1412 Increments age for all agents, altering t_age and t_age_bool. Agents in
1413 the last period of the sequence will be assigned to the initial period.
1414 In a lifecycle model, those agents should be marked as dead and replaced
1415 in short order.
1416 """
1417 alive = self.t_age >= 0 # Don't age the dead
1418 self.t_age[alive] += 1
1419 X = self.t_seq_bool_array # For shorter typing on next line
1420 self.t_seq_bool_array[:, alive] = np.concatenate(
1421 (X[-1:, alive], X[:-1, alive]), axis=0
1422 )
1424 def sim_one_period(self):
1425 """
1426 Simulates one period of the model by advancing all agents one period.
1427 This includes creating newborns, but it does NOT include eliminating
1428 dead agents nor storing tracked results in the history. This method
1429 should usually not be called by a user, instead using simulate(1) if
1430 you want to run the model for exactly one period.
1431 """
1432 # Use the "twist" information to advance last period's end-of-period
1433 # information/values to be the arrival variables for this period. Then, for
1434 # each variable other than those brought in with the twist, wipe it clean.
1435 keepers = []
1436 for var_tm1 in self.twist:
1437 var_t = self.twist[var_tm1]
1438 keepers.append(var_t)
1439 self.data[var_t] = self.data[var_tm1].copy()
1440 self.clear_data(skip=keepers)
1442 # Create newborns first so the arrival vars exist. This should be done in
1443 # the first simulated period (t_sim=0) or if decedents should be replaced.
1444 if self.replace_dead or self.t_sim == 0:
1445 self.create_newborns()
1447 # Loop through ages and run the model on the appropriately aged agents
1448 for t in range(self.T_total):
1449 these = self.t_seq_bool_array[t, :]
1450 if not np.any(these):
1451 continue # Skip any "empty ages"
1452 this_period = self.periods[t]
1454 data_temp = {var: self.data[var][these] for var in this_period.arrival}
1455 this_period.data = data_temp
1456 this_period.N = np.sum(these)
1457 this_period.run()
1459 # Extract all of the variables from this period and write it to data
1460 for var in this_period.data.keys():
1461 self.data[var][these] = this_period.data[var]
1463 # Put time information into the data dictionary
1464 self.data["t_age"] = self.t_age.copy()
1465 self.data["t_seq"] = np.argmax(self.t_seq_bool_array, axis=0).astype(int)
1467 def make_transition_matrices(
1468 self, grid_specs, norm=None, fake_news_timing=False, for_t=None
1469 ):
1470 """
1471 Build Markov-style transition matrices for each period of the model, as
1472 well as the initial distribution of arrival variables for newborns.
1473 Stores results to the attributes of self as follows:
1475 - trans_arrays : List of Markov matrices for transitioning from the arrival
1476 state space in period t to the arrival state space in t+1.
1477 This transition includes death (and replacement).
1478 - newborn_dstn : Stochastic vector as a NumPy array, representing the distribution
1479 of arrival states for "newborns" who were just initialized.
1480 - state_grids : Nested list of tuples representing the arrival state space for
1481 each period. Each element corresponds to the discretized arrival
1482 state space point with the same index in trans_arrays (and
1483 newborn_dstn). Arrival states are ordered within a tuple in the
1484 same order as the model file. Linked from period[t].mesh.
1485 - outcome_arrays : List of dictionaries of arrays that cast from the arrival
1486 state space to the grid of outcome variables, for each period.
1487 Doing np.dot(state_dstn, outcome_arrays[t][var]) will yield
1488 the discretized distribution of that outcome variable. Linked
1489 from periods[t].matrices.
1490 - outcome_grids : List of dictionaries of discretized outcomes in each period.
1491 Keys are names of outcome variables, and entries are vectors
1492 of discretized values that the outcome variable can take on.
1493 Doing np.dot(np.dot(state_dstn, outcome_arrays[var]), outcome_grids[var])
1494 yields the *average* of that outcome in the population. Linked
1495 from periods[t].grids.
1497 Parameters
1498 ----------
1499 grid_specs : dict
1500 Dictionary of dictionaries with specifications for discretized grids
1501 of all variables of interest. If any arrival variables are omitted,
1502 they will be given a default trivial grid with one node at 0. This
1503 should only be done if that arrival variable is closely tied to the
1504 Harmenberg normalizing variable; see below. A grid specification must
1505 include a number of gridpoints N, and should also include a min and
1506 max if the variable is continuous. If the variable is discrete, the
1507 grid values are assumed to be 0,..,N.
1508 norm : str or None
1509 Name of the variable for which Harmenberg normalization should be
1510 applied, if any. This should be a variable that is directly drawn
1511 from a distribution, not a "downstream" variable.
1512 fake_news_timing : bool
1513 Indicator for whether this call is part of the "fake news" algorithm
1514 for constructing sequence space Jacobians (SSJs). This should only
1515 ever be set to True in that situation, which affects how mortality
1516 is handled between periods. In short, the simulator usually assumes
1517 that "newborns" start with t_seq=0, but during the fake news algorithm,
1518 that is not the case.
1519 for_t : list or None
1520 Optional list of time indices for which the matrices should be built.
1521 When not specified, all periods are constructed. The most common use
1522 for this arg is during the "fake news" algorithm for lifecycle models.
1524 Returns
1525 -------
1526 None
1527 """
1528 # Sort grid specifications into those needed by the initializer vs those
1529 # used by other blocks (ordinary periods)
1530 arrival = self.periods[0].arrival
1531 arrival_N = len(arrival)
1532 check_bool = np.zeros(arrival_N, dtype=bool)
1533 grid_specs_init_orig = {}
1534 grid_specs_other = {}
1535 for name in grid_specs.keys():
1536 if name in arrival:
1537 idx = arrival.index(name)
1538 check_bool[idx] = True
1539 grid_specs_init_orig[name] = copy(grid_specs[name])
1540 grid_specs_other[name] = copy(grid_specs[name])
1542 # Build the dictionary of arrival variables, making sure it's in the
1543 # same order as named self.arrival. For any arrival grids that are
1544 # not specified, make a dummy specification.
1545 grid_specs_init = {}
1546 for n in range(arrival_N):
1547 name = arrival[n]
1548 if check_bool[n]:
1549 grid_specs_init[name] = grid_specs_init_orig[name]
1550 continue
1551 dummy_grid_spec = {"N": 1}
1552 grid_specs_init[name] = dummy_grid_spec
1553 grid_specs_other[name] = dummy_grid_spec
1555 # Make the initial state distribution for newborns
1556 self.initializer.make_transition_matrices(grid_specs_init)
1557 self.newborn_dstn = self.initializer.init_dstn
1558 K = self.newborn_dstn.size
1560 # Make the period-by-period transition matrices
1561 these_t = range(len(self.periods)) if for_t is None else for_t
1562 for t in these_t:
1563 block = self.periods[t]
1564 block.make_transition_matrices(
1565 grid_specs_other, twist=self.twist, norm=norm
1566 )
1567 block.reset()
1569 # Extract the master transition matrices into a single list
1570 p2p_trans_arrays = [block.trans_array for block in self.periods]
1572 # Apply agent replacement to the last period of the model, representing
1573 # newborns filling in for decedents. This will usually only do anything
1574 # at all in "one period infinite horizon" models. If this is part of the
1575 # fake news algorithm for constructing SSJs, then replace decedents with
1576 # newborns in *all* periods, because model timing is funny in this case.
1577 if fake_news_timing:
1578 T_set = np.arange(len(self.periods)).tolist()
1579 else:
1580 T_set = [-1]
1581 newborn_dstn = np.reshape(self.newborn_dstn, (1, K))
1582 for t in T_set:
1583 if "dead" not in self.periods[t].matrices.keys():
1584 continue
1585 death_prbs = self.periods[t].matrices["dead"][:, 1]
1586 p2p_trans_arrays[t] *= np.tile(np.reshape(1 - death_prbs, (K, 1)), (1, K))
1587 p2p_trans_arrays[t] += np.reshape(death_prbs, (K, 1)) * newborn_dstn
1589 # Store the transition arrays as attributes of self
1590 self.trans_arrays = p2p_trans_arrays
1592 # Build and store lists of state meshes, outcome arrays, and outcome grids
1593 self.state_grids = [self.periods[t].mesh for t in range(len(self.periods))]
1594 self.outcome_grids = [self.periods[t].grids for t in range(len(self.periods))]
1595 self.outcome_arrays = [
1596 self.periods[t].matrices for t in range(len(self.periods))
1597 ]
1599 def find_steady_state(self):
1600 """
1601 Calculates the steady state distribution of arrival states for a "one period
1602 infinite horizon" model, storing the result to the attribute steady_state_dstn.
1603 Should only be run after make_transition_matrices(), and only if T_total = 1
1604 and the model is infinite horizon.
1605 """
1606 if self.T_total != 1:
1607 raise ValueError(
1608 "This method currently only works with one period infinite horizon problems."
1609 )
1611 # Find the eigenvector associated with the largest eigenvalue of the
1612 # infinite horizon transition matrix. The largest eigenvalue *should*
1613 # be 1 for any Markov matrix, but double check to be sure.
1614 trans_T = csr_matrix(self.trans_arrays[0].transpose())
1615 v, V = eigs(trans_T, k=1)
1616 if not np.isclose(v[0], 1.0):
1617 raise ValueError(
1618 "The largest eigenvalue of the transition matrix isn't close to 1!"
1619 )
1621 # Normalize that eigenvector and make sure its real, then store it
1622 D = V[:, 0]
1623 SS_dstn = (D / np.sum(D)).real
1624 self.steady_state_dstn = SS_dstn
1626 def get_long_run_average(self, var):
1627 """
1628 Calculate and return the long run / steady state population average of
1629 one named variable. Should only be run after find_steady_state().
1631 Parameters
1632 ----------
1633 var : str
1634 Name of the variable for which to calculate the long run average.
1636 Returns
1637 -------
1638 var_mean : float
1639 Long run / steady state population average of the variable.
1640 """
1641 if not hasattr(self, "steady_state_dstn"):
1642 raise ValueError("This method can only be run after find_steady_state()!")
1644 dstn = self.steady_state_dstn
1645 array = self.outcome_arrays[0][var]
1646 grid = self.outcome_grids[0][var]
1648 var_dstn = np.dot(dstn, array)
1649 var_mean = np.dot(var_dstn, grid)
1650 return var_mean
1652 def find_target_state(self, target_var, bounds=None, N=201, tol=1e-8, **kwargs):
1653 """
1654 Find the "target" level of a state variable: the value such that the expectation
1655 of next period's state is the same value (when following the policy function),
1656 *and* is locally stable (pushes up from below and down from above when nearby).
1657 Only works for standard infinite horizon models with a single endogenous state
1658 variable. Other variables whose values must be known (e.g. exogenously evolving
1659 states) can also be specified.
1661 The search procedure is to first examine a grid of candidates on the bounds,
1662 calculating E[Delta x] for state x, and then perform a local search for each
1663 interval where it flips from positive to negative.
1665 This procedure ignores mortality entirely. It represents a stable or target
1666 level conditional on the agent continuing from t to t+1.
1668 If additional information must be known, other model variables can be passed
1669 as keyword arguments, e.g. pLvl=1.0. This feature is used for exogenous state
1670 variables, such as persistent income pLvl in the GenIncProcess model. The user
1671 simply passes its mean (central) value, which is easily known in advance.
1673 Parameters
1674 ----------
1675 target_var : str
1676 Name of the state variable of interest.
1677 bounds : [float], optional
1678 Upper and lower boundaries for the target search. If not provided, defaults
1679 to [0.0, 100.0].
1680 N : int, optional
1681 Number of values of the variable of interest to test on the initial pass.
1682 If not provided, defaults to 201. This affects the "resolution" when there
1683 are multiple possible target levels (uncommon).
1684 tol : float, optional
1685 Maximum acceptable deviation from true target E[Delta x] = 0 to be accepted.
1686 If not specified, defaults to 1e-8.
1688 Returns
1689 -------
1690 state_targ : [float]
1691 List of target_var x values such that E[\Delta x] = 0, which can be empty.
1692 """
1693 if self.T_total != 1:
1694 raise ValueError(
1695 "This method currently only works with one period infinite horizon problems."
1696 )
1697 bounds = bounds or [0.0, 100.0]
1698 state_grid = np.linspace(bounds[0], bounds[1], num=N)
1700 # Process keyword arguments into a dictionary of fixed values
1701 period = self.periods[0]
1702 event_count = len(period.events)
1703 fixed = {}
1704 var_names = list(self.types.keys())
1705 for name in kwargs:
1706 if name in var_names:
1707 fixed[name] = kwargs[name]
1708 else:
1709 raise ValueError(
1710 "Could not find a model variable called " + name + " to hold fixed!"
1711 )
1713 # Find the event index at which to start and stop the quasi-simulation
1714 var_names = [target_var] + list(fixed.keys())
1715 var_count = len(var_names)
1716 found_count = 0
1717 found = var_count * [False]
1718 j = 0
1719 if target_var not in period.arrival:
1720 while (found_count < var_count) and (j < event_count):
1721 event = period.events[j]
1722 assigns = event.assigns
1723 for i in range(var_count):
1724 if (var_names[i] in assigns) and (not found[i]):
1725 found[i] = True
1726 found_count += 1
1727 j += 1
1728 if not np.all(found):
1729 raise ValueError(
1730 "Could not find events that assign target variable and all fixed variables!"
1731 )
1732 idx0 = j # Event index where the quasi-sim should start and stop
1734 # Construct the starting information set for the quasi-simulation
1735 data_init = {}
1736 trivial_vars = []
1738 # Assign dummy data for all vars assigned prior to start/stop
1739 for var in period.arrival:
1740 data_init[var] = np.zeros(N, dtype=int) # dummy data
1741 trivial_vars.append(var)
1742 for j in range(idx0):
1743 event = period.events[j]
1744 assigns = event.assigns
1745 for var in assigns:
1746 data_init[var] = np.zeros(N, dtype=int) # dummy data
1747 trivial_vars.append(var)
1749 # Assign fixed data and the grid of candidate target values
1750 for key in fixed.keys():
1751 data_init[key] = fixed[key] * np.ones(N)
1752 data_init[target_var] = state_grid
1754 # Run the quasi-simulation on the initial grid of states
1755 period.run_quasi_sim(data_init, j0=idx0, twist=self.twist)
1756 origins = period.origin_array
1757 data_final = period.data[target_var]
1758 pmv_final = period.data["pmv_"]
1760 # Calculate mean value of next period's state at each point in the grid
1761 E_state_next = np.empty(N)
1762 for n in range(N):
1763 these = origins == n
1764 E_state_next[n] = np.dot(pmv_final[these], data_final[these])
1765 E_delta_state = E_state_next - state_grid # expected change in state
1767 # Find indices in the grid where E[\Delta x] flips from positive to negative
1768 sign = E_delta_state > 0.0
1769 flip = np.logical_and(sign[:-1], np.logical_not(sign[1:]))
1770 flip_idx = np.argwhere(flip).flatten()
1771 if flip_idx.size == 0:
1772 state_targ = []
1773 return state_targ
1775 # Reduce the fixed values in data_init to single valued vectors
1776 for var in trivial_vars:
1777 data_init[var] = np.array([0])
1778 for key in fixed.keys():
1779 data_init[key] = np.array([fixed[key]])
1781 # Define a function that can be used to search for states where E[\Delta x] = 0
1782 def delta_zero_func(x):
1783 data_init[target_var] = np.array([x])
1784 period.run_quasi_sim(data_init, j0=idx0, twist=self.twist)
1785 data_final = period.data[target_var]
1786 pmv_final = period.data["pmv_"]
1787 E_delta = np.dot(pmv_final, data_final) - x
1788 return E_delta
1790 # For each segment index with a sign flip for E[\Delta x], find x_targ
1791 state_targ = []
1792 for i in flip_idx:
1793 bot = state_grid[i]
1794 top = state_grid[i + 1]
1795 x_targ = brentq(delta_zero_func, bot, top, xtol=tol, rtol=tol)
1796 state_targ.append(x_targ)
1798 # Return the output
1799 return state_targ
1801 def simulate_cohort_by_grids(
1802 self,
1803 outcomes,
1804 T_max=None,
1805 calc_dstn=False,
1806 calc_avg=True,
1807 from_dstn=None,
1808 ):
1809 """
1810 Generate a simulated "cohort style" history for this type of agents using
1811 discretized grid methods. Can only be run after running make_transition_matrices().
1812 Starting from the distribution of states at birth, the population is moved
1813 forward in time via the transition matrices, and the distribution and/or
1814 average of specified outcomes are stored in the dictionary attributes
1815 history_dstn and history_avg respectively.
1817 Parameters
1818 ----------
1819 outcomes : str or [str]
1820 Names of one or more outcome variables to be tracked during the grid
1821 simulation. Each named variable should have an outcome grid specified
1822 when make_transition_matrices() was called, whether explicitly or
1823 implicitly. The existence of these grids is checked as a first step.
1824 T_max : int or None
1825 If specified, the number of periods of the model to actually generate
1826 output for. If not specified, all periods are run.
1827 calc_dstn : bool
1828 Whether outcome distributions should be stored in the dictionary
1829 attribute history_dstn. The default is False.
1830 calc_avg : bool
1831 Whether outcome averages should be stored in the dictionary attribute
1832 history_avg. The default is True.
1833 from_dstn : np.array or None
1834 Optional initial distribution of arrival states. If not specified, the
1835 newborn distribution in the initializer is assumed to be used.
1837 Returns
1838 -------
1839 None
1840 """
1841 # First, verify that newborn and transition matrices exist for all periods
1842 if not hasattr(self, "newborn_dstn"):
1843 raise ValueError(
1844 "The newborn state distribution does not exist; make_transition_matrices() must be run before grid simulations!"
1845 )
1846 if T_max is None:
1847 T_max = self.T_total
1848 T_max = np.minimum(T_max, self.T_total)
1849 if not hasattr(self, "trans_arrays"):
1850 raise ValueError(
1851 "The transition arrays do not exist; make_transition_matrices() must be run before grid simulations!"
1852 )
1853 if len(self.trans_arrays) < T_max:
1854 raise ValueError(
1855 "There are somehow fewer elements of trans_array than there should be!"
1856 )
1857 if not (calc_dstn or calc_avg):
1858 return # No work actually requested, we're done here
1860 # Initialize generated output as requested
1861 if isinstance(outcomes, str):
1862 outcomes = [outcomes]
1863 if calc_dstn:
1864 history_dstn = {}
1865 for name in outcomes: # List will be concatenated to array at end
1866 history_dstn[name] = [] # if all distributions are same size
1867 if calc_avg:
1868 history_avg = {}
1869 for name in outcomes:
1870 history_avg[name] = np.empty(T_max)
1872 # Initialize the state distribution
1873 current_dstn = (
1874 self.newborn_dstn.copy() if from_dstn is None else from_dstn.copy()
1875 )
1876 state_dstn_by_age = []
1878 # Loop over requested periods of this agent type's model
1879 for t in range(T_max):
1880 state_dstn_by_age.append(current_dstn)
1882 # Calculate outcome distributions and averages as requested
1883 for name in outcomes:
1884 this_outcome = self.periods[t].matrices[name].transpose()
1885 this_dstn = np.dot(this_outcome, current_dstn)
1886 if calc_dstn:
1887 history_dstn[name].append(this_dstn)
1888 if calc_avg:
1889 this_grid = self.periods[t].grids[name]
1890 history_avg[name][t] = np.dot(this_dstn, this_grid)
1892 # Advance the distribution to the next period
1893 current_dstn = np.dot(self.trans_arrays[t].transpose(), current_dstn)
1895 # Reshape the distribution histories if possible
1896 if calc_dstn:
1897 for name in outcomes:
1898 dstn_sizes = np.array([dstn.size for dstn in history_dstn[name]])
1899 if np.all(dstn_sizes == dstn_sizes[0]):
1900 history_dstn[name] = np.stack(history_dstn[name], axis=1)
1902 # Store results as attributes of self
1903 self.state_dstn_by_age = state_dstn_by_age
1904 if calc_dstn:
1905 self.history_dstn = history_dstn
1906 if calc_avg:
1907 self.history_avg = history_avg
1909 def describe_model(self, display=True):
1910 """
1911 Convenience method that prints model information to screen.
1912 """
1913 # Make a twist statement
1914 twist_statement = ""
1915 for var_tm1 in self.twist.keys():
1916 var_t = self.twist[var_tm1]
1917 new_line = var_tm1 + "[t-1] <---> " + var_t + "[t]\n"
1918 twist_statement += new_line
1920 # Assemble the overall model statement
1921 output = ""
1922 output += "----------------------------------\n"
1923 output += "%%%%% INITIALIZATION AT BIRTH %%%%\n"
1924 output += "----------------------------------\n"
1925 output += self.initializer.statement
1926 output += "----------------------------------\n"
1927 output += "%%%% DYNAMICS WITHIN PERIOD t %%%%\n"
1928 output += "----------------------------------\n"
1929 output += self.statement
1930 output += "----------------------------------\n"
1931 output += "%%%%%%% RELABELING / TWIST %%%%%%%\n"
1932 output += "----------------------------------\n"
1933 output += twist_statement
1934 output += "-----------------------------------"
1936 # Return or print the output
1937 if display:
1938 print(output)
1939 return
1940 else:
1941 return output
1943 def describe_symbols(self, display=True):
1944 """
1945 Convenience method that prints symbol information to screen.
1946 """
1947 # Get names and types
1948 symbols_lines = []
1949 comments = []
1950 for key in self.comments.keys():
1951 comments.append(self.comments[key])
1953 # Get type of object
1954 if key in self.types.keys():
1955 this_type = str(self.types[key].__name__)
1956 elif key in self.distributions:
1957 this_type = "dstn"
1958 elif key in self.parameters:
1959 this_type = "param"
1960 elif key in self.functions:
1961 this_type = "func"
1963 # Add tags
1964 if key in self.common:
1965 this_type += ", common"
1966 # if key in self.solution:
1967 # this_type += ', solution'
1968 this_line = key + " (" + this_type + ")"
1969 symbols_lines.append(this_line)
1971 # Add comments, aligned
1972 symbols_text = ""
1973 longest = np.max([len(this) for this in symbols_lines])
1974 for j in range(len(symbols_lines)):
1975 line = symbols_lines[j]
1976 comment = comments[j]
1977 L = len(line)
1978 pad = (longest + 1) - L
1979 symbols_text += line + pad * " " + ": " + comment + "\n"
1981 # Return or print the output
1982 output = symbols_text
1983 if display:
1984 print(output)
1985 return
1986 else:
1987 return output
1989 def describe(self, symbols=True, model=True, display=True):
1990 """
1991 Convenience method for showing all information about the model.
1992 """
1993 # Asssemble the requested output
1994 output = self.name + ": " + self.description + "\n"
1995 if symbols or model:
1996 output += "\n"
1997 if symbols:
1998 output += "----------------------------------\n"
1999 output += "%%%%%%%%%%%%% SYMBOLS %%%%%%%%%%%%\n"
2000 output += "----------------------------------\n"
2001 output += self.describe_symbols(display=False)
2002 if model:
2003 output += self.describe_model(display=False)
2004 if symbols and not model:
2005 output += "----------------------------------"
2007 # Return or print the output
2008 if display:
2009 print(output)
2010 return
2011 else:
2012 return output
2015def _parse_model_fields(model, common_override=None):
2016 """
2017 Extract the top-level fields from a parsed model dictionary.
2019 Uses dict.get() with safe defaults rather than try/except for each field,
2020 so that missing keys silently receive their default values.
2022 Parameters
2023 ----------
2024 model : dict
2025 Parsed YAML model dictionary.
2026 common_override : list or None
2027 If provided, overrides the model's 'common' field entirely.
2029 Returns
2030 -------
2031 model_name : str
2032 Name of the model, or 'DEFAULT_NAME' if absent.
2033 description : str
2034 Human-readable description, or a placeholder if absent.
2035 variables : list
2036 Declared variable lines from model['symbols']['variables'].
2037 twist : dict
2038 Intertemporal twist mapping, or empty dict if absent.
2039 common : list
2040 Variables shared across all agents.
2041 arrival : list
2042 Explicitly listed arrival variable names.
2043 """
2044 symbols = model.get("symbols", {})
2045 model_name = model.get("name", "DEFAULT_NAME")
2046 description = model.get("description", "(no description provided)")
2047 variables = symbols.get("variables", [])
2048 twist = model.get("twist", {})
2049 arrival = symbols.get("arrival", [])
2050 if common_override is not None:
2051 common = common_override
2052 else:
2053 common = symbols.get("common", [])
2054 return model_name, description, variables, twist, common, arrival
2057def _build_periods(
2058 template, agent, content, solution, offset, time_vary, time_inv, RNG, T_seq, T_cycle
2059):
2060 """
2061 Construct the list of per-period SimBlock copies for an AgentSimulator.
2063 For each period in the solution sequence, a deep copy of the template block
2064 is made and populated with the appropriate parameter data drawn from the agent.
2066 Parameters
2067 ----------
2068 template : SimBlock
2069 Template block with structure but no parameter values.
2070 agent : AgentType
2071 The agent whose solution and time-varying attributes supply parameter values.
2072 content : dict
2073 Keys are the names of objects needed by the template block.
2074 solution : list
2075 Names of objects that come from the agent's solution attribute.
2076 offset : list
2077 Names of time-varying objects whose index is shifted back by one period.
2078 time_vary : list
2079 Names of objects that vary across periods (drawn from agent attributes).
2080 time_inv : list
2081 Names of objects that are time-invariant (same across all periods).
2082 RNG : np.random.Generator
2083 Random number generator used to assign unique seeds to MarkovEvents.
2084 T_seq : int
2085 Number of periods in the solution sequence.
2086 T_cycle : int
2087 Number of periods per cycle (used to wrap the time index).
2089 Returns
2090 -------
2091 periods : list[SimBlock]
2092 Fully populated list of period blocks, one per entry in the solution.
2093 """
2094 # Build the time-invariant parameter dictionary once
2095 time_inv_dict = {}
2096 for name in content:
2097 if name in time_inv:
2098 if not hasattr(agent, name):
2099 raise ValueError(
2100 "Couldn't get a value for time-invariant object "
2101 + name
2102 + ": attribute does not exist on the agent."
2103 )
2104 time_inv_dict[name] = getattr(agent, name)
2106 periods = []
2107 t_cycle = 0
2108 for t in range(T_seq):
2109 # Make a fresh copy of the template period
2110 new_period = deepcopy(template)
2112 # Make sure each period's events have unique seeds; this is only for MarkovEvents
2113 for event in new_period.events:
2114 if hasattr(event, "seed"):
2115 event.seed = RNG.integers(0, 2**31 - 1)
2117 # Make the parameter dictionary for this period
2118 new_param_dict = deepcopy(time_inv_dict)
2119 for name in content:
2120 if name in solution:
2121 if type(agent.solution[t]) is dict:
2122 new_param_dict[name] = agent.solution[t][name]
2123 else:
2124 new_param_dict[name] = getattr(agent.solution[t], name)
2125 elif name in time_vary:
2126 s = (t_cycle - 1) if name in offset else t_cycle
2127 attr = getattr(agent, name, None)
2128 if attr is None:
2129 raise ValueError(
2130 "Couldn't get a value for time-varying object "
2131 + name
2132 + ": attribute does not exist on the agent."
2133 )
2134 try:
2135 new_param_dict[name] = attr[s]
2136 except (IndexError, TypeError):
2137 raise ValueError(
2138 "Couldn't get a value for time-varying object "
2139 + name
2140 + " at time index "
2141 + str(s)
2142 + "!"
2143 )
2144 elif name in time_inv:
2145 continue
2146 else:
2147 raise ValueError(
2148 "The object called "
2149 + name
2150 + " is not named in time_inv nor time_vary!"
2151 )
2153 # Fill in content for this period, then add it to the list
2154 new_period.content = new_param_dict
2155 new_period.distribute_content()
2156 periods.append(new_period)
2158 # Advance time according to the cycle
2159 t_cycle += 1
2160 if t_cycle == T_cycle:
2161 t_cycle = 0
2163 return periods
2166def make_simulator_from_agent(agent, stop_dead=True, replace_dead=True, common=None):
2167 """
2168 Build an AgentSimulator instance based on an AgentType instance. The AgentType
2169 should have its model attribute defined so that it can be parsed and translated
2170 into the simulator structure. The names of objects in the model statement
2171 should correspond to attributes of the AgentType.
2173 Parameters
2174 ----------
2175 agent : AgentType
2176 Agents for whom a new simulator is to be constructed.
2177 stop_dead : bool
2178 Whether simulated agents who draw dead=True should actually cease acting.
2179 Default is True. Setting to False allows "cohort-style" simulation that
2180 will generate many agents that survive to old ages. In most cases, T_sim
2181 should not exceed T_age, unless the user really does want multiple succ-
2182 essive cohorts to be born and fully simulated.
2183 replace_dead : bool
2184 Whether simulated agents who are marked as dead should be replaced with
2185 newborns (default True) or simply cease acting without replacement (False).
2186 The latter option is useful for models with state-dependent mortality,
2187 to allow "cohort-style" simulation with the correct distribution of states
2188 for survivors at each age. Setting False has no effect if stop_dead is True.
2189 common : [str] or None
2190 List of random variables that should be treated as commonly shared across
2191 all agents, rather than idiosyncratically drawn. If this is provided, it
2192 will override the model defaults.
2194 Returns
2195 -------
2196 new_simulator : AgentSimulator
2197 A simulator structure based on the agents.
2198 """
2199 # Read the model statement into a dictionary, and get names of attributes
2200 if hasattr(agent, "model_statement"): # look for a custom model statement
2201 model_statement = copy(agent.model_statement)
2202 else: # otherwise use the default model file
2203 with importlib.resources.open_text("HARK.models", agent.model_file) as f:
2204 model_statement = f.read()
2205 f.close()
2206 model = yaml.safe_load(model_statement)
2207 time_vary = agent.time_vary
2208 time_inv = agent.time_inv
2209 cycles = agent.cycles
2210 T_age = agent.T_age
2211 comments = {}
2212 RNG = agent.RNG # this is only for generating seeds for MarkovEvents
2214 # Extract basic fields from the model using helper
2215 model_name, description, variables, twist, common, arrival = _parse_model_fields(
2216 model, common_override=common
2217 )
2219 # Make a dictionary of declared data types and add comments
2220 types = {}
2221 for var_line in variables: # Loop through declared variables
2222 var_name, var_type, flags, desc = parse_declaration_for_parts(var_line)
2223 if var_type is not None:
2224 try:
2225 var_type = eval(var_type)
2226 except:
2227 raise ValueError(
2228 "Couldn't understand type "
2229 + var_type
2230 + " for declared variable "
2231 + var_name
2232 + "!"
2233 )
2234 else:
2235 var_type = float
2236 types[var_name] = var_type
2237 comments[var_name] = desc
2238 if ("arrival" in flags) and (var_name not in arrival):
2239 arrival.append(var_name)
2240 if ("common" in flags) and (var_name not in common):
2241 common.append(var_name)
2243 # Make a blank "template" period with structure but no data
2244 template_period, information, offset, solution, block_comments = (
2245 make_template_block(model, arrival, common)
2246 )
2247 comments.update(block_comments)
2249 # Make the agent initializer, without parameter values (etc)
2250 initializer, init_info = make_initializer(model, arrival, common)
2252 # Extract basic fields from the template period and model
2253 statement = template_period.statement
2254 content = template_period.content
2256 # Get the names of parameters, functions, and distributions
2257 parameters = []
2258 functions = []
2259 distributions = []
2260 for key in information.keys():
2261 val = information[key]
2262 if val is None:
2263 parameters.append(key)
2264 elif type(val) is NullFunc:
2265 functions.append(key)
2266 elif type(val) is Distribution:
2267 distributions.append(key)
2269 # Loop through variables that appear in the model block but were undeclared
2270 for var in information.keys():
2271 if var in types.keys():
2272 continue
2273 this = information[var]
2274 if (this is None) or (type(this) is Distribution) or (type(this) is NullFunc):
2275 continue
2276 types[var] = float
2277 comments[var] = ""
2278 if "dead" in types.keys():
2279 types["dead"] = bool
2280 comments["dead"] = "whether agent died this period"
2281 types["t_seq"] = int
2282 types["t_age"] = int
2283 comments["t_seq"] = "which period of the sequence the agent is on"
2284 comments["t_age"] = "how many periods the agent has already lived for"
2286 # Make a dictionary for the initializer and distribute information
2287 init_dict = {}
2288 for name in init_info.keys():
2289 try:
2290 init_dict[name] = getattr(agent, name)
2291 except:
2292 raise ValueError(
2293 "Couldn't get a value for initializer object " + name + "!"
2294 )
2295 initializer.content = init_dict
2296 initializer.distribute_content()
2298 # Create a list of periods, pulling appropriate data from the agent for each one
2299 T_seq = len(agent.solution) # Number of periods in the solution sequence
2300 T_cycle = agent.T_cycle
2301 periods = _build_periods(
2302 template_period,
2303 agent,
2304 content,
2305 solution,
2306 offset,
2307 time_vary,
2308 time_inv,
2309 RNG,
2310 T_seq,
2311 T_cycle,
2312 )
2314 # Calculate maximum age
2315 if T_age is None:
2316 T_age = 0
2317 if cycles > 0:
2318 T_age_max = T_seq - 1
2319 T_age = np.minimum(T_age_max, T_age)
2320 try:
2321 T_sim = agent.T_sim
2322 except:
2323 T_sim = 0 # very boring default!
2325 # Make and return the new simulator
2326 new_simulator = AgentSimulator(
2327 name=model_name,
2328 description=description,
2329 statement=statement,
2330 comments=comments,
2331 parameters=parameters,
2332 functions=functions,
2333 distributions=distributions,
2334 common=common,
2335 types=types,
2336 N_agents=agent.AgentCount,
2337 T_total=T_seq,
2338 T_sim=T_sim,
2339 T_age=T_age,
2340 stop_dead=stop_dead,
2341 replace_dead=replace_dead,
2342 periods=periods,
2343 twist=twist,
2344 initializer=initializer,
2345 track_vars=agent.track_vars,
2346 )
2347 new_simulator.solution = solution # this is for use by SSJ constructor
2348 return new_simulator
2351def _extract_symbol_class(
2352 model, class_name, constructor, validator_msg, offset, solution, comments
2353):
2354 """
2355 Parse and collect one class of symbols (parameters, functions, or distributions).
2357 Handles the near-identical pattern repeated for each symbol class:
2358 iterate over declaration lines, build the result dict, record comments,
2359 and append names to the offset and solution lists as flagged.
2361 Parameters
2362 ----------
2363 model : dict
2364 Parsed model dictionary containing a 'symbols' sub-dict.
2365 class_name : str
2366 Key within model['symbols'] to look up ('parameters', 'functions', or
2367 'distributions').
2368 constructor : callable or None
2369 Called with no arguments to create each entry's value. Pass None for
2370 parameters (which use None as their placeholder value).
2371 validator_msg : str or None
2372 If provided, the expected datatype string (e.g. 'func' or 'dstn'). When a
2373 declaration carries a different datatype, a ValueError is raised. Pass None
2374 to skip validation (used for parameters).
2375 offset : list
2376 Accumulated list of offset-flagged names; extended in place.
2377 solution : list
2378 Accumulated list of solution-flagged names; extended in place.
2379 comments : dict
2380 Accumulated comment strings keyed by name; updated in place.
2382 Returns
2383 -------
2384 result : dict
2385 Mapping from symbol name to its constructed value (or None for parameters).
2386 """
2387 result = {}
2388 symbols = model.get("symbols", {})
2389 if class_name not in symbols:
2390 return result
2391 lines = symbols[class_name]
2392 for line in lines:
2393 name, datatype, flags, desc = parse_declaration_for_parts(line)
2394 if (
2395 (validator_msg is not None)
2396 and (datatype is not None)
2397 and (datatype != validator_msg)
2398 ):
2399 raise ValueError(
2400 name
2401 + " was declared as a "
2402 + class_name[:-1]
2403 + ", but given a different datatype!"
2404 )
2405 result[name] = constructor() if constructor is not None else None
2406 comments[name] = desc
2407 if ("offset" in flags) and (name not in offset):
2408 offset.append(name)
2409 if ("solution" in flags) and (name not in solution):
2410 solution.append(name)
2411 return result
2414def make_template_block(model, arrival=None, common=None):
2415 """
2416 Construct a new SimBlock object as a "template" of the model block. It has
2417 events and reference information, but no values filled in.
2419 Parameters
2420 ----------
2421 model : dict
2422 Dictionary with model block information, probably read in as a yaml.
2423 arrival : [str] or None
2424 List of arrival variables that were flagged or explicitly listed.
2425 common : [str] or None
2426 List of variables that are common or shared across all agents, rather
2427 than idiosyncratically drawn.
2429 Returns
2430 -------
2431 template_block : SimBlock
2432 A "template" of this model block, with no parameters (etc) on it.
2433 info : dict
2434 Dictionary of model objects that were referenced within the block. Keys
2435 are object names and entries reveal what kind of object they are:
2436 - None --> parameter
2437 - 0 --> outcome/data variable (including arrival variables)
2438 - NullFunc --> function
2439 - Distribution --> distribution
2440 offset : [str]
2441 List of object names that are offset in time by one period.
2442 solution : [str]
2443 List of object names that are part of the model solution.
2444 comments : dict
2445 Dictionary of comments included with declared functions, distributions,
2446 and parameters.
2447 """
2448 if arrival is None:
2449 arrival = []
2450 if common is None:
2451 common = []
2453 # Extract explicitly listed metadata using dict.get for safe defaults
2454 symbols = model.get("symbols", {})
2455 name = model.get("name", None)
2456 offset = symbols.get("offset", [])
2457 solution = symbols.get("solution", [])
2459 # Extract parameters, functions, and distributions using the shared helper
2460 comments = {}
2461 parameters = _extract_symbol_class(
2462 model, "parameters", None, None, offset, solution, comments
2463 )
2464 functions = _extract_symbol_class(
2465 model, "functions", NullFunc, "func", offset, solution, comments
2466 )
2467 distributions = _extract_symbol_class(
2468 model, "distributions", Distribution, "dstn", offset, solution, comments
2469 )
2471 # Combine those dictionaries into a single "information" dictionary, which
2472 # represents objects available *at that point* in the dynamic block
2473 content = parameters.copy()
2474 content.update(functions)
2475 content.update(distributions)
2476 info = deepcopy(content)
2477 for var in arrival:
2478 info[var] = 0 # Mark as a state variable
2480 # Parse the model dynamics
2481 dynamics = format_block_statement(model["dynamics"])
2483 # Make the list of ordered events
2484 events = []
2485 names_used_in_dynamics = []
2486 for line in dynamics:
2487 # Make the new event and add it to the list
2488 new_event, names_used = make_new_event(line, info)
2489 events.append(new_event)
2490 names_used_in_dynamics += names_used
2492 # Add newly assigned variables to the information set
2493 for var in new_event.assigns:
2494 if var in info.keys():
2495 raise ValueError(var + " is assigned, but already exists!")
2496 info[var] = 0
2498 # If any assigned variables are common, mark the event as common
2499 for var in new_event.assigns:
2500 if var in common:
2501 new_event.common = True
2502 break # No need to check further
2504 # Remove content that is never referenced within the dynamics
2505 delete_these = []
2506 for name in content.keys():
2507 if name not in names_used_in_dynamics:
2508 delete_these.append(name)
2509 for name in delete_these:
2510 del content[name]
2512 # Make a single string model statement
2513 statement = ""
2514 longest = np.max([len(event.statement) for event in events])
2515 for event in events:
2516 this_statement = event.statement
2517 L = len(this_statement)
2518 pad = (longest + 1) - L
2519 statement += this_statement + pad * " " + ": " + event.description + "\n"
2521 # Make a description for the template block
2522 if name is None:
2523 description = "template block for unnamed block"
2524 else:
2525 description = "template block for " + name
2527 # Make and return the new SimBlock
2528 template_block = SimBlock(
2529 description=description,
2530 arrival=arrival,
2531 content=content,
2532 statement=statement,
2533 events=events,
2534 )
2535 return template_block, info, offset, solution, comments
2538def make_initializer(model, arrival=None, common=None):
2539 """
2540 Construct a new SimBlock object to be the agent initializer, based on the
2541 model dictionary. It has structure and events, but no parameters (etc).
2543 Parameters
2544 ----------
2545 model : dict
2546 Dictionary with model initializer information, probably read in as a yaml.
2547 arrival : [str]
2548 List of arrival variables that were flagged or explicitly listed.
2550 Returns
2551 -------
2552 initializer : SimBlock
2553 A "template" of this model block, with no parameters (etc) on it.
2554 init_requires : dict
2555 Dictionary of model objects that are needed by the initializer to run.
2556 Keys are object names and entries reveal what kind of object they are:
2557 - None --> parameter
2558 - 0 --> outcome variable (these should include all arrival variables)
2559 - NullFunc --> function
2560 - Distribution --> distribution
2561 """
2562 if arrival is None:
2563 arrival = []
2564 if common is None:
2565 common = []
2566 try:
2567 name = model["name"]
2568 except:
2569 name = "DEFAULT_NAME"
2571 # Extract parameters, functions, and distributions
2572 parameters = {}
2573 if "parameters" in model["symbols"].keys():
2574 param_lines = model["symbols"]["parameters"]
2575 for line in param_lines:
2576 param_name, datatype, flags, desc = parse_declaration_for_parts(line)
2577 parameters[param_name] = None
2579 functions = {}
2580 if "functions" in model["symbols"].keys():
2581 func_lines = model["symbols"]["functions"]
2582 for line in func_lines:
2583 func_name, datatype, flags, desc = parse_declaration_for_parts(line)
2584 if (datatype is not None) and (datatype != "func"):
2585 raise ValueError(
2586 func_name
2587 + " was declared as a function, but given a different datatype!"
2588 )
2589 functions[func_name] = NullFunc()
2591 distributions = {}
2592 if "distributions" in model["symbols"].keys():
2593 dstn_lines = model["symbols"]["distributions"]
2594 for line in dstn_lines:
2595 dstn_name, datatype, flags, desc = parse_declaration_for_parts(line)
2596 if (datatype is not None) and (datatype != "dstn"):
2597 raise ValueError(
2598 dstn_name
2599 + " was declared as a distribution, but given a different datatype!"
2600 )
2601 distributions[dstn_name] = Distribution()
2603 # Combine those dictionaries into a single "information" dictionary
2604 content = parameters.copy()
2605 content.update(functions)
2606 content.update(distributions)
2607 info = deepcopy(content)
2609 # Parse the initialization routine
2610 initialize = format_block_statement(model["initialize"])
2612 # Make the list of ordered events
2613 events = []
2614 names_used_in_initialize = [] # this doesn't actually get used
2615 for line in initialize:
2616 # Make the new event and add it to the list
2617 new_event, names_used = make_new_event(line, info)
2618 events.append(new_event)
2619 names_used_in_initialize += names_used
2621 # Add newly assigned variables to the information set
2622 for var in new_event.assigns:
2623 if var in info.keys():
2624 raise ValueError(var + " is assigned, but already exists!")
2625 info[var] = 0
2627 # If any assigned variables are common, mark the event as common
2628 for var in new_event.assigns:
2629 if var in common:
2630 new_event.common = True
2631 break # No need to check further
2633 # Verify that all arrival variables were created in the initializer
2634 for var in arrival:
2635 if var not in info.keys():
2636 raise ValueError(
2637 "The arrival variable " + var + " was not set in the initialize block!"
2638 )
2640 # Make a blank dictionary with information the initializer needs
2641 init_requires = {}
2642 for event in events:
2643 for var in event.parameters.keys():
2644 if var not in init_requires.keys():
2645 try:
2646 init_requires[var] = parameters[var]
2647 except:
2648 raise ValueError(
2649 var
2650 + " was referenced in initialize, but not declared as a parameter!"
2651 )
2652 if type(event) is RandomEvent:
2653 try:
2654 dstn_name = event._dstn_name
2655 init_requires[dstn_name] = distributions[dstn_name]
2656 except:
2657 raise ValueError(
2658 dstn_name
2659 + " was referenced in initialize, but not declared as a distribution!"
2660 )
2661 if type(event) is EvaluationEvent:
2662 try:
2663 func_name = event._func_name
2664 init_requires[dstn_name] = functions[func_name]
2665 except:
2666 raise ValueError(
2667 func_name
2668 + " was referenced in initialize, but not declared as a function!"
2669 )
2671 # Make a single string initializer statement
2672 statement = ""
2673 longest = np.max([len(event.statement) for event in events])
2674 for event in events:
2675 this_statement = event.statement
2676 L = len(this_statement)
2677 pad = (longest + 1) - L
2678 statement += this_statement + pad * " " + ": " + event.description + "\n"
2680 # Make and return the new SimBlock
2681 initializer = SimBlock(
2682 description="agent initializer for " + name,
2683 content=init_requires,
2684 statement=statement,
2685 events=events,
2686 )
2687 return initializer, init_requires
2690def make_new_event(statement, info):
2691 """
2692 Makes a "blank" version of a model event based on a statement line. Determines
2693 which objects are needed vs assigned vs parameters / information from context.
2695 Parameters
2696 ----------
2697 statement : str
2698 One line of a model statement, which will be turned into an event.
2699 info : dict
2700 Empty dictionary of model information that already exists. Consists of
2701 arrival variables, already assigned variables, parameters, and functions.
2702 Typing of each is based on the kind of "empty" object.
2704 Returns
2705 -------
2706 new_event : ModelEvent
2707 A new model event with values and information missing, but structure set.
2708 names_used : [str]
2709 List of names of objects used in this expression.
2710 """
2711 # First determine what kind of event this is
2712 has_eq = "=" in statement
2713 has_tld = "~" in statement
2714 has_amp = "@" in statement
2715 has_brc = ("{" in statement) and ("}" in statement)
2716 has_brk = ("[" in statement) and ("]" in statement)
2717 event_type = None
2718 if has_eq:
2719 if has_tld:
2720 raise ValueError("A statement line can't have both an = and a ~!")
2721 if has_amp:
2722 event_type = EvaluationEvent
2723 else:
2724 event_type = DynamicEvent
2725 if has_tld:
2726 if has_brc:
2727 event_type = MarkovEvent
2728 elif has_brk:
2729 event_type = RandomIndexedEvent
2730 else:
2731 event_type = RandomEvent
2732 if event_type is None:
2733 raise ValueError("Statement line was not any valid type!")
2735 # Now make and return an appropriate event for that type
2736 if event_type is DynamicEvent:
2737 event_maker = make_new_dynamic
2738 if event_type is RandomEvent:
2739 event_maker = make_new_random
2740 if event_type is RandomIndexedEvent:
2741 event_maker = make_new_random_indexed
2742 if event_type is MarkovEvent:
2743 event_maker = make_new_markov
2744 if event_type is EvaluationEvent:
2745 event_maker = make_new_evaluation
2747 new_event, names_used = event_maker(statement, info)
2748 return new_event, names_used
2751def make_new_dynamic(statement, info):
2752 """
2753 Construct a new instance of DynamicEvent based on the given model statement
2754 line and a blank dictionary of parameters. The statement should already be
2755 verified to be a valid dynamic statement: it has an = but no ~ or @.
2757 Parameters
2758 ----------
2759 statement : str
2760 One line dynamics statement, which will be turned into a DynamicEvent.
2761 info : dict
2762 Empty dictionary of available information.
2764 Returns
2765 -------
2766 new_dynamic : DynamicEvent
2767 A new dynamic event with values and information missing, but structure set.
2768 names_used : [str]
2769 List of names of objects used in this expression.
2770 """
2771 # Cut the statement up into its LHS, RHS, and description
2772 lhs, rhs, description = parse_line_for_parts(statement, "=")
2774 # Parse the LHS (assignment) to get assigned variables
2775 assigns = parse_assignment(lhs)
2777 # Parse the RHS (dynamic statement) to extract object names used
2778 obj_names, is_indexed = extract_var_names_from_expr(rhs)
2780 # Allocate each variable to needed dynamic variables or parameters
2781 needs = []
2782 parameters = {}
2783 for j in range(len(obj_names)):
2784 var = obj_names[j]
2785 if var not in info.keys():
2786 raise ValueError(
2787 var + " is used in a dynamic expression, but does not (yet) exist!"
2788 )
2789 val = info[var]
2790 if type(val) is NullFunc:
2791 raise ValueError(
2792 var + " is used in a dynamic expression, but it's a function!"
2793 )
2794 if type(val) is Distribution:
2795 raise ValueError(
2796 var + " is used in a dynamic expression, but it's a distribution!"
2797 )
2798 if val is None:
2799 parameters[var] = None
2800 else:
2801 needs.append(var)
2803 # Declare a SymPy symbol for each variable used; these are temporary
2804 _args = []
2805 for j in range(len(obj_names)):
2806 _var = obj_names[j]
2807 if is_indexed[j]:
2808 exec(_var + " = IndexedBase('" + _var + "')")
2809 else:
2810 exec(_var + " = symbols('" + _var + "')")
2811 _args.append(symbols(_var))
2813 # Make a SymPy expression, then lambdify it
2814 sympy_expr = symbols(rhs)
2815 expr = lambdify(_args, sympy_expr)
2817 # Make an overall list of object names referenced in this event
2818 names_used = assigns + obj_names
2820 # Make and return the new dynamic event
2821 new_dynamic = DynamicEvent(
2822 description=description,
2823 statement=lhs + " = " + rhs,
2824 assigns=assigns,
2825 needs=needs,
2826 parameters=parameters,
2827 expr=expr,
2828 args=obj_names,
2829 )
2830 return new_dynamic, names_used
2833def make_new_random(statement, info):
2834 """
2835 Make a new random variable realization event based on the given model statement
2836 line and a blank dictionary of parameters. The statement should already be
2837 verified to be a valid random statement: it has a ~ but no = or [].
2839 Parameters
2840 ----------
2841 statement : str
2842 One line of the model statement, which will be turned into a random event.
2843 info : dict
2844 Empty dictionary of available information.
2846 Returns
2847 -------
2848 new_random : RandomEvent
2849 A new random event with values and information missing, but structure set.
2850 names_used : [str]
2851 List of names of objects used in this expression.
2852 """
2853 # Cut the statement up into its LHS, RHS, and description
2854 lhs, rhs, description = parse_line_for_parts(statement, "~")
2856 # Parse the LHS (assignment) to get assigned variables
2857 assigns = parse_assignment(lhs)
2859 # Verify that the RHS is actually a distribution
2860 if type(info[rhs]) is not Distribution:
2861 raise ValueError(
2862 rhs + " was treated as a distribution, but not declared as one!"
2863 )
2865 # Make an overall list of object names referenced in this event
2866 names_used = assigns + [rhs]
2868 # Make and return the new random event
2869 new_random = RandomEvent(
2870 description=description,
2871 statement=lhs + " ~ " + rhs,
2872 assigns=assigns,
2873 needs=[],
2874 parameters={},
2875 dstn=info[rhs],
2876 )
2877 new_random._dstn_name = rhs
2878 return new_random, names_used
2881def make_new_random_indexed(statement, info):
2882 """
2883 Make a new indexed random variable realization event based on the given model
2884 statement line and a blank dictionary of parameters. The statement should
2885 already be verified to be a valid random statement: it has a ~ and [].
2887 Parameters
2888 ----------
2889 statement : str
2890 One line of the model statement, which will be turned into a random event.
2891 info : dict
2892 Empty dictionary of available information.
2894 Returns
2895 -------
2896 new_random_indexed : RandomEvent
2897 A new random indexed event with values and information missing, but structure set.
2898 names_used : [str]
2899 List of names of objects used in this expression.
2900 """
2901 # Cut the statement up into its LHS, RHS, and description
2902 lhs, rhs, description = parse_line_for_parts(statement, "~")
2904 # Parse the LHS (assignment) to get assigned variables
2905 assigns = parse_assignment(lhs)
2907 # Split the RHS into the distribution and the index
2908 dstn, index = parse_random_indexed(rhs)
2910 # Verify that the RHS is actually a distribution
2911 if type(info[dstn]) is not Distribution:
2912 raise ValueError(
2913 dstn + " was treated as a distribution, but not declared as one!"
2914 )
2916 # Make an overall list of object names referenced in this event
2917 names_used = assigns + [dstn, index]
2919 # Make and return the new random indexed event
2920 new_random_indexed = RandomIndexedEvent(
2921 description=description,
2922 statement=lhs + " ~ " + rhs,
2923 assigns=assigns,
2924 needs=[index],
2925 parameters={},
2926 index=index,
2927 )
2928 new_random_indexed._dstn_name = dstn
2929 return new_random_indexed, names_used
2932def make_new_markov(statement, info):
2933 """
2934 Make a new Markov-type event based on the given model statement line and a
2935 blank dictionary of parameters. The statement should already be verified to
2936 be a valid Markov statement: it has a ~ and {} and maybe (). This can represent
2937 a Markov matrix transition event, a draw from a discrete index, or just a
2938 Bernoulli random variable. If a Bernoulli event, the "probabilties" can be
2939 idiosyncratic data.
2941 Parameters
2942 ----------
2943 statement : str
2944 One line of the model statement, which will be turned into a random event.
2945 info : dict
2946 Empty dictionary of available information.
2948 Returns
2949 -------
2950 new_markov : MarkovEvent
2951 A new Markov draw event with values and information missing, but structure set.
2952 names_used : [str]
2953 List of names of objects used in this expression.
2954 """
2955 # Cut the statement up into its LHS, RHS, and description
2956 lhs, rhs, description = parse_line_for_parts(statement, "~")
2958 # Parse the LHS (assignment) to get assigned variables
2959 assigns = parse_assignment(lhs)
2961 # Parse the RHS (Markov statement) for the array and index
2962 probs, index = parse_markov(rhs)
2963 if index is None:
2964 needs = []
2965 else:
2966 needs = [index]
2968 # Determine whether probs is an idiosyncratic variable or a parameter, and
2969 # set up the event to grab it appropriately
2970 if info[probs] is None:
2971 parameters = {probs: None}
2972 else:
2973 needs += [probs]
2974 parameters = {}
2976 # Make an overall list of object names referenced in this event
2977 names_used = assigns + needs + [probs]
2979 # Make and return the new Markov event
2980 new_markov = MarkovEvent(
2981 description=description,
2982 statement=lhs + " ~ " + rhs,
2983 assigns=assigns,
2984 needs=needs,
2985 parameters=parameters,
2986 probs=probs,
2987 index=index,
2988 )
2989 return new_markov, names_used
2992def make_new_evaluation(statement, info):
2993 """
2994 Make a new function evaluation event based the given model statement line
2995 and a blank dictionary of parameters. The statement should already be verified
2996 to be a valid evaluation statement: it has an @ and an = but no ~.
2998 Parameters
2999 ----------
3000 statement : str
3001 One line of the model statement, which will be turned into an eval event.
3002 info : dict
3003 Empty dictionary of available information.
3005 Returns
3006 -------
3007 new_evaluation : EvaluationEvent
3008 A new evaluation event with values and information missing, but structure set.
3009 names_used : [str]
3010 List of names of objects used in this expression.
3011 """
3012 # Cut the statement up into its LHS, RHS, and description
3013 lhs, rhs, description = parse_line_for_parts(statement, "=")
3015 # Parse the LHS (assignment) to get assigned variables
3016 assigns = parse_assignment(lhs)
3018 # Parse the RHS (evaluation) for the function and its arguments
3019 func, arguments = parse_evaluation(rhs)
3021 # Allocate each variable to needed dynamic variables or parameters
3022 needs = []
3023 parameters = {}
3024 for j in range(len(arguments)):
3025 var = arguments[j]
3026 if var not in info.keys():
3027 raise ValueError(
3028 var + " is used in an evaluation statement, but does not (yet) exist!"
3029 )
3030 val = info[var]
3031 if type(val) is NullFunc:
3032 raise ValueError(
3033 var
3034 + " is used as an argument an evaluation statement, but it's a function!"
3035 )
3036 if type(val) is Distribution:
3037 raise ValueError(
3038 var + " is used in an evaluation statement, but it's a distribution!"
3039 )
3040 if val is None:
3041 parameters[var] = None
3042 else:
3043 needs.append(var)
3045 # Make an overall list of object names referenced in this event
3046 names_used = assigns + arguments + [func]
3048 # Make and return the new evaluation event
3049 new_evaluation = EvaluationEvent(
3050 description=description,
3051 statement=lhs + " = " + rhs,
3052 assigns=assigns,
3053 needs=needs,
3054 parameters=parameters,
3055 arguments=arguments,
3056 func=info[func],
3057 )
3058 new_evaluation._func_name = func
3059 return new_evaluation, names_used
3062def look_for_char_and_remove(phrase, symb):
3063 """
3064 Check whether a symbol appears in a string, and remove it if it does.
3066 Parameters
3067 ----------
3068 phrase : str
3069 String to be searched for a symbol.
3070 symb : char
3071 Single character to be searched for.
3073 Returns
3074 -------
3075 out : str
3076 Possibly shortened input phrase.
3077 found : bool
3078 Whether the symbol was found and removed.
3079 """
3080 found = symb in phrase
3081 out = phrase.replace(symb, "")
3082 return out, found
3085def parse_declaration_for_parts(line):
3086 """
3087 Split a declaration line from a model file into the object's name, its datatype,
3088 any metadata flags, and any provided comment or description.
3090 Parameters
3091 ----------
3092 line : str
3093 Line of to be parsed into the object name, object type, and a comment or description.
3095 Returns
3096 -------
3097 name : str
3098 Name of the object.
3099 datatype : str or None
3100 Provided datatype string, in parentheses, if any.
3101 flags : [str]
3102 List of metadata flags that were detected. These include ! for a variable
3103 that is in arrival, * for any non-variable that's part of the solution,
3104 + for any object that is offset in time, and & for a common random variable.
3106 desc : str
3107 Comment or description, after //, if any.
3108 """
3109 flags = []
3110 check_for_flags = {"offset": "+", "arrival": "!", "solution": "*", "common": "&"}
3112 # First, separate off the comment or description, if any
3113 slashes = line.find("\\")
3114 desc = "" if slashes == -1 else line[(slashes + 2) :].strip()
3115 rem = line if slashes == -1 else line[:slashes].strip()
3117 # Now look for bracketing parentheses declaring a datatype
3118 lp = rem.find("(")
3119 if lp > -1:
3120 rp = rem.find(")")
3121 if rp == -1:
3122 raise ValueError("Unclosed parentheses on object declaration line!")
3123 datatype = rem[(lp + 1) : rp].strip()
3124 leftover = rem[:lp].strip()
3125 else:
3126 datatype = None
3127 leftover = rem
3129 # What's left over should be the object name plus any flags
3130 for key in check_for_flags.keys():
3131 symb = check_for_flags[key]
3132 leftover, found = look_for_char_and_remove(leftover, symb)
3133 if found:
3134 flags.append(key)
3136 # Remove any remaining spaces, and that *should* be the name
3137 name = leftover.replace(" ", "")
3138 # TODO: Check for valid name formatting based on characters.
3140 return name, datatype, flags, desc
3143def parse_line_for_parts(statement, symb):
3144 """
3145 Split one line of a model statement into its LHS, RHS, and description. The
3146 description is everything following \\, while the LHS and RHS are determined
3147 by a special symbol.
3149 Parameters
3150 ----------
3151 statement : str
3152 One line of a model statement, which will be parsed for its parts.
3153 symb : char
3154 The character that represents the divide between LHS and RHS
3156 Returns
3157 -------
3158 lhs : str
3159 The left-hand (assignment) side of the expression.
3160 rhs : str
3161 The right-hand (evaluation) side of the expression.
3162 desc : str
3163 The provided description of the expression.
3164 """
3165 eq = statement.find(symb)
3166 lhs = statement[:eq].replace(" ", "")
3167 not_lhs = statement[(eq + 1) :]
3168 comment = not_lhs.find("\\")
3169 desc = "" if comment == -1 else not_lhs[(comment + 2) :].strip()
3170 rhs = not_lhs if comment == -1 else not_lhs[:comment]
3171 rhs = rhs.replace(" ", "")
3172 return lhs, rhs, desc
3175def parse_assignment(lhs):
3176 """
3177 Get ordered list of assigned variables from the LHS of a model line.
3179 Parameters
3180 ----------
3181 lhs : str
3182 Left-hand side of a model expression
3184 Returns
3185 -------
3186 assigns : List[str]
3187 List of variable names that are assigned in this model line.
3188 """
3189 if lhs[0] == "(":
3190 if not lhs[-1] == ")":
3191 raise ValueError("Parentheses on assignment was not closed!")
3192 assigns = []
3193 pos = 0
3194 while pos != -1:
3195 pos += 1
3196 end = lhs.find(",", pos)
3197 var = lhs[pos:end]
3198 if var != "":
3199 assigns.append(var)
3200 pos = end
3201 else:
3202 assigns = [lhs]
3203 return assigns
3206def extract_var_names_from_expr(expression):
3207 """
3208 Parse the RHS of a dynamic model statement to get variable names used in it.
3210 Parameters
3211 ----------
3212 expression : str
3213 RHS of a model statement to be parsed for variable names.
3215 Returns
3216 -------
3217 var_names : List[str]
3218 List of variable names used in the expression. These *should* be dynamic
3219 variables and parameters, but not functions.
3220 indexed : List[bool]
3221 Indicators for whether each variable seems to be used with indexing.
3222 """
3223 var_names = []
3224 indexed = []
3225 math_symbols = "+-/*^%.(),[]{}<>"
3226 digits = "01234567890"
3227 cur = ""
3228 for j in range(len(expression)):
3229 c = expression[j]
3230 if (c in math_symbols) or ((c in digits) and cur == ""):
3231 if cur == "":
3232 continue
3233 if cur in var_names:
3234 cur = ""
3235 continue
3236 var_names.append(cur)
3237 if c == "[":
3238 indexed.append(True)
3239 else:
3240 indexed.append(False)
3241 cur = ""
3242 else:
3243 cur += c
3244 if cur != "" and cur not in var_names:
3245 var_names.append(cur)
3246 indexed.append(False) # final symbol couldn't possibly be indexed
3247 return var_names, indexed
3250def parse_evaluation(expression):
3251 """
3252 Separate a function evaluation expression into the function that is called
3253 and the variable inputs that are passed to it.
3255 Parameters
3256 ----------
3257 expression : str
3258 RHS of a function evaluation model statement, which will be parsed for
3259 the function and its inputs.
3261 Returns
3262 -------
3263 func_name : str
3264 Name of the function that will be called in this event.
3265 arg_names : List[str]
3266 List of arguments of the function.
3267 """
3268 # Get the name of the function: what's to the left of the @
3269 amp = expression.find("@")
3270 func_name = expression[:amp]
3272 # Check for parentheses formatting
3273 rem = expression[(amp + 1) :]
3274 if not rem[0] == "(":
3275 raise ValueError(
3276 "The @ in a function evaluation statement must be followed by (!"
3277 )
3278 if not rem[-1] == ")":
3279 raise ValueError("A function evaluation statement must end in )!")
3280 rem = rem[1:-1]
3282 # Parse what's inside the parentheses for argument names
3283 arg_names = []
3284 pos = 0
3285 go = True
3286 while go:
3287 end = rem.find(",", pos)
3288 if end > -1:
3289 arg = rem[pos:end]
3290 else:
3291 arg = rem[pos:]
3292 go = False
3293 if arg != "":
3294 arg_names.append(arg)
3295 pos = end + 1
3297 return func_name, arg_names
3300def parse_markov(expression):
3301 """
3302 Separate a Markov draw declaration into the array of probabilities and the
3303 index for idiosyncratic values.
3305 Parameters
3306 ----------
3307 expression : str
3308 RHS of a function evaluation model statement, which will be parsed for
3309 the probabilities name and index name.
3311 Returns
3312 -------
3313 probs : str
3314 Name of the probabilities object in this statement.
3315 index : str
3316 Name of the indexing variable in this statement.
3317 """
3318 # Get the name of the probabilitie
3319 lb = expression.find("{") # this *should* be 0
3320 rb = expression.find("}")
3321 if lb == -1 or rb == -1 or rb < (lb + 2):
3322 raise ValueError("A Markov assignment must have an {array}!")
3323 probs = expression[(lb + 1) : rb]
3325 # Get the name of the index, if any
3326 x = rb + 1
3327 lp = expression.find("(", x)
3328 rp = expression.find(")", x)
3329 if lp == -1 and rp == -1: # no index present at all
3330 return probs, None
3331 if lp == -1 or rp == -1 or rp < (lp + 2):
3332 raise ValueError("Improper Markov formatting: should be {probs}(index)!")
3333 index = expression[(lp + 1) : rp]
3335 return probs, index
3338def parse_random_indexed(expression):
3339 """
3340 Separate an indexed random variable assignment into the distribution and
3341 the index for it.
3343 Parameters
3344 ----------
3345 expression : str
3346 RHS of a function evaluation model statement, which will be parsed for
3347 the distribution name and index name.
3349 Returns
3350 -------
3351 dstn : str
3352 Name of the distribution in this statement.
3353 index : str
3354 Name of the indexing variable in this statement.
3355 """
3356 # Get the name of the index
3357 lb = expression.find("[")
3358 rb = expression.find("]")
3359 if lb == -1 or rb == -1 or rb < (lb + 2):
3360 raise ValueError("An indexed random variable assignment must have an [index]!")
3361 index = expression[(lb + 1) : rb]
3363 # Get the name of the distribution
3364 dstn = expression[:lb]
3366 return dstn, index
3369def format_block_statement(statement):
3370 """
3371 Ensure that a string stagement of a model block (maybe a period, maybe an
3372 initializer) is formatted as a list of strings, one statement per entry.
3374 Parameters
3375 ----------
3376 statement : str
3377 A model statement, which might be for a block or an initializer. The
3378 statement might be formatted as a list or as a single string.
3380 Returns
3381 -------
3382 block_statements: [str]
3383 A list of model statements, one per entry.
3384 """
3385 if type(statement) is str:
3386 if statement.find("\n") > -1:
3387 block_statements = []
3388 pos = 0
3389 end = statement.find("\n", pos)
3390 while end > -1:
3391 new_line = statement[pos:end]
3392 block_statements.append(new_line)
3393 pos = end + 1
3394 end = statement.find("\n", pos)
3395 else:
3396 block_statements = [statement.copy()]
3397 if type(statement) is list:
3398 for line in statement:
3399 if type(line) is not str:
3400 raise ValueError("The model statement somehow includes a non-string!")
3401 block_statements = statement.copy()
3402 return block_statements
3405@njit
3406def aggregate_blobs_onto_polynomial_grid(
3407 vals, pmv, origins, grid, J, Q
3408): # pragma: no cover
3409 """
3410 Numba-compatible helper function for casting "probability blobs" onto a discretized
3411 grid of outcome values, based on their origin in the arrival state space. This
3412 version is for non-continuation variables, returning only the probability array
3413 mapping from arrival states to the outcome variable.
3414 """
3415 bot = grid[0]
3416 top = grid[-1]
3417 M = grid.size
3418 Mm1 = M - 1
3419 N = pmv.size
3420 scale = 1.0 / (top - bot)
3421 order = 1.0 / Q
3422 diffs = grid[1:] - grid[:-1]
3424 probs = np.zeros((J, M))
3426 for n in range(N):
3427 x = vals[n]
3428 jj = origins[n]
3429 p = pmv[n]
3430 if (x > bot) and (x < top):
3431 ii = int(np.floor(((x - bot) * scale) ** order * Mm1))
3432 temp = (x - grid[ii]) / diffs[ii]
3433 probs[jj, ii] += (1.0 - temp) * p
3434 probs[jj, ii + 1] += temp * p
3435 elif x <= bot:
3436 probs[jj, 0] += p
3437 else:
3438 probs[jj, -1] += p
3439 return probs
3442@njit
3443def aggregate_blobs_onto_polynomial_grid_alt(
3444 vals, pmv, origins, grid, J, Q
3445): # pragma: no cover
3446 """
3447 Numba-compatible helper function for casting "probability blobs" onto a discretized
3448 grid of outcome values, based on their origin in the arrival state space. This
3449 version is for ncontinuation variables, returning the probability array mapping
3450 from arrival states to the outcome variable, the index in the outcome variable grid
3451 for each blob, and the alpha weighting between gridpoints.
3452 """
3453 bot = grid[0]
3454 top = grid[-1]
3455 M = grid.size
3456 Mm1 = M - 1
3457 N = pmv.size
3458 scale = 1.0 / (top - bot)
3459 order = 1.0 / Q
3460 diffs = grid[1:] - grid[:-1]
3462 probs = np.zeros((J, M))
3463 idx = np.empty(N, dtype=np.dtype(np.int32))
3464 alpha = np.empty(N)
3466 for n in range(N):
3467 x = vals[n]
3468 jj = origins[n]
3469 p = pmv[n]
3470 if (x > bot) and (x < top):
3471 ii = int(np.floor(((x - bot) * scale) ** order * Mm1))
3472 temp = (x - grid[ii]) / diffs[ii]
3473 probs[jj, ii] += (1.0 - temp) * p
3474 probs[jj, ii + 1] += temp * p
3475 alpha[n] = temp
3476 idx[n] = ii
3477 elif x <= bot:
3478 probs[jj, 0] += p
3479 alpha[n] = 0.0
3480 idx[n] = 0
3481 else:
3482 probs[jj, -1] += p
3483 alpha[n] = 1.0
3484 idx[n] = M - 2
3485 return probs, idx, alpha
3488@njit
3489def aggregate_blobs_onto_discrete_grid(vals, pmv, origins, M, J): # pragma: no cover
3490 """
3491 Numba-compatible helper function for allocating "probability blobs" to a grid
3492 over a discrete state-- the state itself is truly discrete.
3493 """
3494 out = np.zeros((J, M))
3495 N = pmv.size
3496 for n in range(N):
3497 ii = vals[n]
3498 jj = origins[n]
3499 p = pmv[n]
3500 out[jj, ii] += p
3501 return out
3504@njit
3505def calc_overall_trans_probs(
3506 out, idx, alpha, binary, offset, pmv, origins
3507): # pragma: no cover
3508 """
3509 Numba-compatible helper function for combining transition probabilities from
3510 the arrival state space to *multiple* continuation variables into a single
3511 unified transition matrix.
3512 """
3513 N = alpha.shape[0]
3514 B = binary.shape[0]
3515 D = binary.shape[1]
3516 for n in range(N):
3517 ii = origins[n]
3518 jj_base = idx[n]
3519 p = pmv[n]
3520 for b in range(B):
3521 adj = offset[b]
3522 P = p
3523 for d in range(D):
3524 k = binary[b, d]
3525 P *= alpha[n, d, k]
3526 jj = jj_base + adj
3527 out[ii, jj] += P
3528 return out