Coverage for HARK / simulator.py: 93%
1550 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 06:00 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 06:00 +0000
1"""
2A module with classes and functions for automated simulation of HARK.AgentType
3models from a human- and machine-readable model specification.
4"""
6from dataclasses import dataclass, field
7from copy import copy, deepcopy
8import numpy as np
9from numba import njit
10from sympy.utilities.lambdify import lambdify
11from sympy import symbols, IndexedBase
12from typing import Callable
13from HARK.utilities import NullFunc, make_exponential_grid
14from HARK.distributions import Distribution
15from scipy.sparse import csr_matrix, csc_matrix
16from scipy.sparse.linalg import eigs
17from scipy.optimize import brentq
18from itertools import product
19import importlib.resources
20import yaml
22# Prevent pre-commit from removing sympy
23x = symbols("x")
24del x
25y = IndexedBase("y")
26del y
29@dataclass(kw_only=True)
30class ModelEvent:
31 """
32 Class for representing "events" that happen to agents in the course of their
33 model. These might be statements of dynamics, realization of a random shock,
34 or the evaluation of a function (potentially a control or other solution-
35 based object). This is a superclass for types of events defined below.
37 Parameters
38 ----------
39 description : str
40 Text description of this model event.
41 statement : str
42 The line of the model statement that this event corresponds to.
43 parameters : dict
44 Dictionary of objects that are static / universal within this event.
45 assigns : list[str]
46 List of names of variables that this event assigns values for.
47 needs : list[str]
48 List of names of variables that this event requires to be run.
49 data : dict
50 Dictionary of current variable values within this event.
51 common : bool
52 Indicator for whether the variables assigned in this event are commonly
53 held across all agents, rather than idiosyncratic.
54 N : int
55 Number of agents currently in this event.
56 """
58 statement: str = field(default="")
59 parameters: dict = field(default_factory=dict)
60 description: str = field(default="")
61 assigns: list[str] = field(default_factory=list, repr=False)
62 needs: list = field(default_factory=list, repr=False)
63 data: dict = field(default_factory=dict, repr=False)
64 common: bool = field(default=False, repr=False)
65 N: int = field(default=1, repr=False)
67 def run(self):
68 """
69 This method should be filled in by each subclass.
70 """
71 pass # pragma: nocover
73 def reset(self):
74 self.data = {}
76 def assign(self, output):
77 if len(self.assigns) > 1:
78 assert len(self.assigns) == len(output)
79 for j in range(len(self.assigns)):
80 var = self.assigns[j]
81 if type(output[j]) is not np.ndarray:
82 output[j] = np.array([output[j]])
83 self.data[var] = output[j]
84 else:
85 var = self.assigns[0]
86 if type(output) is not np.ndarray:
87 output = np.array([output])
88 self.data[var] = output
90 def expand_information(self, origins, probs, atoms, which=None):
91 """
92 This method is only called internally when a RandomEvent or MarkovEvent
93 runs its quasi_run() method. It expands the set of of "probability blobs"
94 by applying a random realization event. All extant blobs for which the
95 shock applies are replicated for each atom in the random event, with the
96 probability mass divided among the replicates.
98 Parameters
99 ----------
100 origins : np.array
101 Array that tracks which arrival state space node each blob originated
102 from. This is expanded into origins_new, which is returned.
103 probs : np.array
104 Vector of probabilities of each of the random possibilities.
105 atoms : [np.array]
106 List of arrays with realization values for the distribution. Each
107 array corresponds to one variable that is assigned by this event.
108 which : np.array or None
109 If given, a Boolean array indicating which of the pre-existing blobs
110 is affected by the given probabilities and atoms. By default, all
111 blobs are assumed to be affected.
113 Returns
114 -------
115 origins_new : np.array
116 Expanded boolean array of indicating the arrival state space node that
117 each blob originated from.
118 """
119 K = probs.size
120 N = self.N
121 if which is None:
122 which = np.ones(N, dtype=bool)
123 other = np.logical_not(which)
124 M = np.sum(which) # how many blobs are we affecting?
125 MX = N - M # how many blobs are we not affecting?
127 # Update probabilities of outcomes
128 pmv_old = np.reshape(self.data["pmv_"][which], (M, 1))
129 pmv_new = (pmv_old * np.reshape(probs, (1, K))).flatten()
130 self.data["pmv_"] = np.concatenate((self.data["pmv_"][other], pmv_new))
132 # Replicate the pre-existing data for each atom
133 for var in self.data.keys():
134 if (var == "pmv_") or (var in self.assigns):
135 continue # don't double expand pmv, and don't touch assigned variables
136 data_old = np.reshape(self.data[var][which], (M, 1))
137 data_new = np.tile(data_old, (1, K)).flatten()
138 self.data[var] = np.concatenate((self.data[var][other], data_new))
140 # If any of the assigned variables don't exist yet, add dummy versions
141 # of them. This section exists so that the code works with "partial events"
142 # on both the first pass and subsequent passes.
143 for j in range(len(self.assigns)):
144 var = self.assigns[j]
145 if var in self.data.keys():
146 continue
147 self.data[var] = np.zeros(N, dtype=atoms[j].dtype)
148 # Zeros are just dummy values
150 # Add the new random variables to the simulation data. This generates
151 # replicates for the affected blobs and leaves the others untouched,
152 # still with their dummy values. They will be altered on later passes.
153 for j in range(len(self.assigns)):
154 var = self.assigns[j]
155 data_new = np.tile(np.reshape(atoms[j], (1, K)), (M, 1)).flatten()
156 self.data[var] = np.concatenate((self.data[var][other], data_new))
158 # Expand the origins array to account for the new replicates
159 origins_new = np.tile(np.reshape(origins[which], (M, 1)), (1, K)).flatten()
160 origins_new = np.concatenate((origins[other], origins_new))
161 self.N = MX + M * K
163 # Send the new origins array back to the calling process
164 return origins_new
166 def add_idiosyncratic_bernoulli_info(self, origins, probs):
167 """
168 Special method for adding Bernoulli outcomes to the information set when
169 probabilities are idiosyncratic to each agent. All extant blobs are duplicated
170 with the appropriate probability
172 Parameters
173 ----------
174 origins : np.array
175 Array that tracks which arrival state space node each blob originated
176 from. This is expanded into origins_new, which is returned.
177 probs : np.array
178 Vector of probabilities of drawing True for each blob.
180 Returns
181 -------
182 origins_new : np.array
183 Expanded boolean array of indicating the arrival state space node that
184 each blob originated from.
185 """
186 N = self.N
188 # # Update probabilities of outcomes, replicating each one
189 pmv_old = np.reshape(self.data["pmv_"], (N, 1))
190 P = np.reshape(probs, (N, 1))
191 PX = np.concatenate([1.0 - P, P], axis=1)
192 pmv_new = (pmv_old * PX).flatten()
193 self.data["pmv_"] = pmv_new
195 # Replicate the pre-existing data for each atom
196 for var in self.data.keys():
197 if (var == "pmv_") or (var in self.assigns):
198 continue # don't double expand pmv, and don't touch assigned variables
199 data_old = np.reshape(self.data[var], (N, 1))
200 data_new = np.tile(data_old, (1, 2)).flatten()
201 self.data[var] = data_new
203 # Add the (one and only) new random variable to the simulation data
204 var = self.assigns[0]
205 data_new = np.tile(np.array([[0, 1]]), (N, 1)).flatten()
206 self.data[var] = data_new
208 # Expand the origins array to account for the new replicates
209 origins_new = np.tile(np.reshape(origins, (N, 1)), (1, 2)).flatten()
210 self.N = N * 2
212 # Send the new origins array back to the calling process
213 return origins_new
216@dataclass(kw_only=True)
217class DynamicEvent(ModelEvent):
218 """
219 Class for representing model dynamics for an agent, consisting of an expression
220 to be evaluated and variables to which the results are assigned.
222 Parameters
223 ----------
224 expr : Callable
225 Function or expression to be evaluated for the assigned variables.
226 args : list[str]
227 Ordered list of argument names for the expression.
228 """
230 expr: Callable = field(default_factory=NullFunc, repr=False)
231 args: list[str] = field(default_factory=list, repr=False)
233 def evaluate(self):
234 temp_dict = self.data.copy()
235 temp_dict.update(self.parameters)
236 args = (temp_dict[arg] for arg in self.args)
237 out = self.expr(*args)
238 return out
240 def run(self):
241 self.assign(self.evaluate())
243 def quasi_run(self, origins, norm=None):
244 self.run()
245 return origins
248@dataclass(kw_only=True)
249class RandomEvent(ModelEvent):
250 """
251 Class for representing the realization of random variables for an agent,
252 consisting of a shock distribution and variables to which the results are assigned.
254 Parameters
255 ----------
256 dstn : Distribution
257 Distribution of one or more random variables that are drawn from during
258 this event and assigned to the corresponding variables.
259 """
261 dstn: Distribution = field(default_factory=Distribution, repr=False)
263 def reset(self):
264 self.dstn.reset()
265 ModelEvent.reset(self)
267 def draw(self):
268 out = np.empty((len(self.assigns), self.N))
269 if not self.common:
270 out[:, :] = self.dstn.draw(self.N)
271 else:
272 out[:, :] = self.dstn.draw(1)
273 if len(self.assigns) == 1:
274 out = out.flatten()
275 return out
277 def run(self):
278 self.assign(self.draw())
280 def quasi_run(self, origins, norm=None):
281 # Get distribution
282 atoms = self.dstn.atoms
283 probs = self.dstn.pmv.copy()
285 # Apply Harmenberg normalization if applicable
286 try:
287 harm_idx = self.assigns.index(norm)
288 probs *= atoms[harm_idx]
289 except:
290 pass
292 # Expand the set of simulated blobs
293 origins_new = self.expand_information(origins, probs, atoms)
294 return origins_new
297@dataclass(kw_only=True)
298class RandomIndexedEvent(RandomEvent):
299 """
300 Class for representing the realization of random variables for an agent,
301 consisting of a list of shock distributions, an index for the list, and the
302 variables to which the results are assigned.
304 Parameters
305 ----------
306 dstn : [Distribution]
307 List of distributions of one or more random variables that are drawn
308 from during this event and assigned to the corresponding variables.
309 index : str
310 Name of the index that is used to choose a distribution for each agent.
311 """
313 index: str = field(default="", repr=False)
314 dstn: list[Distribution] = field(default_factory=list, repr=False)
316 def draw(self):
317 idx = self.data[self.index]
318 K = len(self.assigns)
319 out = np.empty((K, self.N))
320 out.fill(np.nan)
322 if self.common:
323 k = idx[0] # this will behave badly if index is not itself common
324 out[:, :] = self.dstn[k].draw(1)
325 return out
327 for k in range(len(self.dstn)):
328 these = idx == k
329 if not np.any(these):
330 continue
331 out[:, these] = self.dstn[k].draw(np.sum(these))
332 if K == 1:
333 out = out.flatten()
334 return out
336 def reset(self):
337 for k in range(len(self.dstn)):
338 self.dstn[k].reset()
339 ModelEvent.reset(self)
341 def quasi_run(self, origins, norm=None):
342 origins_new = origins.copy()
343 J = len(self.dstn)
345 for j in range(J):
346 idx = self.data[self.index]
347 these = idx == j
349 # Get distribution
350 atoms = self.dstn[j].atoms
351 probs = self.dstn[j].pmv.copy()
353 # Apply Harmenberg normalization if applicable
354 try:
355 harm_idx = self.assigns.index(norm)
356 probs *= atoms[harm_idx]
357 except:
358 pass
360 # Expand the set of simulated blobs
361 origins_new = self.expand_information(
362 origins_new, probs, atoms, which=these
363 )
365 # Return the altered origins array
366 return origins_new
369@dataclass(kw_only=True)
370class MarkovEvent(ModelEvent):
371 """
372 Class for representing the realization of a Markov draw for an agent, in which
373 a Markov probabilities (array, vector, or a single float) is used to determine
374 the realization of some discrete outcome. If the probabilities are a 2D array,
375 it represents a Markov matrix (rows sum to 1), and there must be an index; if
376 the probabilities are a vector, it should be a stochastic vector; if it's a
377 single float, it represents a Bernoulli probability.
378 """
380 probs: str = field(default="", repr=False)
381 index: str = field(default="", repr=False)
382 N: int = field(default=1, repr=False)
383 seed: int = field(default=0, repr=False)
384 # seed is overwritten when each period is created
386 def __post_init__(self):
387 self.reset_rng()
389 def reset(self):
390 self.reset_rng()
391 ModelEvent.reset(self)
393 def reset_rng(self):
394 self.RNG = np.random.RandomState(self.seed)
396 def draw(self):
397 # Initialize the output
398 out = -np.ones(self.N, dtype=int)
399 if self.probs in self.parameters:
400 probs = self.parameters[self.probs]
401 probs_are_param = True
402 else:
403 probs = self.data[self.probs]
404 probs_are_param = False
406 # Make the base draw(s)
407 if self.common:
408 X = self.RNG.rand(1)
409 else:
410 X = self.RNG.rand(self.N)
412 if self.index: # it's a Markov matrix
413 idx = self.data[self.index]
414 J = probs.shape[0]
415 for j in range(J):
416 these = idx == j
417 if not np.any(these):
418 continue
419 P = np.cumsum(probs[j, :])
420 if self.common:
421 out[:] = np.searchsorted(P, X[0]) # only one value of X!
422 else:
423 out[these] = np.searchsorted(P, X[these])
424 return out
426 if (isinstance(probs, np.ndarray)) and (
427 probs_are_param
428 ): # it's a stochastic vector
429 P = np.cumsum(probs)
430 if self.common:
431 out[:] = np.searchsorted(P, X[0])
432 return out
433 else:
434 return np.searchsorted(P, X)
436 # Otherwise, this is just a Bernoulli RV
437 P = probs
438 if self.common:
439 out[:] = X < P
440 return out
441 else:
442 return X < P # basic Bernoulli
444 def run(self):
445 self.assign(self.draw())
447 def quasi_run(self, origins, norm=None):
448 if self.probs in self.parameters:
449 probs = self.parameters[self.probs]
450 probs_are_param = True
451 else:
452 probs = self.data[self.probs]
453 probs_are_param = False
455 # If it's a Markov matrix:
456 if self.index:
457 K = probs.shape[0]
458 atoms = np.array([np.arange(probs.shape[1], dtype=int)])
459 origins_new = origins.copy()
460 for k in range(K):
461 idx = self.data[self.index]
462 these = idx == k
463 probs_temp = probs[k, :]
464 origins_new = self.expand_information(
465 origins_new, probs_temp, atoms, which=these
466 )
467 return origins_new
469 # If it's a stochastic vector:
470 if (isinstance(probs, np.ndarray)) and (probs_are_param):
471 atoms = np.array([np.arange(probs.shape[0], dtype=int)])
472 origins_new = self.expand_information(origins, probs, atoms)
473 return origins_new
475 # Otherwise, this is just a Bernoulli RV, but it might have idiosyncratic probability
476 if probs_are_param:
477 P = probs
478 atoms = np.array([[False, True]])
479 origins_new = self.expand_information(origins, np.array([1 - P, P]), atoms)
480 return origins_new
482 # Final case: probability is idiosyncratic Bernoulli
483 origins_new = self.add_idiosyncratic_bernoulli_info(origins, probs)
484 return origins_new
487@dataclass(kw_only=True)
488class EvaluationEvent(ModelEvent):
489 """
490 Class for representing the evaluation of a model function. This might be from
491 the solution of the model (like a policy function or decision rule) or just
492 a non-algebraic function used in the model. This looks a lot like DynamicEvent.
494 Parameters
495 ----------
496 func : Callable
497 Model function that is evaluated in this event, with the output assigned
498 to the appropriate variables.
499 """
501 func: Callable = field(default_factory=NullFunc, repr=False)
502 arguments: list[str] = field(default_factory=list, repr=False)
504 def evaluate(self):
505 temp_dict = self.data.copy()
506 temp_dict.update(self.parameters)
507 args_temp = (temp_dict[arg] for arg in self.arguments)
508 out = self.func(*args_temp)
509 return out
511 def run(self):
512 self.assign(self.evaluate())
514 def quasi_run(self, origins, norm=None):
515 self.run()
516 return origins
519@dataclass(kw_only=True)
520class SimBlock:
521 """
522 Class for representing a "block" of a simulated model, which might be a whole
523 period or a "stage" within a period.
525 Parameters
526 ----------
527 description : str
528 Textual description of what happens in this simulated block.
529 statement : str
530 Verbatim model statement that was used to create this block.
531 content : dict
532 Dictionary of objects that are constant / universal within the block.
533 This includes both traditional numeric parameters as well as functions.
534 arrival : list[str]
535 List of inbound states: information available at the *start* of the block.
536 events: list[ModelEvent]
537 Ordered list of events that happen during the block.
538 data: dict
539 Dictionary that stores current variable values.
540 N : int
541 Number of idiosyncratic agents in this block.
542 """
544 statement: str = field(default="", repr=False)
545 content: dict = field(default_factory=dict)
546 description: str = field(default="", repr=False)
547 arrival: list[str] = field(default_factory=list, repr=False)
548 events: list[ModelEvent] = field(default_factory=list, repr=False)
549 data: dict = field(default_factory=dict, repr=False)
550 N: int = field(default=1, repr=False)
552 def run(self):
553 """
554 Run this simulated block by running each of its events in order.
555 """
556 for j in range(len(self.events)):
557 event = self.events[j]
558 for k in range(len(event.assigns)):
559 var = event.assigns[k]
560 if var in event.data.keys():
561 del event.data[var]
562 for k in range(len(event.needs)):
563 var = event.needs[k]
564 event.data[var] = self.data[var]
565 event.N = self.N
566 event.run()
567 for k in range(len(event.assigns)):
568 var = event.assigns[k]
569 self.data[var] = event.data[var]
571 def reset(self):
572 """
573 Reset the simulated block by resetting each of its events.
574 """
575 self.data = {}
576 for j in range(len(self.events)):
577 self.events[j].reset()
579 def distribute_content(self):
580 """
581 Fill in parameters, functions, and distributions to each event.
582 """
583 for event in self.events:
584 for param in event.parameters.keys():
585 try:
586 event.parameters[param] = self.content[param]
587 except:
588 raise ValueError(
589 "Could not distribute the parameter called " + param + "!"
590 )
591 if (type(event) is RandomEvent) or (type(event) is RandomIndexedEvent):
592 try:
593 event.dstn = self.content[event._dstn_name]
594 except:
595 raise ValueError(
596 "Could not find a distribution called " + event._dstn_name + "!"
597 )
598 if type(event) is EvaluationEvent:
599 try:
600 event.func = self.content[event._func_name]
601 except:
602 raise ValueError(
603 "Could not find a function called " + event._func_name + "!"
604 )
606 def _build_input_grids(self, grid_specs, arrival_N):
607 """
608 Build input and output grid dictionaries from grid specifications.
610 Creates grids_in for arrival variables and grids_out for outcome variables,
611 tracking which arrival variables have been covered and whether each output
612 grid is continuous.
614 Parameters
615 ----------
616 grid_specs : dict
617 Dictionary of grid specifications keyed by variable name.
618 arrival_N : int
619 Number of arrival variables in this block.
621 Returns
622 -------
623 grids_in : dict
624 Grids for arrival (input) variables.
625 grids_out : dict
626 Grids for outcome (output) variables.
627 continuous_grid_out_bool : list
628 List of booleans indicating whether each output grid is continuous.
629 grid_orders : dict
630 Polynomial order for each variable grid (-1 for discrete, None if unknown).
631 dummy_grid : np.ndarray or None
632 Dummy grid used only when arrival_N == 0.
633 """
634 completed = arrival_N * [False]
635 grids_in = {}
636 grids_out = {}
637 dummy_grid = None
638 if arrival_N == 0: # should only be for initializer block
639 dummy_grid = np.array([0])
640 grids_in["_dummy"] = dummy_grid
642 continuous_grid_out_bool = []
643 grid_orders = {}
644 for var in grid_specs.keys():
645 spec = grid_specs[var]
646 try:
647 idx = self.arrival.index(var)
648 completed[idx] = True
649 is_arrival = True
650 except:
651 is_arrival = False
652 if ("min" in spec) and ("max" in spec):
653 Q = spec["order"] if "order" in spec else 1.0
654 bot = spec["min"]
655 top = spec["max"]
656 N = spec["N"]
657 new_grid = make_exponential_grid(bot, top, N, Q)
658 is_cont = True
659 grid_orders[var] = Q
660 elif "N" in spec:
661 new_grid = np.arange(spec["N"], dtype=int)
662 is_cont = False
663 grid_orders[var] = -1
664 else:
665 new_grid = None # could not make grid, construct later
666 is_cont = False
667 grid_orders[var] = None
669 if is_arrival:
670 grids_in[var] = new_grid
671 else:
672 grids_out[var] = new_grid
673 continuous_grid_out_bool.append(is_cont)
675 # Verify that specifications were passed for all arrival variables
676 for j in range(len(self.arrival)):
677 if not completed[j]:
678 raise ValueError(
679 "No grid specification was provided for " + self.arrival[j] + "!"
680 )
682 return grids_in, grids_out, continuous_grid_out_bool, grid_orders, dummy_grid
684 def _build_twist_grids(
685 self, twist, grids_in, grid_orders, grids_out, continuous_grid_out_bool
686 ):
687 """
688 Override output grids with arrival-matching grids for continuation variables.
690 When an intertemporal twist is provided, the result grids for continuation
691 variables are set to match the corresponding arrival variable grids.
693 Parameters
694 ----------
695 twist : dict
696 Mapping from continuation variable names to arrival variable names.
697 grids_in : dict
698 Grids for arrival variables.
699 grid_orders : dict
700 Polynomial orders for each variable grid (modified in place).
701 grids_out : dict
702 Grids for output variables (modified in place).
703 continuous_grid_out_bool : list
704 Boolean continuity flags for output grids (extended in place).
706 Returns
707 -------
708 grids_out : dict
709 Updated output grids.
710 grid_orders : dict
711 Updated grid orders.
712 grid_out_is_continuous : np.ndarray
713 Boolean array indicating continuity of each output grid.
714 """
715 for cont_var in twist.keys():
716 arr_var = twist[cont_var]
717 if cont_var not in list(grids_out.keys()):
718 is_cont = grids_in[arr_var].dtype is np.dtype(np.float64)
719 continuous_grid_out_bool.append(is_cont)
720 grids_out[cont_var] = copy(grids_in[arr_var])
721 grid_orders[cont_var] = grid_orders[arr_var]
722 grid_out_is_continuous = np.array(continuous_grid_out_bool)
723 return grids_out, grid_orders, grid_out_is_continuous
725 def _project_onto_output_grids(
726 self,
727 grids_out,
728 grid_out_is_continuous,
729 grid_orders,
730 cont_vars,
731 twist,
732 N_orig,
733 J,
734 N,
735 ):
736 """
737 Project quasi-simulation results onto the discretized output grids.
739 Loops over each output variable, dispatching to the appropriate
740 aggregation routine based on whether the grid is continuous or discrete.
742 Parameters
743 ----------
744 grids_out : dict
745 Output grids (may be updated for None grids).
746 grid_out_is_continuous : np.ndarray
747 Boolean continuity flags for each output variable.
748 grid_orders : dict
749 Polynomial orders for each variable grid.
750 cont_vars : list
751 Names of continuation variables.
752 twist : dict or None
753 The intertemporal twist mapping (used only to check if provided).
754 N_orig : int
755 Number of original arrival-state grid points.
756 J : int
757 Size of the arrival state mesh.
758 N : int
759 Number of agents in this block.
761 Returns
762 -------
763 matrices_out : dict
764 Transition matrices for each output variable.
765 cont_idx : dict
766 Lower-bracket indices for continuation variables.
767 cont_alpha : dict
768 Interpolation weights for continuation variables.
769 cont_M : dict
770 Grid sizes for continuation variables.
771 cont_discrete : dict
772 Whether each continuation variable uses a discrete grid.
773 grids_out : dict
774 Updated output grids (some None entries may be filled in).
775 grid_out_is_continuous : np.ndarray
776 Updated continuity flags (may be changed for size-1 float grids).
777 """
778 origin_array = self.origin_array
779 matrices_out = {}
780 cont_idx = {}
781 cont_alpha = {}
782 cont_M = {}
783 cont_discrete = {}
784 k = 0
785 for var in grids_out.keys():
786 if var not in self.data.keys():
787 raise ValueError(
788 "Variable " + var + " does not exist but a grid was specified!"
789 )
790 grid = grids_out[var]
791 vals = self.data[var]
792 pmv = self.data["pmv_"]
793 M = grid.size if grid is not None else 0
795 # Semi-hacky fix to deal with omitted arrival variables
796 if (M == 1) and (vals.dtype is np.dtype(np.float64)):
797 grid = grid.astype(float)
798 grids_out[var] = grid
799 grid_out_is_continuous[k] = True
801 if grid_out_is_continuous[k]:
802 # Split the final values among discrete gridpoints on the interior.
803 if M > 1:
804 Q = grid_orders[var]
805 if var in cont_vars:
806 trans_matrix, cont_idx[var], cont_alpha[var] = (
807 aggregate_blobs_onto_polynomial_grid_alt(
808 vals, pmv, origin_array, grid, J, Q
809 )
810 )
811 cont_M[var] = M
812 cont_discrete[var] = False
813 else:
814 trans_matrix = aggregate_blobs_onto_polynomial_grid(
815 vals, pmv, origin_array, grid, J, Q
816 )
817 else: # Skip if the grid is a dummy with only one value.
818 trans_matrix = np.ones((J, M))
819 if var in cont_vars:
820 cont_idx[var] = np.zeros(N, dtype=int)
821 cont_alpha[var] = np.zeros(N)
822 cont_M[var] = M
823 cont_discrete[var] = False
825 else: # Grid is discrete, can use simpler method
826 if grid is None:
827 M = np.max(vals.astype(int))
828 if var == "dead":
829 M = 2
830 grid = np.arange(M, dtype=int)
831 grids_out[var] = grid
832 M = grid.size
833 vals = vals.astype(int)
834 trans_matrix = aggregate_blobs_onto_discrete_grid(
835 vals, pmv, origin_array, M, J
836 )
837 if var in cont_vars:
838 cont_idx[var] = vals
839 cont_alpha[var] = np.zeros(N)
840 cont_M[var] = M
841 cont_discrete[var] = True
843 # Store the transition matrix for this variable
844 matrices_out[var] = trans_matrix
845 k += 1
847 return (
848 matrices_out,
849 cont_idx,
850 cont_alpha,
851 cont_M,
852 cont_discrete,
853 grids_out,
854 grid_out_is_continuous,
855 )
857 def _build_master_transition_array(
858 self, cont_vars, cont_idx, cont_alpha, cont_M, cont_discrete, N_orig, N, D
859 ):
860 """
861 Construct the master arrival-to-continuation transition array.
863 Combines per-variable index arrays and interpolation weights into a
864 single tensor using multilinear interpolation. The offset index arithmetic
865 and continuation-variable ordering are load-bearing.
867 Parameters
868 ----------
869 cont_vars : list
870 Names of continuation variables, ordered to match arrival variables.
871 cont_idx : dict
872 Lower-bracket indices for each continuation variable.
873 cont_alpha : dict
874 Interpolation weights (upper bracket) for each continuation variable.
875 cont_M : dict
876 Grid size for each continuation variable.
877 cont_discrete : dict
878 Whether each continuation variable uses a discrete (not continuous) grid.
879 N_orig : int
880 Number of arrival-state grid points.
881 N : int
882 Total number of quasi-simulated agents.
883 D : int
884 Number of continuation dimensions.
886 Returns
887 -------
888 master_trans_array_X : np.ndarray
889 Unnormalized master transition array of shape (N_orig, prod(cont_M)).
890 """
891 pmv = self.data["pmv_"]
892 origin_array = self.origin_array
894 # Count the number of non-trivial dimensions. A continuation dimension
895 # is non-trivial if it is both continuous and has more than one grid node.
896 C = 0
897 shape = [N_orig]
898 trivial = []
899 for var in cont_vars:
900 shape.append(cont_M[var])
901 if (not cont_discrete[var]) and (cont_M[var] > 1):
902 C += 1
903 trivial.append(False)
904 else:
905 trivial.append(True)
906 trivial = np.array(trivial)
908 # Make a binary array of offsets from the base index
909 bin_array_base = np.array(list(product([0, 1], repeat=C)))
910 bin_array = np.empty((2**C, D), dtype=int)
911 some_zeros = np.zeros(2**C, dtype=int)
912 c = 0
913 for d in range(D):
914 bin_array[:, d] = some_zeros if trivial[d] else bin_array_base[:, c]
915 c += not trivial[d]
917 # Make a vector of dimensional offsets from the base index
918 dim_offsets = np.ones(D, dtype=int)
919 for d in range(D - 1):
920 dim_offsets[d] = np.prod(shape[(d + 2) :])
921 dim_offsets_X = np.tile(dim_offsets, (2**C, 1))
922 offsets = np.sum(bin_array * dim_offsets_X, axis=1)
924 # Make combined arrays of indices and alphas
925 index_array = np.empty((N, D), dtype=int)
926 alpha_array = np.empty((N, D, 2))
927 for d in range(D):
928 var = cont_vars[d]
929 index_array[:, d] = cont_idx[var]
930 alpha_array[:, d, 0] = 1.0 - cont_alpha[var]
931 alpha_array[:, d, 1] = cont_alpha[var]
932 idx_array = np.dot(index_array, dim_offsets)
934 # Make the master transition array
935 blank = np.zeros(np.array((N_orig, np.prod(shape[1:]))))
936 master_trans_array_X = calc_overall_trans_probs(
937 blank, idx_array, alpha_array, bin_array, offsets, pmv, origin_array
938 )
939 return master_trans_array_X
941 def _condition_on_survival(self, master_trans_array_X, matrices_out, N_orig):
942 """
943 Condition the master transition array on agent survival.
945 Divides through by the survival probability so that the transition
946 array represents the distribution conditional on not dying this period.
948 Parameters
949 ----------
950 master_trans_array_X : np.ndarray
951 Unconditioned master transition array of shape (N_orig, M) where
952 M = prod(cont_M). Reshaped internally to (N_orig, N_orig, 2)
953 assuming one binary continuation variable (dead/alive).
954 matrices_out : dict
955 Per-variable transition matrices; must contain 'dead'.
956 N_orig : int
957 Number of arrival-state grid points.
959 Returns
960 -------
961 master_trans_array_X : np.ndarray
962 Survival-conditioned master transition array of shape (N_orig, N_orig).
963 """
964 master_trans_array_X = np.reshape(master_trans_array_X, (N_orig, N_orig, 2))
965 survival_probs = np.reshape(matrices_out["dead"][:, 0], [N_orig, 1])
966 master_trans_array_X = master_trans_array_X[..., 0] / survival_probs
967 return master_trans_array_X
969 def make_transition_matrices(self, grid_specs, twist=None, norm=None):
970 """
971 Construct a transition matrix for this block, moving from a discretized
972 grid of arrival variables to a discretized grid of end-of-block variables.
973 User specifies how the grids of pre-states should be built. Output is
974 stored in attributes of self as follows:
976 - matrices : A dictionary of arrays that cast from the arrival state space
977 to the grid of outcome variables. Doing np.dot(dstn, matrices[var])
978 will yield the discretized distribution of that outcome variable.
979 - grids : A dictionary of discretized grids for outcome variables. Doing
980 np.dot(np.dot(dstn, matrices[var]), grids[var]) yields the *average*
981 of that outcome in the population.
982 - trans_array : The full-period Markov transition matrix that goes from
983 arrival variables in t to arrival variables in t+1, including
984 mortality.
986 Parameters
987 ----------
988 grid_specs : dict
989 Dictionary of dictionaries of grid specifications. For now, these have
990 at most a minimum value, a maximum value, a number of nodes, and a poly-
991 nomial order. They are equispaced if a min and max are specified, and
992 polynomially spaced with the specified order > 0 if provided. Otherwise,
993 they are set at 0,..,N if only N is provided.
994 twist : dict or None
995 Mapping from end-of-period (continuation) variables to successor's
996 arrival variables. When this is specified, additional output is created
997 for the "full period" arrival-to-arrival transition matrix.
998 norm : str or None
999 Name of the shock variable by which to normalize for Harmenberg
1000 aggregation. By default, no normalization happens.
1002 Returns
1003 -------
1004 None
1005 """
1006 arrival_N = len(self.arrival)
1008 # Build input and output grids from grid specifications
1009 grids_in, grids_out, continuous_grid_out_bool, grid_orders, dummy_grid = (
1010 self._build_input_grids(grid_specs, arrival_N)
1011 )
1013 # If a twist was specified, override output grids for continuation variables
1014 if twist is not None:
1015 grids_out, grid_orders, grid_out_is_continuous = self._build_twist_grids(
1016 twist, grids_in, grid_orders, grids_out, continuous_grid_out_bool
1017 )
1018 else:
1019 grid_out_is_continuous = np.array(continuous_grid_out_bool)
1021 # Make meshes of all the arrival grids, which will be the initial simulation data
1022 if arrival_N > 0:
1023 state_meshes = np.meshgrid(
1024 *[grids_in[var] for var in self.arrival], indexing="ij"
1025 )
1026 else: # this only happens in the initializer block
1027 state_meshes = [dummy_grid.copy()]
1028 state_init = {
1029 self.arrival[k]: state_meshes[k].flatten() for k in range(arrival_N)
1030 }
1031 N_orig = state_meshes[0].size
1032 mesh_tuples = [
1033 [state_init[self.arrival[k]][n] for k in range(arrival_N)]
1034 for n in range(N_orig)
1035 ]
1037 # Quasi-simulate this block
1038 self.run_quasi_sim(state_init, norm=norm)
1040 # Add survival to output if mortality is in the model
1041 if "dead" in self.data.keys():
1042 grids_out["dead"] = None
1044 # Get continuation variable names, making sure they're in the same order
1045 # as named by the arrival variables. This should maybe be done in the
1046 # simulator when it's initialized.
1047 if twist is not None:
1048 cont_vars_orig = list(twist.keys())
1049 temp_dict = {twist[var]: var for var in cont_vars_orig}
1050 cont_vars = []
1051 for var in self.arrival:
1052 cont_vars.append(temp_dict[var])
1053 if "dead" in self.data.keys():
1054 cont_vars.append("dead")
1055 grid_out_is_continuous = np.concatenate(
1056 (grid_out_is_continuous, [False])
1057 )
1058 else:
1059 cont_vars = list(grids_out.keys()) # all outcomes are arrival vars
1060 D = len(cont_vars)
1062 # Project the final results onto the output or result grids
1063 N = self.N
1064 J = state_meshes[0].size
1065 (
1066 matrices_out,
1067 cont_idx,
1068 cont_alpha,
1069 cont_M,
1070 cont_discrete,
1071 grids_out,
1072 grid_out_is_continuous,
1073 ) = self._project_onto_output_grids(
1074 grids_out,
1075 grid_out_is_continuous,
1076 grid_orders,
1077 cont_vars,
1078 twist,
1079 N_orig,
1080 J,
1081 N,
1082 )
1084 # Construct the master arrival-to-continuation transition array
1085 master_trans_array_X = self._build_master_transition_array(
1086 cont_vars, cont_idx, cont_alpha, cont_M, cont_discrete, N_orig, N, D
1087 )
1089 # Condition on survival if relevant
1090 if "dead" in self.data.keys():
1091 master_trans_array_X = self._condition_on_survival(
1092 master_trans_array_X, matrices_out, N_orig
1093 )
1095 # Reshape the transition matrix depending on what kind of block this is
1096 if arrival_N == 0:
1097 # If this is the initializer block, the "transition" matrix is really
1098 # just the initial distribution of states at model birth; flatten it.
1099 master_init_array = master_trans_array_X.flatten()
1100 else:
1101 # In an ordinary period, reshape the transition array so it's square.
1102 master_trans_array = np.reshape(master_trans_array_X, (N_orig, N_orig))
1104 # Store the results as attributes of self
1105 grids = {}
1106 grids.update(grids_in)
1107 grids.update(grids_out)
1108 self.grids = grids
1109 self.matrices = matrices_out
1110 self.mesh = mesh_tuples
1111 if twist is not None:
1112 self.trans_array = master_trans_array
1113 if arrival_N == 0:
1114 self.init_dstn = master_init_array
1116 def run_quasi_sim(self, data, j0=0, twist=None, norm=None):
1117 """
1118 "Quasi-simulate" this block from given starting data at some event index,
1119 looping back to end at the same point (only if j0 > 0 and twist is given).
1120 To quasi-simulate means to run the model forward for *every* possible shock
1121 realization, tracking probability masses.
1123 If the quasi-simulation loops through the twist, mortality is ignored.
1125 Parameters
1126 ----------
1127 data : dict
1128 Dictionary of initial data, mapping variable names to vectors of values.
1129 j0 : int, optional
1130 Event index number at which to start (and end)) the quasi-simulation.
1131 By default, it is run from index 0.
1132 twist : dict, optional
1133 Optional dictionary mapping end-of-block variables back to arrival variables.
1134 If this is provided *and* j0 > 0, then the quasi-sim is run for a complete
1135 period, starting and ending at the same index. Else it's run to end of period.
1136 norm : str or None
1137 The name of the variable on which to perform Harmenberg normalization.
1139 Returns
1140 -------
1141 None
1142 """
1143 # Make the initial vector of probability masses
1144 if not data: # data is empty because it's initializer block
1145 N_orig = 1
1146 else:
1147 key = list(data.keys())[0]
1148 N_orig = data[key].size
1149 self.N = N_orig
1150 state_init = deepcopy(data)
1151 state_init["pmv_"] = np.ones(self.N)
1153 # Initialize the array of arrival states
1154 origin_array = np.arange(self.N, dtype=int)
1156 # Reset the block's state and give it the initial state data
1157 self.reset()
1158 self.data.update(state_init)
1160 # Loop through each event in order and quasi-simulate it
1161 J = len(self.events)
1162 for j in range(j0, J):
1163 event = self.events[j]
1164 event.data = self.data # Give event *all* data directly
1165 event.N = self.N
1166 origin_array = event.quasi_run(origin_array, norm=norm)
1167 self.N = self.data["pmv_"].size
1169 # If we didn't start at the beginning and there is a twist, loop back to
1170 # the start and do the remaining events
1171 if twist is not None:
1172 new_data = {"pmv_": self.data["pmv_"].copy()}
1173 for end_var in twist.keys():
1174 arr_var = twist[end_var]
1175 new_data[arr_var] = self.data[end_var].copy()
1176 self.data = new_data
1177 for j in range(j0):
1178 event = self.events[j]
1179 event.data = self.data # Give event *all* data directly
1180 event.N = self.N
1181 origin_array = event.quasi_run(origin_array, norm=norm)
1182 self.N = self.data["pmv_"].size
1184 # Assign the origin array as an attribute of self
1185 self.origin_array = origin_array
1188@dataclass(kw_only=True)
1189class AgentSimulator:
1190 """
1191 A class for representing an entire simulator structure for an AgentType.
1192 It includes a sequence of SimBlocks representing periods of the model, which
1193 could be built from the information on an AgentType instance.
1195 Parameters
1196 ----------
1197 name : str
1198 Short name of this model.s
1199 description : str
1200 Textual description of what happens in this simulated block.
1201 statement : str
1202 Verbatim model statement that was used to create this simulator.
1203 comments : dict
1204 Dictionary of comments or descriptions for various model objects.
1205 parameters : list[str]
1206 List of parameter names used in the model.
1207 distributions : list[str]
1208 List of distribution names used in the model.
1209 functions : list[str]
1210 List of function names used in the model.
1211 common: list[str]
1212 Names of variables that are common across idiosyncratic agents.
1213 types: dict
1214 Dictionary of data types for all variables in the model.
1215 N_agents: int
1216 Number of idiosyncratic agents in this simulation.
1217 T_total: int
1218 Total number of periods in these agents' model.
1219 T_sim: int
1220 Maximum number of periods that will be simulated, determining the size
1221 of the history arrays.
1222 T_age: int
1223 Period after which to automatically terminate an agent if they would
1224 survive past this period.
1225 stop_dead : bool
1226 Whether simulated agents who draw dead=True should actually cease acting.
1227 Default is True. Setting to False allows "cohort-style" simulation that
1228 will generate many agents that survive to old ages. In most cases, T_sim
1229 should not exceed T_age, unless the user really does want multiple succ-
1230 essive cohorts to be born and fully simulated.
1231 replace_dead : bool
1232 Whether simulated agents who are marked as dead should be replaced with
1233 newborns (default True) or simply cease acting without replacement (False).
1234 The latter option is useful for models with state-dependent mortality,
1235 to allow "cohort-style" simulation with the correct distribution of states
1236 for survivors at each age. Setting to False has no effect if stop_dead is True.
1237 periods: list[SimBlock]
1238 Ordered list of simulation blocks, each representing a period.
1239 twist : dict
1240 Dictionary that maps period t-1 variables to period t variables, as a
1241 relabeling "between" periods.
1242 initializer : SimBlock
1243 A special simulated block that should have *no* arrival variables, because
1244 it represents the initialization of "newborn" agents.
1245 data : dict
1246 Dictionary that holds *current* values of model variables.
1247 track_vars : list[str]
1248 List of names of variables whose history should be tracked in the simulation.
1249 history : dict
1250 Dictionary that holds the histories of tracked variables.
1251 """
1253 name: str = field(default="")
1254 description: str = field(default="")
1255 statement: str = field(default="", repr=False)
1256 comments: dict = field(default_factory=dict, repr=False)
1257 parameters: list[str] = field(default_factory=list, repr=False)
1258 distributions: list[str] = field(default_factory=list, repr=False)
1259 functions: list[str] = field(default_factory=list, repr=False)
1260 common: list[str] = field(default_factory=list, repr=False)
1261 types: dict = field(default_factory=dict, repr=False)
1262 N_agents: int = field(default=1)
1263 T_total: int = field(default=1, repr=False)
1264 T_sim: int = field(default=1)
1265 T_age: int = field(default=0, repr=False)
1266 stop_dead: bool = field(default=True)
1267 replace_dead: bool = field(default=True)
1268 periods: list[SimBlock] = field(default_factory=list, repr=False)
1269 twist: dict = field(default_factory=dict, repr=False)
1270 data: dict = field(default_factory=dict, repr=False)
1271 initializer: field(default_factory=SimBlock, repr=False)
1272 track_vars: list[str] = field(default_factory=list, repr=False)
1273 history: dict = field(default_factory=dict, repr=False)
1275 def simulate(self, T=None):
1276 """
1277 Simulates the model for T periods, including replacing dead agents as
1278 warranted and storing tracked variables in the history. If T is not
1279 specified, the agents are simulated for the entire T_sim periods.
1280 This is the primary user-facing simulation method.
1281 """
1282 if T is None:
1283 T = self.T_sim - self.t_sim # All remaining simulated periods
1284 if (T + self.t_sim) > self.T_sim:
1285 raise ValueError("Can't simulate more than T_sim periods!")
1287 # Execute the simulation loop for T periods
1288 for t in range(T):
1289 # Do the ordinary work for simulating a period
1290 self.sim_one_period()
1292 # Mark agents who have reached maximum allowable age
1293 if "dead" in self.data.keys() and self.T_age > 0:
1294 too_old = self.t_age == self.T_age
1295 self.data["dead"][too_old] = True
1297 # Record tracked variables and advance age
1298 self.store_tracked_vars()
1299 self.advance_age()
1301 # Handle death and replacement depending on simulation style
1302 if "dead" in self.data.keys() and self.stop_dead:
1303 self.mark_dead_agents()
1304 self.t_sim += 1
1306 def reset(self):
1307 """
1308 Completely reset this simulator back to its original state so that it
1309 can be run from scratch. This should allow it to generate the same results
1310 every single time the simulator is run (if nothing changes).
1311 """
1312 N = self.N_agents
1313 T = self.T_sim
1314 self.t_sim = 0 # Time index for the simulation
1316 # Reset the variable data and history arrays
1317 self.clear_data()
1318 self.history = {}
1319 for var in self.track_vars:
1320 self.history[var] = np.empty((T, N), dtype=self.types[var])
1322 # Reset all of the blocks / periods
1323 self.initializer.reset()
1324 for t in range(len(self.periods)):
1325 self.periods[t].reset()
1327 # Specify all agents as "newborns" assigned to the initializer block
1328 self.t_seq_bool_array = np.zeros((self.T_total, N), dtype=bool)
1329 self.t_age = -np.ones(N, dtype=int)
1331 def clear_data(self, skip=None):
1332 """
1333 Reset all current data arrays back to blank, other than those designated
1334 to be skipped, if any.
1336 Parameters
1337 ----------
1338 skip : [str] or None
1339 Names of variables *not* to be cleared from data. Default is None.
1341 Returns
1342 -------
1343 None
1344 """
1345 if skip is None:
1346 skip = []
1347 N = self.N_agents
1348 for var in self.types.keys():
1349 if var in skip:
1350 continue
1351 this_type = self.types[var]
1352 if this_type is float:
1353 self.data[var] = np.full((N,), np.nan)
1354 elif this_type is bool:
1355 self.data[var] = np.zeros((N,), dtype=bool)
1356 elif this_type is int:
1357 self.data[var] = np.zeros((N,), dtype=np.int32)
1358 elif this_type is complex:
1359 self.data[var] = np.full((N,), np.nan, dtype=complex)
1360 else:
1361 raise ValueError(
1362 "Type "
1363 + str(this_type)
1364 + " of variable "
1365 + var
1366 + " was not recognized!"
1367 )
1369 def mark_dead_agents(self):
1370 """
1371 Looks at the special data field "dead" and marks those agents for replacement.
1372 If no variable called "dead" has been defined, this is skipped.
1373 """
1374 who_died = self.data["dead"]
1375 self.t_seq_bool_array[:, who_died] = False
1376 self.t_age[who_died] = -1
1378 def create_newborns(self):
1379 """
1380 Calls the initializer to generate newborns where needed.
1381 """
1382 # Skip this step if there are no newborns
1383 newborns = self.t_age == -1
1384 if not np.any(newborns):
1385 return
1387 # Generate initial arrival variables
1388 N = np.sum(newborns)
1389 self.initializer.data = {} # by definition
1390 self.initializer.N = N
1391 self.initializer.run()
1393 # Set the initial arrival data for newborns and clear other variables
1394 init_arrival = self.periods[0].arrival
1395 for var in self.types:
1396 self.data[var][newborns] = (
1397 self.initializer.data[var]
1398 if var in init_arrival
1399 else np.empty(N, dtype=self.types[var])
1400 )
1402 # Set newborns' period to 0
1403 self.t_age[newborns] = 0
1404 self.t_seq_bool_array[0, newborns] = True
1406 def store_tracked_vars(self):
1407 """
1408 Record current values of requested variables in the history dictionary.
1409 """
1410 for var in self.track_vars:
1411 self.history[var][self.t_sim, :] = self.data[var]
1413 def advance_age(self):
1414 """
1415 Increments age for all agents, altering t_age and t_age_bool. Agents in
1416 the last period of the sequence will be assigned to the initial period.
1417 In a lifecycle model, those agents should be marked as dead and replaced
1418 in short order.
1419 """
1420 alive = self.t_age >= 0 # Don't age the dead
1421 self.t_age[alive] += 1
1422 X = self.t_seq_bool_array # For shorter typing on next line
1423 self.t_seq_bool_array[:, alive] = np.concatenate(
1424 (X[-1:, alive], X[:-1, alive]), axis=0
1425 )
1427 def sim_one_period(self):
1428 """
1429 Simulates one period of the model by advancing all agents one period.
1430 This includes creating newborns, but it does NOT include eliminating
1431 dead agents nor storing tracked results in the history. This method
1432 should usually not be called by a user, instead using simulate(1) if
1433 you want to run the model for exactly one period.
1434 """
1435 # Use the "twist" information to advance last period's end-of-period
1436 # information/values to be the arrival variables for this period. Then, for
1437 # each variable other than those brought in with the twist, wipe it clean.
1438 keepers = []
1439 for var_tm1 in self.twist:
1440 var_t = self.twist[var_tm1]
1441 keepers.append(var_t)
1442 self.data[var_t] = self.data[var_tm1].copy()
1443 self.clear_data(skip=keepers)
1445 # Create newborns first so the arrival vars exist. This should be done in
1446 # the first simulated period (t_sim=0) or if decedents should be replaced.
1447 if self.replace_dead or self.t_sim == 0:
1448 self.create_newborns()
1450 # Loop through ages and run the model on the appropriately aged agents
1451 for t in range(self.T_total):
1452 these = self.t_seq_bool_array[t, :]
1453 if not np.any(these):
1454 continue # Skip any "empty ages"
1455 this_period = self.periods[t]
1457 data_temp = {var: self.data[var][these] for var in this_period.arrival}
1458 this_period.data = data_temp
1459 this_period.N = np.sum(these)
1460 this_period.run()
1462 # Extract all of the variables from this period and write it to data
1463 for var in this_period.data.keys():
1464 self.data[var][these] = this_period.data[var]
1466 # Put time information into the data dictionary
1467 self.data["t_age"] = self.t_age.copy()
1468 self.data["t_seq"] = np.argmax(self.t_seq_bool_array, axis=0).astype(int)
1470 def make_transition_matrices(
1471 self, grid_specs, norm=None, fake_news_timing=False, for_t=None
1472 ):
1473 """
1474 Build Markov-style transition matrices for each period of the model, as
1475 well as the initial distribution of arrival variables for newborns.
1476 Stores results to the attributes of self as follows:
1478 - trans_arrays : List of Markov matrices for transitioning from the arrival
1479 state space in period t to the arrival state space in t+1.
1480 This transition includes death (and replacement).
1481 - newborn_dstn : Stochastic vector as a NumPy array, representing the distribution
1482 of arrival states for "newborns" who were just initialized.
1483 - state_grids : Nested list of tuples representing the arrival state space for
1484 each period. Each element corresponds to the discretized arrival
1485 state space point with the same index in trans_arrays (and
1486 newborn_dstn). Arrival states are ordered within a tuple in the
1487 same order as the model file. Linked from period[t].mesh.
1488 - outcome_arrays : List of dictionaries of arrays that cast from the arrival
1489 state space to the grid of outcome variables, for each period.
1490 Doing np.dot(state_dstn, outcome_arrays[t][var]) will yield
1491 the discretized distribution of that outcome variable. Linked
1492 from periods[t].matrices.
1493 - outcome_grids : List of dictionaries of discretized outcomes in each period.
1494 Keys are names of outcome variables, and entries are vectors
1495 of discretized values that the outcome variable can take on.
1496 Doing np.dot(np.dot(state_dstn, outcome_arrays[var]), outcome_grids[var])
1497 yields the *average* of that outcome in the population. Linked
1498 from periods[t].grids.
1500 Parameters
1501 ----------
1502 grid_specs : dict
1503 Dictionary of dictionaries with specifications for discretized grids
1504 of all variables of interest. If any arrival variables are omitted,
1505 they will be given a default trivial grid with one node at 0. This
1506 should only be done if that arrival variable is closely tied to the
1507 Harmenberg normalizing variable; see below. A grid specification must
1508 include a number of gridpoints N, and should also include a min and
1509 max if the variable is continuous. If the variable is discrete, the
1510 grid values are assumed to be 0,..,N.
1511 norm : str or None
1512 Name of the variable for which Harmenberg normalization should be
1513 applied, if any. This should be a variable that is directly drawn
1514 from a distribution, not a "downstream" variable.
1515 fake_news_timing : bool
1516 Indicator for whether this call is part of the "fake news" algorithm
1517 for constructing sequence space Jacobians (SSJs). This should only
1518 ever be set to True in that situation, which affects how mortality
1519 is handled between periods. In short, the simulator usually assumes
1520 that "newborns" start with t_seq=0, but during the fake news algorithm,
1521 that is not the case.
1522 for_t : list or None
1523 Optional list of time indices for which the matrices should be built.
1524 When not specified, all periods are constructed. The most common use
1525 for this arg is during the "fake news" algorithm for lifecycle models.
1527 Returns
1528 -------
1529 None
1530 """
1531 # Sort grid specifications into those needed by the initializer vs those
1532 # used by other blocks (ordinary periods)
1533 arrival = self.periods[0].arrival
1534 arrival_N = len(arrival)
1535 check_bool = np.zeros(arrival_N, dtype=bool)
1536 grid_specs_init_orig = {}
1537 grid_specs_other = {}
1538 for name in grid_specs.keys():
1539 if name in arrival:
1540 idx = arrival.index(name)
1541 check_bool[idx] = True
1542 grid_specs_init_orig[name] = copy(grid_specs[name])
1543 grid_specs_other[name] = copy(grid_specs[name])
1545 # Build the dictionary of arrival variables, making sure it's in the
1546 # same order as named self.arrival. For any arrival grids that are
1547 # not specified, make a dummy specification.
1548 grid_specs_init = {}
1549 for n in range(arrival_N):
1550 name = arrival[n]
1551 if check_bool[n]:
1552 grid_specs_init[name] = grid_specs_init_orig[name]
1553 continue
1554 dummy_grid_spec = {"N": 1}
1555 grid_specs_init[name] = dummy_grid_spec
1556 grid_specs_other[name] = dummy_grid_spec
1558 # Make the initial state distribution for newborns
1559 self.initializer.make_transition_matrices(grid_specs_init)
1560 self.newborn_dstn = self.initializer.init_dstn
1561 K = self.newborn_dstn.size
1563 # Make the period-by-period transition matrices
1564 these_t = range(len(self.periods)) if for_t is None else for_t
1565 for t in these_t:
1566 block = self.periods[t]
1567 block.make_transition_matrices(
1568 grid_specs_other, twist=self.twist, norm=norm
1569 )
1570 block.reset()
1571 self.grid_specs = grid_specs_other
1572 self.norm = norm
1574 # Extract the master transition matrices into a single list
1575 p2p_trans_arrays = [block.trans_array for block in self.periods]
1577 # Apply agent replacement to the last period of the model, representing
1578 # newborns filling in for decedents. This will usually only do anything
1579 # at all in "one period infinite horizon" models. If this is part of the
1580 # fake news algorithm for constructing SSJs, then replace decedents with
1581 # newborns in *all* periods, because model timing is funny in this case.
1582 if fake_news_timing:
1583 T_set = np.arange(len(self.periods)).tolist()
1584 else:
1585 T_set = [-1]
1586 newborn_dstn = np.reshape(self.newborn_dstn, (1, K))
1587 for t in T_set:
1588 if "dead" not in self.periods[t].matrices.keys():
1589 continue
1590 death_prbs = self.periods[t].matrices["dead"][:, 1]
1591 p2p_trans_arrays[t] *= np.tile(np.reshape(1 - death_prbs, (K, 1)), (1, K))
1592 p2p_trans_arrays[t] += np.reshape(death_prbs, (K, 1)) * newborn_dstn
1594 # Store the transition arrays as attributes of self
1595 self.trans_arrays = p2p_trans_arrays
1597 # Build and store lists of state meshes, outcome arrays, and outcome grids
1598 self.state_grids = [self.periods[t].mesh for t in range(len(self.periods))]
1599 self.outcome_grids = [self.periods[t].grids for t in range(len(self.periods))]
1600 self.outcome_arrays = [
1601 self.periods[t].matrices for t in range(len(self.periods))
1602 ]
1604 def find_steady_state(self):
1605 """
1606 Calculates the steady state distribution of arrival states for a "one period
1607 infinite horizon" model, storing the result to the attribute steady_state_dstn.
1608 Should only be run after make_transition_matrices(), and only if T_total = 1
1609 and the model is infinite horizon.
1610 """
1611 if self.T_total != 1:
1612 raise ValueError(
1613 "This method currently only works with one period infinite horizon problems."
1614 )
1616 # Find the eigenvector associated with the largest eigenvalue of the
1617 # infinite horizon transition matrix. The largest eigenvalue *should*
1618 # be 1 for any Markov matrix, but double check to be sure.
1619 trans_T = csr_matrix(self.trans_arrays[0].transpose())
1620 v, V = eigs(trans_T, k=1)
1621 if not np.isclose(v[0], 1.0):
1622 raise ValueError(
1623 "The largest eigenvalue of the transition matrix isn't close to 1!"
1624 )
1626 # Normalize that eigenvector and make sure its real, then store it
1627 D = V[:, 0]
1628 SS_dstn = (D / np.sum(D)).real
1629 self.steady_state_dstn = SS_dstn
1631 def get_long_run_dstn(self, var):
1632 """
1633 Calculate and return the long run / steady state population distribution
1634 of one named variable. Should only be run after find_steady_state().
1636 Parameters
1637 ----------
1638 var : str
1639 Name of the variable for which to calculate the long run distribution.
1641 Returns
1642 -------
1643 var_dstn : np.array
1644 Long run / steady state population distribution of the variable,
1645 as a stochastic vector defined on the variable's discretized grid.
1646 """
1647 if not hasattr(self, "steady_state_dstn"):
1648 raise ValueError("This method can only be run after find_steady_state()!")
1650 dstn = self.steady_state_dstn
1651 array = self.outcome_arrays[0][var]
1652 var_dstn = np.dot(dstn, array)
1653 return var_dstn
1655 def get_long_run_average(self, var):
1656 """
1657 Calculate and return the long run / steady state population average of
1658 one named variable. Should only be run after find_steady_state().
1660 Parameters
1661 ----------
1662 var : str
1663 Name of the variable for which to calculate the long run average.
1665 Returns
1666 -------
1667 var_mean : float
1668 Long run / steady state population average of the variable.
1669 """
1670 var_dstn = self.get_long_run_dstn(var)
1671 grid = self.outcome_grids[0][var]
1672 var_mean = np.dot(var_dstn, grid)
1673 return var_mean
1675 def find_target_state(self, target_var, bounds=None, N=201, tol=1e-8, **kwargs):
1676 """
1677 Find the "target" level of a state variable: the value such that the expectation
1678 of next period's state is the same value (when following the policy function),
1679 *and* is locally stable (pushes up from below and down from above when nearby).
1680 Only works for standard infinite horizon models with a single endogenous state
1681 variable. Other variables whose values must be known (e.g. exogenously evolving
1682 states) can also be specified.
1684 The search procedure is to first examine a grid of candidates on the bounds,
1685 calculating E[Delta x] for state x, and then perform a local search for each
1686 interval where it flips from positive to negative.
1688 This procedure ignores mortality entirely. It represents a stable or target
1689 level conditional on the agent continuing from t to t+1.
1691 If additional information must be known, other model variables can be passed
1692 as keyword arguments, e.g. pLvl=1.0. This feature is used for exogenous state
1693 variables, such as persistent income pLvl in the GenIncProcess model. The user
1694 simply passes its mean (central) value, which is easily known in advance.
1696 Parameters
1697 ----------
1698 target_var : str
1699 Name of the state variable of interest.
1700 bounds : [float], optional
1701 Upper and lower boundaries for the target search. If not provided, defaults
1702 to [0.0, 100.0].
1703 N : int, optional
1704 Number of values of the variable of interest to test on the initial pass.
1705 If not provided, defaults to 201. This affects the "resolution" when there
1706 are multiple possible target levels (uncommon).
1707 tol : float, optional
1708 Maximum acceptable deviation from true target E[Delta x] = 0 to be accepted.
1709 If not specified, defaults to 1e-8.
1711 Returns
1712 -------
1713 state_targ : [float]
1714 List of target_var x values such that E[\Delta x] = 0, which can be empty.
1715 """
1716 if self.T_total != 1:
1717 raise ValueError(
1718 "This method currently only works with one period infinite horizon problems."
1719 )
1720 bounds = bounds or [0.0, 100.0]
1721 state_grid = np.linspace(bounds[0], bounds[1], num=N)
1723 # Process keyword arguments into a dictionary of fixed values
1724 period = self.periods[0]
1725 event_count = len(period.events)
1726 fixed = {}
1727 var_names = list(self.types.keys())
1728 for name in kwargs:
1729 if name in var_names:
1730 fixed[name] = kwargs[name]
1731 else:
1732 raise ValueError(
1733 "Could not find a model variable called " + name + " to hold fixed!"
1734 )
1736 # Find the event index at which to start and stop the quasi-simulation
1737 var_names = [target_var] + list(fixed.keys())
1738 var_count = len(var_names)
1739 found_count = 0
1740 found = var_count * [False]
1741 j = 0
1742 if target_var not in period.arrival:
1743 while (found_count < var_count) and (j < event_count):
1744 event = period.events[j]
1745 assigns = event.assigns
1746 for i in range(var_count):
1747 if (var_names[i] in assigns) and (not found[i]):
1748 found[i] = True
1749 found_count += 1
1750 j += 1
1751 if not np.all(found):
1752 raise ValueError(
1753 "Could not find events that assign target variable and all fixed variables!"
1754 )
1755 idx0 = j # Event index where the quasi-sim should start and stop
1757 # Construct the starting information set for the quasi-simulation
1758 data_init = {}
1759 trivial_vars = []
1761 # Assign dummy data for all vars assigned prior to start/stop
1762 for var in period.arrival:
1763 data_init[var] = np.zeros(N, dtype=int) # dummy data
1764 trivial_vars.append(var)
1765 for j in range(idx0):
1766 event = period.events[j]
1767 assigns = event.assigns
1768 for var in assigns:
1769 data_init[var] = np.zeros(N, dtype=int) # dummy data
1770 trivial_vars.append(var)
1772 # Assign fixed data and the grid of candidate target values
1773 for key in fixed.keys():
1774 data_init[key] = fixed[key] * np.ones(N)
1775 data_init[target_var] = state_grid
1777 # Run the quasi-simulation on the initial grid of states
1778 period.run_quasi_sim(data_init, j0=idx0, twist=self.twist)
1779 origins = period.origin_array
1780 data_final = period.data[target_var]
1781 pmv_final = period.data["pmv_"]
1783 # Calculate mean value of next period's state at each point in the grid
1784 E_state_next = np.empty(N)
1785 for n in range(N):
1786 these = origins == n
1787 E_state_next[n] = np.dot(pmv_final[these], data_final[these])
1788 E_delta_state = E_state_next - state_grid # expected change in state
1790 # Find indices in the grid where E[\Delta x] flips from positive to negative
1791 sign = E_delta_state > 0.0
1792 flip = np.logical_and(sign[:-1], np.logical_not(sign[1:]))
1793 flip_idx = np.argwhere(flip).flatten()
1794 if flip_idx.size == 0:
1795 state_targ = []
1796 return state_targ
1798 # Reduce the fixed values in data_init to single valued vectors
1799 for var in trivial_vars:
1800 data_init[var] = np.array([0])
1801 for key in fixed.keys():
1802 data_init[key] = np.array([fixed[key]])
1804 # Define a function that can be used to search for states where E[\Delta x] = 0
1805 def delta_zero_func(x):
1806 data_init[target_var] = np.array([x])
1807 period.run_quasi_sim(data_init, j0=idx0, twist=self.twist)
1808 data_final = period.data[target_var]
1809 pmv_final = period.data["pmv_"]
1810 E_delta = np.dot(pmv_final, data_final) - x
1811 return E_delta
1813 # For each segment index with a sign flip for E[\Delta x], find x_targ
1814 state_targ = []
1815 for i in flip_idx:
1816 bot = state_grid[i]
1817 top = state_grid[i + 1]
1818 x_targ = brentq(delta_zero_func, bot, top, xtol=tol, rtol=tol)
1819 state_targ.append(x_targ)
1821 # Return the output
1822 return state_targ
1824 def simulate_shock_by_grids(
1825 self,
1826 outcomes,
1827 T,
1828 shock=None,
1829 from_dstn=None,
1830 calc_dstn=False,
1831 calc_avg=True,
1832 ):
1833 """
1834 Generate the time series of population outcomes in response to an unexpected
1835 shock. The shock can be specified as additive or multiplicative events that
1836 are applied to the steady state distribution of arrival states, or as a user-
1837 specified distribution of arrival states. This method is intended only for
1838 infinite horizon, single period models. Stores results in the dictionary
1839 attributes history_avg and history_dstn respectively.
1841 This method can only be run after running make_transition_matrices.
1843 Parameters
1844 ----------
1845 outcomes : str or [str]
1846 Names of one or more outcome variables
1847 T : int
1848 Number of periods to simulate after the shock.
1849 shock : str or [str], optional
1850 One or more of "shock operations" to be applied to the steady state
1851 (or custom distribution, if specified). Each shock operation should
1852 name a continuation variable (something named on the left side of
1853 the twist), and be followed by an operator and a value. At this time,
1854 the only valid operators are "+", "*", and "=". For example, the shock
1855 "aNrm + 0.1" means that 0.1 should be added to (the distribution of)
1856 end-of-period assets, while "pLvl * 0.8" means that permanent income
1857 should be reduced by 20% for the entire population. The "=" operator
1858 shifts the entire population to the specified value. Not all arrival
1859 variables must be named in this argument. Indeed, none need be named.
1860 The numeric value should *not* use scientific notation nor other math
1861 operations; e.g. use "0.0001" and not "1e-4".
1862 from_dstn : np.array, optional
1863 If provided, a user-specified distribution of arrival states. If none is
1864 given (typical), then the steady state distribution is used. Any shocks
1865 described in shock are applied to this initial distribution.
1866 calc_dstn : bool, optional
1867 Whether to store the distribution of outcomes over time in history_dstn.
1868 The default is False.
1869 calc_avg : bool, optional
1870 Whether to store the population average of the outcomes over time in
1871 history_avg. The default is True.
1873 Returns
1874 -------
1875 None.
1876 """
1877 if not (calc_dstn or calc_avg):
1878 raise ValueError(
1879 "At least one of calc_dstn or calc_avg must be true, or there's no work!"
1880 )
1881 if (shock is None) and (from_dstn is None):
1882 raise ValueError(
1883 "The shock or from_dstn must be specified, or there's nothing to simulate!"
1884 )
1885 if self.T_total != 1:
1886 raise ValueError(
1887 "simulate_shock_by_grids is only implemented for infinite-horizon models with T_total == 1."
1888 )
1889 if not hasattr(self, "trans_arrays"):
1890 raise KeyError(
1891 "This method can't be run before running make_transition_matrices!"
1892 )
1893 if shock is None:
1894 shock = []
1895 if type(shock) is str:
1896 shock = [shock]
1897 if isinstance(outcomes, str):
1898 outcomes = [outcomes]
1900 # Get the starting unperturbed distribution
1901 if from_dstn is None:
1902 if not hasattr(self, "steady_state_dstn"):
1903 self.find_steady_state()
1904 init_dstn = self.steady_state_dstn
1905 else:
1906 dstn_sum = np.sum(from_dstn)
1907 dstn_N = from_dstn.size
1908 if not np.isclose(dstn_sum, 1.0):
1909 raise ValueError(
1910 "Specified from_dstn should be a stochastic vector, but its values sum to "
1911 + str(dstn_sum)
1912 )
1913 arrival_N = len(self.state_grids[0])
1914 if not arrival_N == dstn_N:
1915 raise ValueError(
1916 "Specified from_dstn should be a vector of size "
1917 + str(arrival_N)
1918 + ", but has size "
1919 + str(dstn_N)
1920 + "!"
1921 )
1922 init_dstn = from_dstn
1924 # Make dynamic event strings for each shock statement
1925 event_strings = []
1926 shock_vars = []
1927 op_list = ["+", "*", "="]
1928 for k in range(len(shock)):
1929 # Parse the shock statement for its parts
1930 S = shock[k]
1931 op = None
1932 for j in range(len(op_list)):
1933 if op_list[j] in S:
1934 op = op_list[j]
1935 break
1936 if op is None:
1937 raise ValueError(
1938 "The shock statement (" + S + ") did not contain a valid operator!"
1939 )
1940 loc = S.index(op)
1941 var = S[:loc].strip()
1942 val = S[(loc + 1) :].strip()
1943 if var not in self.twist:
1944 raise KeyError(
1945 "All shocked variables must be continuation states, but "
1946 + var
1947 + " is not!"
1948 )
1949 try:
1950 float(val)
1951 except (ValueError, TypeError):
1952 raise ValueError("Couldn't interpret " + val + " as a number!")
1954 # Make a string for this shock event
1955 var_alt = self.twist[var]
1956 if op == "+":
1957 this_event = var + " = " + var_alt + " + " + val
1958 elif op == "*":
1959 this_event = var + " = " + var_alt + " * " + val
1960 elif op == "=":
1961 this_event = var + " = " + var_alt + " * 0.0 + " + val
1962 event_strings.append(this_event)
1963 shock_vars.append(var)
1965 # For any continuation states that weren't shocked, make a trivial event
1966 cont_vars = list(self.twist.keys())
1967 for k in range(len(cont_vars)):
1968 var = cont_vars[k]
1969 if var in shock_vars:
1970 continue
1971 var_alt = self.twist[var]
1972 this_event = var + " = " + var_alt
1973 event_strings.append(this_event)
1975 # Extract grid specifications only for arrival and continuation variables
1976 grid_specs_temp = {}
1977 for var in self.grid_specs:
1978 if (var in self.twist) or (var in self.periods[0].arrival):
1979 grid_specs_temp[var] = self.grid_specs[var]
1981 # Make a fake model block that applies the shock
1982 shock_model = {"name": "exogenous shock", "dynamics": event_strings}
1983 shock_block, info, offset, solution, comments = make_template_block(
1984 shock_model, arrival=self.periods[0].arrival
1985 )
1986 shock_block.make_transition_matrices(grid_specs_temp, twist=self.twist)
1987 shock_block.reset()
1989 # Apply the shock transition to the starting distribution
1990 init_dstn = np.dot(init_dstn, shock_block.trans_array)
1992 # Initialize generated output
1993 history_dstn = {}
1994 history_avg = {}
1996 # Initialize the state distribution
1997 current_dstn = init_dstn.copy()
1999 # If we need the full distribution history, allocate state_dstn_by_t;
2000 # otherwise, avoid this potentially large O(num_states * T) array.
2001 if calc_dstn:
2002 state_dstn_by_t = np.empty((current_dstn.size, T))
2003 else:
2004 state_dstn_by_t = None
2006 trans_array = csc_matrix(self.trans_arrays[0])
2008 # If we only need averages (no full distributions), we can stream
2009 # the averages over time without storing the full state history.
2010 if calc_avg and not calc_dstn:
2011 outcome_arrays_0 = self.outcome_arrays[0]
2012 outcome_grids_0 = self.outcome_grids[0]
2013 for name in outcomes:
2014 history_avg[name] = np.empty(T)
2016 # Loop over requested periods of this agent type's model
2017 for t in range(T):
2018 # Store full state distribution history only if requested
2019 if calc_dstn:
2020 state_dstn_by_t[:, t] = current_dstn
2022 # Stream averages when only averages are needed
2023 if calc_avg and not calc_dstn:
2024 for name in outcomes:
2025 this_outcome = outcome_arrays_0[name]
2026 this_grid = outcome_grids_0[name]
2027 this_dstn_t = np.dot(current_dstn, this_outcome)
2028 history_avg[name][t] = np.dot(this_grid, this_dstn_t)
2030 current_dstn = current_dstn @ trans_array
2032 # Calculate history of outcomes as requested
2033 for name in outcomes:
2034 this_outcome = self.outcome_arrays[0][name]
2035 this_grid = self.outcome_grids[0][name]
2037 if calc_dstn:
2038 this_dstn = np.dot(this_outcome.T, state_dstn_by_t)
2039 history_dstn[name] = this_dstn
2041 if calc_avg:
2042 history_avg[name] = np.dot(this_grid, this_dstn)
2043 elif calc_avg:
2044 # Averages have already been filled in the time loop
2045 continue
2046 # Store results as attributes of self
2047 self.history_dstn = history_dstn
2048 self.history_avg = history_avg
2050 def simulate_cohort_by_grids(
2051 self,
2052 outcomes,
2053 T_max=None,
2054 calc_dstn=False,
2055 calc_avg=True,
2056 from_dstn=None,
2057 ):
2058 """
2059 Generate a simulated "cohort style" history for this type of agents using
2060 discretized grid methods. Can only be run after running make_transition_matrices().
2061 Starting from the distribution of states at birth, the population is moved
2062 forward in time via the transition matrices, and the distribution and/or
2063 average of specified outcomes are stored in the dictionary attributes
2064 history_dstn and history_avg respectively.
2066 Parameters
2067 ----------
2068 outcomes : str or [str]
2069 Names of one or more outcome variables to be tracked during the grid
2070 simulation. Each named variable should have an outcome grid specified
2071 when make_transition_matrices() was called, whether explicitly or
2072 implicitly. The existence of these grids is checked as a first step.
2073 T_max : int or None
2074 If specified, the number of periods of the model to actually generate
2075 output for. If not specified, all periods are run.
2076 calc_dstn : bool
2077 Whether outcome distributions should be stored in the dictionary
2078 attribute history_dstn. The default is False.
2079 calc_avg : bool
2080 Whether outcome averages should be stored in the dictionary attribute
2081 history_avg. The default is True.
2082 from_dstn : np.array or None
2083 Optional initial distribution of arrival states. If not specified, the
2084 newborn distribution in the initializer is assumed to be used.
2086 Returns
2087 -------
2088 None
2089 """
2090 # First, verify that newborn and transition matrices exist for all periods
2091 if not hasattr(self, "newborn_dstn"):
2092 raise ValueError(
2093 "The newborn state distribution does not exist; make_transition_matrices() must be run before grid simulations!"
2094 )
2095 if T_max is None:
2096 T_max = self.T_total
2097 T_max = np.minimum(T_max, self.T_total)
2098 if not hasattr(self, "trans_arrays"):
2099 raise ValueError(
2100 "The transition arrays do not exist; make_transition_matrices() must be run before grid simulations!"
2101 )
2102 if len(self.trans_arrays) < T_max:
2103 raise ValueError(
2104 "There are somehow fewer elements of trans_array than there should be!"
2105 )
2106 if not (calc_dstn or calc_avg):
2107 return # No work actually requested, we're done here
2109 # Initialize generated output as requested
2110 if isinstance(outcomes, str):
2111 outcomes = [outcomes]
2112 if calc_dstn:
2113 history_dstn = {}
2114 for name in outcomes: # List will be concatenated to array at end
2115 history_dstn[name] = [] # if all distributions are same size
2116 if calc_avg:
2117 history_avg = {}
2118 for name in outcomes:
2119 history_avg[name] = np.empty(T_max)
2121 # Initialize the state distribution
2122 current_dstn = (
2123 self.newborn_dstn.copy() if from_dstn is None else from_dstn.copy()
2124 )
2125 state_dstn_by_age = []
2127 # Loop over requested periods of this agent type's model
2128 for t in range(T_max):
2129 state_dstn_by_age.append(current_dstn)
2131 # Calculate outcome distributions and averages as requested
2132 for name in outcomes:
2133 this_outcome = self.periods[t].matrices[name].transpose()
2134 this_dstn = np.dot(this_outcome, current_dstn)
2135 if calc_dstn:
2136 history_dstn[name].append(this_dstn)
2137 if calc_avg:
2138 this_grid = self.periods[t].grids[name]
2139 history_avg[name][t] = np.dot(this_dstn, this_grid)
2141 # Advance the distribution to the next period
2142 current_dstn = np.dot(self.trans_arrays[t].transpose(), current_dstn)
2144 # Reshape the distribution histories if possible
2145 if calc_dstn:
2146 for name in outcomes:
2147 dstn_sizes = np.array([dstn.size for dstn in history_dstn[name]])
2148 if np.all(dstn_sizes == dstn_sizes[0]):
2149 history_dstn[name] = np.stack(history_dstn[name], axis=1)
2151 # Store results as attributes of self
2152 self.state_dstn_by_age = state_dstn_by_age
2153 if calc_dstn:
2154 self.history_dstn = history_dstn
2155 if calc_avg:
2156 self.history_avg = history_avg
2158 def describe_model(self, display=True):
2159 """
2160 Convenience method that prints model information to screen.
2161 """
2162 # Make a twist statement
2163 twist_statement = ""
2164 for var_tm1 in self.twist.keys():
2165 var_t = self.twist[var_tm1]
2166 new_line = var_tm1 + "[t-1] <---> " + var_t + "[t]\n"
2167 twist_statement += new_line
2169 # Assemble the overall model statement
2170 output = ""
2171 output += "----------------------------------\n"
2172 output += "%%%%% INITIALIZATION AT BIRTH %%%%\n"
2173 output += "----------------------------------\n"
2174 output += self.initializer.statement
2175 output += "----------------------------------\n"
2176 output += "%%%% DYNAMICS WITHIN PERIOD t %%%%\n"
2177 output += "----------------------------------\n"
2178 output += self.statement
2179 output += "----------------------------------\n"
2180 output += "%%%%%%% RELABELING / TWIST %%%%%%%\n"
2181 output += "----------------------------------\n"
2182 output += twist_statement
2183 output += "-----------------------------------"
2185 # Return or print the output
2186 if display:
2187 print(output)
2188 return
2189 else:
2190 return output
2192 def describe_symbols(self, display=True):
2193 """
2194 Convenience method that prints symbol information to screen.
2195 """
2196 # Get names and types
2197 symbols_lines = []
2198 comments = []
2199 for key in self.comments.keys():
2200 comments.append(self.comments[key])
2202 # Get type of object
2203 if key in self.types.keys():
2204 this_type = str(self.types[key].__name__)
2205 elif key in self.distributions:
2206 this_type = "dstn"
2207 elif key in self.parameters:
2208 this_type = "param"
2209 elif key in self.functions:
2210 this_type = "func"
2212 # Add tags
2213 if key in self.common:
2214 this_type += ", common"
2215 # if key in self.solution:
2216 # this_type += ', solution'
2217 this_line = key + " (" + this_type + ")"
2218 symbols_lines.append(this_line)
2220 # Add comments, aligned
2221 symbols_text = ""
2222 longest = np.max([len(this) for this in symbols_lines])
2223 for j in range(len(symbols_lines)):
2224 line = symbols_lines[j]
2225 comment = comments[j]
2226 L = len(line)
2227 pad = (longest + 1) - L
2228 symbols_text += line + pad * " " + ": " + comment + "\n"
2230 # Return or print the output
2231 output = symbols_text
2232 if display:
2233 print(output)
2234 return
2235 else:
2236 return output
2238 def describe(self, symbols=True, model=True, display=True):
2239 """
2240 Convenience method for showing all information about the model.
2241 """
2242 # Asssemble the requested output
2243 output = self.name + ": " + self.description + "\n"
2244 if symbols or model:
2245 output += "\n"
2246 if symbols:
2247 output += "----------------------------------\n"
2248 output += "%%%%%%%%%%%%% SYMBOLS %%%%%%%%%%%%\n"
2249 output += "----------------------------------\n"
2250 output += self.describe_symbols(display=False)
2251 if model:
2252 output += self.describe_model(display=False)
2253 if symbols and not model:
2254 output += "----------------------------------"
2256 # Return or print the output
2257 if display:
2258 print(output)
2259 return
2260 else:
2261 return output
2264def _parse_model_fields(model, common_override=None):
2265 """
2266 Extract the top-level fields from a parsed model dictionary.
2268 Uses dict.get() with safe defaults rather than try/except for each field,
2269 so that missing keys silently receive their default values.
2271 Parameters
2272 ----------
2273 model : dict
2274 Parsed YAML model dictionary.
2275 common_override : list or None
2276 If provided, overrides the model's 'common' field entirely.
2278 Returns
2279 -------
2280 model_name : str
2281 Name of the model, or 'DEFAULT_NAME' if absent.
2282 description : str
2283 Human-readable description, or a placeholder if absent.
2284 variables : list
2285 Declared variable lines from model['symbols']['variables'].
2286 twist : dict
2287 Intertemporal twist mapping, or empty dict if absent.
2288 common : list
2289 Variables shared across all agents.
2290 arrival : list
2291 Explicitly listed arrival variable names.
2292 """
2293 symbols = model.get("symbols", {})
2294 model_name = model.get("name", "DEFAULT_NAME")
2295 description = model.get("description", "(no description provided)")
2296 variables = symbols.get("variables", [])
2297 twist = model.get("twist", {})
2298 arrival = symbols.get("arrival", [])
2299 if common_override is not None:
2300 common = common_override
2301 else:
2302 common = symbols.get("common", [])
2303 return model_name, description, variables, twist, common, arrival
2306def _build_periods(
2307 template, agent, content, solution, offset, time_vary, time_inv, RNG, T_seq, T_cycle
2308):
2309 """
2310 Construct the list of per-period SimBlock copies for an AgentSimulator.
2312 For each period in the solution sequence, a deep copy of the template block
2313 is made and populated with the appropriate parameter data drawn from the agent.
2315 Parameters
2316 ----------
2317 template : SimBlock
2318 Template block with structure but no parameter values.
2319 agent : AgentType
2320 The agent whose solution and time-varying attributes supply parameter values.
2321 content : dict
2322 Keys are the names of objects needed by the template block.
2323 solution : list
2324 Names of objects that come from the agent's solution attribute.
2325 offset : list
2326 Names of time-varying objects whose index is shifted back by one period.
2327 time_vary : list
2328 Names of objects that vary across periods (drawn from agent attributes).
2329 time_inv : list
2330 Names of objects that are time-invariant (same across all periods).
2331 RNG : np.random.Generator
2332 Random number generator used to assign unique seeds to MarkovEvents.
2333 T_seq : int
2334 Number of periods in the solution sequence.
2335 T_cycle : int
2336 Number of periods per cycle (used to wrap the time index).
2338 Returns
2339 -------
2340 periods : list[SimBlock]
2341 Fully populated list of period blocks, one per entry in the solution.
2342 """
2343 # Build the time-invariant parameter dictionary once
2344 time_inv_dict = {}
2345 for name in content:
2346 if name in time_inv:
2347 if not hasattr(agent, name):
2348 raise ValueError(
2349 "Couldn't get a value for time-invariant object "
2350 + name
2351 + ": attribute does not exist on the agent."
2352 )
2353 time_inv_dict[name] = getattr(agent, name)
2355 periods = []
2356 t_cycle = 0
2357 for t in range(T_seq):
2358 # Make a fresh copy of the template period
2359 new_period = deepcopy(template)
2361 # Make sure each period's events have unique seeds; this is only for MarkovEvents
2362 for event in new_period.events:
2363 if hasattr(event, "seed"):
2364 event.seed = RNG.integers(0, 2**31 - 1)
2366 # Make the parameter dictionary for this period
2367 new_param_dict = deepcopy(time_inv_dict)
2368 for name in content:
2369 if name in solution:
2370 if type(agent.solution[t]) is dict:
2371 new_param_dict[name] = agent.solution[t][name]
2372 else:
2373 new_param_dict[name] = getattr(agent.solution[t], name)
2374 elif name in time_vary:
2375 s = (t_cycle - 1) if name in offset else t_cycle
2376 attr = getattr(agent, name, None)
2377 if attr is None:
2378 raise ValueError(
2379 "Couldn't get a value for time-varying object "
2380 + name
2381 + ": attribute does not exist on the agent."
2382 )
2383 try:
2384 new_param_dict[name] = attr[s]
2385 except (IndexError, TypeError):
2386 raise ValueError(
2387 "Couldn't get a value for time-varying object "
2388 + name
2389 + " at time index "
2390 + str(s)
2391 + "!"
2392 )
2393 elif name in time_inv:
2394 continue
2395 else:
2396 raise ValueError(
2397 "The object called "
2398 + name
2399 + " is not named in time_inv nor time_vary!"
2400 )
2402 # Fill in content for this period, then add it to the list
2403 new_period.content = new_param_dict
2404 new_period.distribute_content()
2405 periods.append(new_period)
2407 # Advance time according to the cycle
2408 t_cycle += 1
2409 if t_cycle == T_cycle:
2410 t_cycle = 0
2412 return periods
2415def make_simulator_from_agent(agent, stop_dead=True, replace_dead=True, common=None):
2416 """
2417 Build an AgentSimulator instance based on an AgentType instance. The AgentType
2418 should have its model attribute defined so that it can be parsed and translated
2419 into the simulator structure. The names of objects in the model statement
2420 should correspond to attributes of the AgentType.
2422 Parameters
2423 ----------
2424 agent : AgentType
2425 Agents for whom a new simulator is to be constructed.
2426 stop_dead : bool
2427 Whether simulated agents who draw dead=True should actually cease acting.
2428 Default is True. Setting to False allows "cohort-style" simulation that
2429 will generate many agents that survive to old ages. In most cases, T_sim
2430 should not exceed T_age, unless the user really does want multiple succ-
2431 essive cohorts to be born and fully simulated.
2432 replace_dead : bool
2433 Whether simulated agents who are marked as dead should be replaced with
2434 newborns (default True) or simply cease acting without replacement (False).
2435 The latter option is useful for models with state-dependent mortality,
2436 to allow "cohort-style" simulation with the correct distribution of states
2437 for survivors at each age. Setting False has no effect if stop_dead is True.
2438 common : [str] or None
2439 List of random variables that should be treated as commonly shared across
2440 all agents, rather than idiosyncratically drawn. If this is provided, it
2441 will override the model defaults.
2443 Returns
2444 -------
2445 new_simulator : AgentSimulator
2446 A simulator structure based on the agents.
2447 """
2448 # Read the model statement into a dictionary, and get names of attributes
2449 if hasattr(agent, "model_statement"): # look for a custom model statement
2450 model_statement = copy(agent.model_statement)
2451 else: # otherwise use the default model file
2452 with importlib.resources.open_text("HARK.models", agent.model_file) as f:
2453 model_statement = f.read()
2454 f.close()
2455 model = yaml.safe_load(model_statement)
2456 time_vary = agent.time_vary
2457 time_inv = agent.time_inv
2458 cycles = agent.cycles
2459 T_age = agent.T_age
2460 comments = {}
2461 RNG = agent.RNG # this is only for generating seeds for MarkovEvents
2463 # Extract basic fields from the model using helper
2464 model_name, description, variables, twist, common, arrival = _parse_model_fields(
2465 model, common_override=common
2466 )
2468 # Make a dictionary of declared data types and add comments
2469 types = {}
2470 for var_line in variables: # Loop through declared variables
2471 var_name, var_type, flags, desc = parse_declaration_for_parts(var_line)
2472 if var_type is not None:
2473 try:
2474 var_type = eval(var_type)
2475 except:
2476 raise ValueError(
2477 "Couldn't understand type "
2478 + var_type
2479 + " for declared variable "
2480 + var_name
2481 + "!"
2482 )
2483 else:
2484 var_type = float
2485 types[var_name] = var_type
2486 comments[var_name] = desc
2487 if ("arrival" in flags) and (var_name not in arrival):
2488 arrival.append(var_name)
2489 if ("common" in flags) and (var_name not in common):
2490 common.append(var_name)
2492 # Make a blank "template" period with structure but no data
2493 template_period, information, offset, solution, block_comments = (
2494 make_template_block(model, arrival, common)
2495 )
2496 comments.update(block_comments)
2498 # Make the agent initializer, without parameter values (etc)
2499 initializer, init_info = make_initializer(model, arrival, common)
2501 # Extract basic fields from the template period and model
2502 statement = template_period.statement
2503 content = template_period.content
2505 # Get the names of parameters, functions, and distributions
2506 parameters = []
2507 functions = []
2508 distributions = []
2509 for key in information.keys():
2510 val = information[key]
2511 if val is None:
2512 parameters.append(key)
2513 elif type(val) is NullFunc:
2514 functions.append(key)
2515 elif type(val) is Distribution:
2516 distributions.append(key)
2518 # Loop through variables that appear in the model block but were undeclared
2519 for var in information.keys():
2520 if var in types.keys():
2521 continue
2522 this = information[var]
2523 if (this is None) or (type(this) is Distribution) or (type(this) is NullFunc):
2524 continue
2525 types[var] = float
2526 comments[var] = ""
2527 if "dead" in types.keys():
2528 types["dead"] = bool
2529 comments["dead"] = "whether agent died this period"
2530 types["t_seq"] = int
2531 types["t_age"] = int
2532 comments["t_seq"] = "which period of the sequence the agent is on"
2533 comments["t_age"] = "how many periods the agent has already lived for"
2535 # Make a dictionary for the initializer and distribute information
2536 init_dict = {}
2537 for name in init_info.keys():
2538 try:
2539 init_dict[name] = getattr(agent, name)
2540 except:
2541 raise ValueError(
2542 "Couldn't get a value for initializer object " + name + "!"
2543 )
2544 initializer.content = init_dict
2545 initializer.distribute_content()
2547 # Create a list of periods, pulling appropriate data from the agent for each one
2548 T_seq = len(agent.solution) # Number of periods in the solution sequence
2549 T_cycle = agent.T_cycle
2550 periods = _build_periods(
2551 template_period,
2552 agent,
2553 content,
2554 solution,
2555 offset,
2556 time_vary,
2557 time_inv,
2558 RNG,
2559 T_seq,
2560 T_cycle,
2561 )
2563 # Calculate maximum age
2564 if T_age is None:
2565 T_age = 0
2566 if cycles > 0:
2567 T_age_max = T_seq - 1
2568 T_age = np.minimum(T_age_max, T_age)
2569 try:
2570 T_sim = agent.T_sim
2571 except:
2572 T_sim = 0 # very boring default!
2574 # Make and return the new simulator
2575 new_simulator = AgentSimulator(
2576 name=model_name,
2577 description=description,
2578 statement=statement,
2579 comments=comments,
2580 parameters=parameters,
2581 functions=functions,
2582 distributions=distributions,
2583 common=common,
2584 types=types,
2585 N_agents=agent.AgentCount,
2586 T_total=T_seq,
2587 T_sim=T_sim,
2588 T_age=T_age,
2589 stop_dead=stop_dead,
2590 replace_dead=replace_dead,
2591 periods=periods,
2592 twist=twist,
2593 initializer=initializer,
2594 track_vars=agent.track_vars,
2595 )
2596 new_simulator.solution = solution # this is for use by SSJ constructor
2597 return new_simulator
2600def _extract_symbol_class(
2601 model, class_name, constructor, validator_msg, offset, solution, comments
2602):
2603 """
2604 Parse and collect one class of symbols (parameters, functions, or distributions).
2606 Handles the near-identical pattern repeated for each symbol class:
2607 iterate over declaration lines, build the result dict, record comments,
2608 and append names to the offset and solution lists as flagged.
2610 Parameters
2611 ----------
2612 model : dict
2613 Parsed model dictionary containing a 'symbols' sub-dict.
2614 class_name : str
2615 Key within model['symbols'] to look up ('parameters', 'functions', or
2616 'distributions').
2617 constructor : callable or None
2618 Called with no arguments to create each entry's value. Pass None for
2619 parameters (which use None as their placeholder value).
2620 validator_msg : str or None
2621 If provided, the expected datatype string (e.g. 'func' or 'dstn'). When a
2622 declaration carries a different datatype, a ValueError is raised. Pass None
2623 to skip validation (used for parameters).
2624 offset : list
2625 Accumulated list of offset-flagged names; extended in place.
2626 solution : list
2627 Accumulated list of solution-flagged names; extended in place.
2628 comments : dict
2629 Accumulated comment strings keyed by name; updated in place.
2631 Returns
2632 -------
2633 result : dict
2634 Mapping from symbol name to its constructed value (or None for parameters).
2635 """
2636 result = {}
2637 symbols = model.get("symbols", {})
2638 if class_name not in symbols:
2639 return result
2640 lines = symbols[class_name]
2641 for line in lines:
2642 name, datatype, flags, desc = parse_declaration_for_parts(line)
2643 if (
2644 (validator_msg is not None)
2645 and (datatype is not None)
2646 and (datatype != validator_msg)
2647 ):
2648 raise ValueError(
2649 name
2650 + " was declared as a "
2651 + class_name[:-1]
2652 + ", but given a different datatype!"
2653 )
2654 result[name] = constructor() if constructor is not None else None
2655 comments[name] = desc
2656 if ("offset" in flags) and (name not in offset):
2657 offset.append(name)
2658 if ("solution" in flags) and (name not in solution):
2659 solution.append(name)
2660 return result
2663def make_template_block(model, arrival=None, common=None):
2664 """
2665 Construct a new SimBlock object as a "template" of the model block. It has
2666 events and reference information, but no values filled in.
2668 Parameters
2669 ----------
2670 model : dict
2671 Dictionary with model block information, probably read in as a yaml.
2672 arrival : [str] or None
2673 List of arrival variables that were flagged or explicitly listed.
2674 common : [str] or None
2675 List of variables that are common or shared across all agents, rather
2676 than idiosyncratically drawn.
2678 Returns
2679 -------
2680 template_block : SimBlock
2681 A "template" of this model block, with no parameters (etc) on it.
2682 info : dict
2683 Dictionary of model objects that were referenced within the block. Keys
2684 are object names and entries reveal what kind of object they are:
2685 - None --> parameter
2686 - 0 --> outcome/data variable (including arrival variables)
2687 - NullFunc --> function
2688 - Distribution --> distribution
2689 offset : [str]
2690 List of object names that are offset in time by one period.
2691 solution : [str]
2692 List of object names that are part of the model solution.
2693 comments : dict
2694 Dictionary of comments included with declared functions, distributions,
2695 and parameters.
2696 """
2697 if arrival is None:
2698 arrival = []
2699 if common is None:
2700 common = []
2702 # Extract explicitly listed metadata using dict.get for safe defaults
2703 symbols = model.get("symbols", {})
2704 name = model.get("name", None)
2705 offset = symbols.get("offset", [])
2706 solution = symbols.get("solution", [])
2708 # Extract parameters, functions, and distributions using the shared helper
2709 comments = {}
2710 parameters = _extract_symbol_class(
2711 model, "parameters", None, None, offset, solution, comments
2712 )
2713 functions = _extract_symbol_class(
2714 model, "functions", NullFunc, "func", offset, solution, comments
2715 )
2716 distributions = _extract_symbol_class(
2717 model, "distributions", Distribution, "dstn", offset, solution, comments
2718 )
2720 # Combine those dictionaries into a single "information" dictionary, which
2721 # represents objects available *at that point* in the dynamic block
2722 content = parameters.copy()
2723 content.update(functions)
2724 content.update(distributions)
2725 info = deepcopy(content)
2726 for var in arrival:
2727 info[var] = 0 # Mark as a state variable
2729 # Parse the model dynamics
2730 dynamics = format_block_statement(model["dynamics"])
2732 # Make the list of ordered events
2733 events = []
2734 names_used_in_dynamics = []
2735 for line in dynamics:
2736 # Make the new event and add it to the list
2737 new_event, names_used = make_new_event(line, info)
2738 events.append(new_event)
2739 names_used_in_dynamics += names_used
2741 # Add newly assigned variables to the information set
2742 for var in new_event.assigns:
2743 if var in info.keys():
2744 raise ValueError(var + " is assigned, but already exists!")
2745 info[var] = 0
2747 # If any assigned variables are common, mark the event as common
2748 for var in new_event.assigns:
2749 if var in common:
2750 new_event.common = True
2751 break # No need to check further
2753 # Remove content that is never referenced within the dynamics
2754 delete_these = []
2755 for name in content.keys():
2756 if name not in names_used_in_dynamics:
2757 delete_these.append(name)
2758 for name in delete_these:
2759 del content[name]
2761 # Make a single string model statement
2762 statement = ""
2763 longest = np.max([len(event.statement) for event in events])
2764 for event in events:
2765 this_statement = event.statement
2766 L = len(this_statement)
2767 pad = (longest + 1) - L
2768 statement += this_statement + pad * " " + ": " + event.description + "\n"
2770 # Make a description for the template block
2771 if name is None:
2772 description = "template block for unnamed block"
2773 else:
2774 description = "template block for " + name
2776 # Make and return the new SimBlock
2777 template_block = SimBlock(
2778 description=description,
2779 arrival=arrival,
2780 content=content,
2781 statement=statement,
2782 events=events,
2783 )
2784 return template_block, info, offset, solution, comments
2787def make_initializer(model, arrival=None, common=None):
2788 """
2789 Construct a new SimBlock object to be the agent initializer, based on the
2790 model dictionary. It has structure and events, but no parameters (etc).
2792 Parameters
2793 ----------
2794 model : dict
2795 Dictionary with model initializer information, probably read in as a yaml.
2796 arrival : [str]
2797 List of arrival variables that were flagged or explicitly listed.
2799 Returns
2800 -------
2801 initializer : SimBlock
2802 A "template" of this model block, with no parameters (etc) on it.
2803 init_requires : dict
2804 Dictionary of model objects that are needed by the initializer to run.
2805 Keys are object names and entries reveal what kind of object they are:
2806 - None --> parameter
2807 - 0 --> outcome variable (these should include all arrival variables)
2808 - NullFunc --> function
2809 - Distribution --> distribution
2810 """
2811 if arrival is None:
2812 arrival = []
2813 if common is None:
2814 common = []
2815 try:
2816 name = model["name"]
2817 except:
2818 name = "DEFAULT_NAME"
2820 # Extract parameters, functions, and distributions
2821 parameters = {}
2822 if "parameters" in model["symbols"].keys():
2823 param_lines = model["symbols"]["parameters"]
2824 for line in param_lines:
2825 param_name, datatype, flags, desc = parse_declaration_for_parts(line)
2826 parameters[param_name] = None
2828 functions = {}
2829 if "functions" in model["symbols"].keys():
2830 func_lines = model["symbols"]["functions"]
2831 for line in func_lines:
2832 func_name, datatype, flags, desc = parse_declaration_for_parts(line)
2833 if (datatype is not None) and (datatype != "func"):
2834 raise ValueError(
2835 func_name
2836 + " was declared as a function, but given a different datatype!"
2837 )
2838 functions[func_name] = NullFunc()
2840 distributions = {}
2841 if "distributions" in model["symbols"].keys():
2842 dstn_lines = model["symbols"]["distributions"]
2843 for line in dstn_lines:
2844 dstn_name, datatype, flags, desc = parse_declaration_for_parts(line)
2845 if (datatype is not None) and (datatype != "dstn"):
2846 raise ValueError(
2847 dstn_name
2848 + " was declared as a distribution, but given a different datatype!"
2849 )
2850 distributions[dstn_name] = Distribution()
2852 # Combine those dictionaries into a single "information" dictionary
2853 content = parameters.copy()
2854 content.update(functions)
2855 content.update(distributions)
2856 info = deepcopy(content)
2858 # Parse the initialization routine
2859 initialize = format_block_statement(model["initialize"])
2861 # Make the list of ordered events
2862 events = []
2863 names_used_in_initialize = [] # this doesn't actually get used
2864 for line in initialize:
2865 # Make the new event and add it to the list
2866 new_event, names_used = make_new_event(line, info)
2867 events.append(new_event)
2868 names_used_in_initialize += names_used
2870 # Add newly assigned variables to the information set
2871 for var in new_event.assigns:
2872 if var in info.keys():
2873 raise ValueError(var + " is assigned, but already exists!")
2874 info[var] = 0
2876 # If any assigned variables are common, mark the event as common
2877 for var in new_event.assigns:
2878 if var in common:
2879 new_event.common = True
2880 break # No need to check further
2882 # Verify that all arrival variables were created in the initializer
2883 for var in arrival:
2884 if var not in info.keys():
2885 raise ValueError(
2886 "The arrival variable " + var + " was not set in the initialize block!"
2887 )
2889 # Make a blank dictionary with information the initializer needs
2890 init_requires = {}
2891 for event in events:
2892 for var in event.parameters.keys():
2893 if var not in init_requires.keys():
2894 try:
2895 init_requires[var] = parameters[var]
2896 except:
2897 raise ValueError(
2898 var
2899 + " was referenced in initialize, but not declared as a parameter!"
2900 )
2901 if type(event) is RandomEvent:
2902 try:
2903 dstn_name = event._dstn_name
2904 init_requires[dstn_name] = distributions[dstn_name]
2905 except:
2906 raise ValueError(
2907 dstn_name
2908 + " was referenced in initialize, but not declared as a distribution!"
2909 )
2910 if type(event) is EvaluationEvent:
2911 try:
2912 func_name = event._func_name
2913 init_requires[dstn_name] = functions[func_name]
2914 except:
2915 raise ValueError(
2916 func_name
2917 + " was referenced in initialize, but not declared as a function!"
2918 )
2920 # Make a single string initializer statement
2921 statement = ""
2922 longest = np.max([len(event.statement) for event in events])
2923 for event in events:
2924 this_statement = event.statement
2925 L = len(this_statement)
2926 pad = (longest + 1) - L
2927 statement += this_statement + pad * " " + ": " + event.description + "\n"
2929 # Make and return the new SimBlock
2930 initializer = SimBlock(
2931 description="agent initializer for " + name,
2932 content=init_requires,
2933 statement=statement,
2934 events=events,
2935 )
2936 return initializer, init_requires
2939def make_new_event(statement, info):
2940 """
2941 Makes a "blank" version of a model event based on a statement line. Determines
2942 which objects are needed vs assigned vs parameters / information from context.
2944 Parameters
2945 ----------
2946 statement : str
2947 One line of a model statement, which will be turned into an event.
2948 info : dict
2949 Empty dictionary of model information that already exists. Consists of
2950 arrival variables, already assigned variables, parameters, and functions.
2951 Typing of each is based on the kind of "empty" object.
2953 Returns
2954 -------
2955 new_event : ModelEvent
2956 A new model event with values and information missing, but structure set.
2957 names_used : [str]
2958 List of names of objects used in this expression.
2959 """
2960 # First determine what kind of event this is
2961 has_eq = "=" in statement
2962 has_tld = "~" in statement
2963 has_amp = "@" in statement
2964 has_brc = ("{" in statement) and ("}" in statement)
2965 has_brk = ("[" in statement) and ("]" in statement)
2966 event_type = None
2967 if has_eq:
2968 if has_tld:
2969 raise ValueError("A statement line can't have both an = and a ~!")
2970 if has_amp:
2971 event_type = EvaluationEvent
2972 else:
2973 event_type = DynamicEvent
2974 if has_tld:
2975 if has_brc:
2976 event_type = MarkovEvent
2977 elif has_brk:
2978 event_type = RandomIndexedEvent
2979 else:
2980 event_type = RandomEvent
2981 if event_type is None:
2982 raise ValueError("Statement line was not any valid type!")
2984 # Now make and return an appropriate event for that type
2985 if event_type is DynamicEvent:
2986 event_maker = make_new_dynamic
2987 if event_type is RandomEvent:
2988 event_maker = make_new_random
2989 if event_type is RandomIndexedEvent:
2990 event_maker = make_new_random_indexed
2991 if event_type is MarkovEvent:
2992 event_maker = make_new_markov
2993 if event_type is EvaluationEvent:
2994 event_maker = make_new_evaluation
2996 new_event, names_used = event_maker(statement, info)
2997 return new_event, names_used
3000def make_new_dynamic(statement, info):
3001 """
3002 Construct a new instance of DynamicEvent based on the given model statement
3003 line and a blank dictionary of parameters. The statement should already be
3004 verified to be a valid dynamic statement: it has an = but no ~ or @.
3006 Parameters
3007 ----------
3008 statement : str
3009 One line dynamics statement, which will be turned into a DynamicEvent.
3010 info : dict
3011 Empty dictionary of available information.
3013 Returns
3014 -------
3015 new_dynamic : DynamicEvent
3016 A new dynamic event with values and information missing, but structure set.
3017 names_used : [str]
3018 List of names of objects used in this expression.
3019 """
3020 # Cut the statement up into its LHS, RHS, and description
3021 lhs, rhs, description = parse_line_for_parts(statement, "=")
3023 # Parse the LHS (assignment) to get assigned variables
3024 assigns = parse_assignment(lhs)
3026 # Parse the RHS (dynamic statement) to extract object names used
3027 obj_names, is_indexed = extract_var_names_from_expr(rhs)
3029 # Allocate each variable to needed dynamic variables or parameters
3030 needs = []
3031 parameters = {}
3032 for j in range(len(obj_names)):
3033 var = obj_names[j]
3034 if var not in info.keys():
3035 raise ValueError(
3036 var + " is used in a dynamic expression, but does not (yet) exist!"
3037 )
3038 val = info[var]
3039 if type(val) is NullFunc:
3040 raise ValueError(
3041 var + " is used in a dynamic expression, but it's a function!"
3042 )
3043 if type(val) is Distribution:
3044 raise ValueError(
3045 var + " is used in a dynamic expression, but it's a distribution!"
3046 )
3047 if val is None:
3048 parameters[var] = None
3049 else:
3050 needs.append(var)
3052 # Declare a SymPy symbol for each variable used; these are temporary
3053 _args = []
3054 for j in range(len(obj_names)):
3055 _var = obj_names[j]
3056 if is_indexed[j]:
3057 exec(_var + " = IndexedBase('" + _var + "')")
3058 else:
3059 exec(_var + " = symbols('" + _var + "')")
3060 _args.append(symbols(_var))
3062 # Make a SymPy expression, then lambdify it
3063 sympy_expr = symbols(rhs)
3064 expr = lambdify(_args, sympy_expr)
3066 # Make an overall list of object names referenced in this event
3067 names_used = assigns + obj_names
3069 # Make and return the new dynamic event
3070 new_dynamic = DynamicEvent(
3071 description=description,
3072 statement=lhs + " = " + rhs,
3073 assigns=assigns,
3074 needs=needs,
3075 parameters=parameters,
3076 expr=expr,
3077 args=obj_names,
3078 )
3079 return new_dynamic, names_used
3082def make_new_random(statement, info):
3083 """
3084 Make a new random variable realization event based on the given model statement
3085 line and a blank dictionary of parameters. The statement should already be
3086 verified to be a valid random statement: it has a ~ but no = or [].
3088 Parameters
3089 ----------
3090 statement : str
3091 One line of the model statement, which will be turned into a random event.
3092 info : dict
3093 Empty dictionary of available information.
3095 Returns
3096 -------
3097 new_random : RandomEvent
3098 A new random event with values and information missing, but structure set.
3099 names_used : [str]
3100 List of names of objects used in this expression.
3101 """
3102 # Cut the statement up into its LHS, RHS, and description
3103 lhs, rhs, description = parse_line_for_parts(statement, "~")
3105 # Parse the LHS (assignment) to get assigned variables
3106 assigns = parse_assignment(lhs)
3108 # Verify that the RHS is actually a distribution
3109 if type(info[rhs]) is not Distribution:
3110 raise ValueError(
3111 rhs + " was treated as a distribution, but not declared as one!"
3112 )
3114 # Make an overall list of object names referenced in this event
3115 names_used = assigns + [rhs]
3117 # Make and return the new random event
3118 new_random = RandomEvent(
3119 description=description,
3120 statement=lhs + " ~ " + rhs,
3121 assigns=assigns,
3122 needs=[],
3123 parameters={},
3124 dstn=info[rhs],
3125 )
3126 new_random._dstn_name = rhs
3127 return new_random, names_used
3130def make_new_random_indexed(statement, info):
3131 """
3132 Make a new indexed random variable realization event based on the given model
3133 statement line and a blank dictionary of parameters. The statement should
3134 already be verified to be a valid random statement: it has a ~ and [].
3136 Parameters
3137 ----------
3138 statement : str
3139 One line of the model statement, which will be turned into a random event.
3140 info : dict
3141 Empty dictionary of available information.
3143 Returns
3144 -------
3145 new_random_indexed : RandomEvent
3146 A new random indexed event with values and information missing, but structure set.
3147 names_used : [str]
3148 List of names of objects used in this expression.
3149 """
3150 # Cut the statement up into its LHS, RHS, and description
3151 lhs, rhs, description = parse_line_for_parts(statement, "~")
3153 # Parse the LHS (assignment) to get assigned variables
3154 assigns = parse_assignment(lhs)
3156 # Split the RHS into the distribution and the index
3157 dstn, index = parse_random_indexed(rhs)
3159 # Verify that the RHS is actually a distribution
3160 if type(info[dstn]) is not Distribution:
3161 raise ValueError(
3162 dstn + " was treated as a distribution, but not declared as one!"
3163 )
3165 # Make an overall list of object names referenced in this event
3166 names_used = assigns + [dstn, index]
3168 # Make and return the new random indexed event
3169 new_random_indexed = RandomIndexedEvent(
3170 description=description,
3171 statement=lhs + " ~ " + rhs,
3172 assigns=assigns,
3173 needs=[index],
3174 parameters={},
3175 index=index,
3176 )
3177 new_random_indexed._dstn_name = dstn
3178 return new_random_indexed, names_used
3181def make_new_markov(statement, info):
3182 """
3183 Make a new Markov-type event based on the given model statement line and a
3184 blank dictionary of parameters. The statement should already be verified to
3185 be a valid Markov statement: it has a ~ and {} and maybe (). This can represent
3186 a Markov matrix transition event, a draw from a discrete index, or just a
3187 Bernoulli random variable. If a Bernoulli event, the "probabilties" can be
3188 idiosyncratic data.
3190 Parameters
3191 ----------
3192 statement : str
3193 One line of the model statement, which will be turned into a random event.
3194 info : dict
3195 Empty dictionary of available information.
3197 Returns
3198 -------
3199 new_markov : MarkovEvent
3200 A new Markov draw event with values and information missing, but structure set.
3201 names_used : [str]
3202 List of names of objects used in this expression.
3203 """
3204 # Cut the statement up into its LHS, RHS, and description
3205 lhs, rhs, description = parse_line_for_parts(statement, "~")
3207 # Parse the LHS (assignment) to get assigned variables
3208 assigns = parse_assignment(lhs)
3210 # Parse the RHS (Markov statement) for the array and index
3211 probs, index = parse_markov(rhs)
3212 if index is None:
3213 needs = []
3214 else:
3215 needs = [index]
3217 # Determine whether probs is an idiosyncratic variable or a parameter, and
3218 # set up the event to grab it appropriately
3219 if info[probs] is None:
3220 parameters = {probs: None}
3221 else:
3222 needs += [probs]
3223 parameters = {}
3225 # Make an overall list of object names referenced in this event
3226 names_used = assigns + needs + [probs]
3228 # Make and return the new Markov event
3229 new_markov = MarkovEvent(
3230 description=description,
3231 statement=lhs + " ~ " + rhs,
3232 assigns=assigns,
3233 needs=needs,
3234 parameters=parameters,
3235 probs=probs,
3236 index=index,
3237 )
3238 return new_markov, names_used
3241def make_new_evaluation(statement, info):
3242 """
3243 Make a new function evaluation event based the given model statement line
3244 and a blank dictionary of parameters. The statement should already be verified
3245 to be a valid evaluation statement: it has an @ and an = but no ~.
3247 Parameters
3248 ----------
3249 statement : str
3250 One line of the model statement, which will be turned into an eval event.
3251 info : dict
3252 Empty dictionary of available information.
3254 Returns
3255 -------
3256 new_evaluation : EvaluationEvent
3257 A new evaluation event with values and information missing, but structure set.
3258 names_used : [str]
3259 List of names of objects used in this expression.
3260 """
3261 # Cut the statement up into its LHS, RHS, and description
3262 lhs, rhs, description = parse_line_for_parts(statement, "=")
3264 # Parse the LHS (assignment) to get assigned variables
3265 assigns = parse_assignment(lhs)
3267 # Parse the RHS (evaluation) for the function and its arguments
3268 func, arguments = parse_evaluation(rhs)
3270 # Allocate each variable to needed dynamic variables or parameters
3271 needs = []
3272 parameters = {}
3273 for j in range(len(arguments)):
3274 var = arguments[j]
3275 if var not in info.keys():
3276 raise ValueError(
3277 var + " is used in an evaluation statement, but does not (yet) exist!"
3278 )
3279 val = info[var]
3280 if type(val) is NullFunc:
3281 raise ValueError(
3282 var
3283 + " is used as an argument an evaluation statement, but it's a function!"
3284 )
3285 if type(val) is Distribution:
3286 raise ValueError(
3287 var + " is used in an evaluation statement, but it's a distribution!"
3288 )
3289 if val is None:
3290 parameters[var] = None
3291 else:
3292 needs.append(var)
3294 # Make an overall list of object names referenced in this event
3295 names_used = assigns + arguments + [func]
3297 # Make and return the new evaluation event
3298 new_evaluation = EvaluationEvent(
3299 description=description,
3300 statement=lhs + " = " + rhs,
3301 assigns=assigns,
3302 needs=needs,
3303 parameters=parameters,
3304 arguments=arguments,
3305 func=info[func],
3306 )
3307 new_evaluation._func_name = func
3308 return new_evaluation, names_used
3311def look_for_char_and_remove(phrase, symb):
3312 """
3313 Check whether a symbol appears in a string, and remove it if it does.
3315 Parameters
3316 ----------
3317 phrase : str
3318 String to be searched for a symbol.
3319 symb : char
3320 Single character to be searched for.
3322 Returns
3323 -------
3324 out : str
3325 Possibly shortened input phrase.
3326 found : bool
3327 Whether the symbol was found and removed.
3328 """
3329 found = symb in phrase
3330 out = phrase.replace(symb, "")
3331 return out, found
3334def parse_declaration_for_parts(line):
3335 """
3336 Split a declaration line from a model file into the object's name, its datatype,
3337 any metadata flags, and any provided comment or description.
3339 Parameters
3340 ----------
3341 line : str
3342 Line of to be parsed into the object name, object type, and a comment or description.
3344 Returns
3345 -------
3346 name : str
3347 Name of the object.
3348 datatype : str or None
3349 Provided datatype string, in parentheses, if any.
3350 flags : [str]
3351 List of metadata flags that were detected. These include ! for a variable
3352 that is in arrival, * for any non-variable that's part of the solution,
3353 + for any object that is offset in time, and & for a common random variable.
3355 desc : str
3356 Comment or description, after //, if any.
3357 """
3358 flags = []
3359 check_for_flags = {"offset": "+", "arrival": "!", "solution": "*", "common": "&"}
3361 # First, separate off the comment or description, if any
3362 slashes = line.find("\\")
3363 desc = "" if slashes == -1 else line[(slashes + 2) :].strip()
3364 rem = line if slashes == -1 else line[:slashes].strip()
3366 # Now look for bracketing parentheses declaring a datatype
3367 lp = rem.find("(")
3368 if lp > -1:
3369 rp = rem.find(")")
3370 if rp == -1:
3371 raise ValueError("Unclosed parentheses on object declaration line!")
3372 datatype = rem[(lp + 1) : rp].strip()
3373 leftover = rem[:lp].strip()
3374 else:
3375 datatype = None
3376 leftover = rem
3378 # What's left over should be the object name plus any flags
3379 for key in check_for_flags.keys():
3380 symb = check_for_flags[key]
3381 leftover, found = look_for_char_and_remove(leftover, symb)
3382 if found:
3383 flags.append(key)
3385 # Remove any remaining spaces, and that *should* be the name
3386 name = leftover.replace(" ", "")
3387 # TODO: Check for valid name formatting based on characters.
3389 return name, datatype, flags, desc
3392def parse_line_for_parts(statement, symb):
3393 """
3394 Split one line of a model statement into its LHS, RHS, and description. The
3395 description is everything following \\, while the LHS and RHS are determined
3396 by a special symbol.
3398 Parameters
3399 ----------
3400 statement : str
3401 One line of a model statement, which will be parsed for its parts.
3402 symb : char
3403 The character that represents the divide between LHS and RHS
3405 Returns
3406 -------
3407 lhs : str
3408 The left-hand (assignment) side of the expression.
3409 rhs : str
3410 The right-hand (evaluation) side of the expression.
3411 desc : str
3412 The provided description of the expression.
3413 """
3414 eq = statement.find(symb)
3415 lhs = statement[:eq].replace(" ", "")
3416 not_lhs = statement[(eq + 1) :]
3417 comment = not_lhs.find("\\")
3418 desc = "" if comment == -1 else not_lhs[(comment + 2) :].strip()
3419 rhs = not_lhs if comment == -1 else not_lhs[:comment]
3420 rhs = rhs.replace(" ", "")
3421 return lhs, rhs, desc
3424def parse_assignment(lhs):
3425 """
3426 Get ordered list of assigned variables from the LHS of a model line.
3428 Parameters
3429 ----------
3430 lhs : str
3431 Left-hand side of a model expression
3433 Returns
3434 -------
3435 assigns : List[str]
3436 List of variable names that are assigned in this model line.
3437 """
3438 if lhs[0] == "(":
3439 if not lhs[-1] == ")":
3440 raise ValueError("Parentheses on assignment was not closed!")
3441 assigns = []
3442 pos = 0
3443 while pos != -1:
3444 pos += 1
3445 end = lhs.find(",", pos)
3446 var = lhs[pos:end]
3447 if var != "":
3448 assigns.append(var)
3449 pos = end
3450 else:
3451 assigns = [lhs]
3452 return assigns
3455def extract_var_names_from_expr(expression):
3456 """
3457 Parse the RHS of a dynamic model statement to get variable names used in it.
3459 Parameters
3460 ----------
3461 expression : str
3462 RHS of a model statement to be parsed for variable names.
3464 Returns
3465 -------
3466 var_names : List[str]
3467 List of variable names used in the expression. These *should* be dynamic
3468 variables and parameters, but not functions.
3469 indexed : List[bool]
3470 Indicators for whether each variable seems to be used with indexing.
3471 """
3472 var_names = []
3473 indexed = []
3474 math_symbols = "+-/*^%.(),[]{}<>"
3475 digits = "01234567890"
3476 cur = ""
3477 for j in range(len(expression)):
3478 c = expression[j]
3479 if (c in math_symbols) or ((c in digits) and cur == ""):
3480 if cur == "":
3481 continue
3482 if cur in var_names:
3483 cur = ""
3484 continue
3485 var_names.append(cur)
3486 if c == "[":
3487 indexed.append(True)
3488 else:
3489 indexed.append(False)
3490 cur = ""
3491 else:
3492 cur += c
3493 if cur != "" and cur not in var_names:
3494 var_names.append(cur)
3495 indexed.append(False) # final symbol couldn't possibly be indexed
3496 return var_names, indexed
3499def parse_evaluation(expression):
3500 """
3501 Separate a function evaluation expression into the function that is called
3502 and the variable inputs that are passed to it.
3504 Parameters
3505 ----------
3506 expression : str
3507 RHS of a function evaluation model statement, which will be parsed for
3508 the function and its inputs.
3510 Returns
3511 -------
3512 func_name : str
3513 Name of the function that will be called in this event.
3514 arg_names : List[str]
3515 List of arguments of the function.
3516 """
3517 # Get the name of the function: what's to the left of the @
3518 amp = expression.find("@")
3519 func_name = expression[:amp]
3521 # Check for parentheses formatting
3522 rem = expression[(amp + 1) :]
3523 if not rem[0] == "(":
3524 raise ValueError(
3525 "The @ in a function evaluation statement must be followed by (!"
3526 )
3527 if not rem[-1] == ")":
3528 raise ValueError("A function evaluation statement must end in )!")
3529 rem = rem[1:-1]
3531 # Parse what's inside the parentheses for argument names
3532 arg_names = []
3533 pos = 0
3534 go = True
3535 while go:
3536 end = rem.find(",", pos)
3537 if end > -1:
3538 arg = rem[pos:end]
3539 else:
3540 arg = rem[pos:]
3541 go = False
3542 if arg != "":
3543 arg_names.append(arg)
3544 pos = end + 1
3546 return func_name, arg_names
3549def parse_markov(expression):
3550 """
3551 Separate a Markov draw declaration into the array of probabilities and the
3552 index for idiosyncratic values.
3554 Parameters
3555 ----------
3556 expression : str
3557 RHS of a function evaluation model statement, which will be parsed for
3558 the probabilities name and index name.
3560 Returns
3561 -------
3562 probs : str
3563 Name of the probabilities object in this statement.
3564 index : str
3565 Name of the indexing variable in this statement.
3566 """
3567 # Get the name of the probabilitie
3568 lb = expression.find("{") # this *should* be 0
3569 rb = expression.find("}")
3570 if lb == -1 or rb == -1 or rb < (lb + 2):
3571 raise ValueError("A Markov assignment must have an {array}!")
3572 probs = expression[(lb + 1) : rb]
3574 # Get the name of the index, if any
3575 x = rb + 1
3576 lp = expression.find("(", x)
3577 rp = expression.find(")", x)
3578 if lp == -1 and rp == -1: # no index present at all
3579 return probs, None
3580 if lp == -1 or rp == -1 or rp < (lp + 2):
3581 raise ValueError("Improper Markov formatting: should be {probs}(index)!")
3582 index = expression[(lp + 1) : rp]
3584 return probs, index
3587def parse_random_indexed(expression):
3588 """
3589 Separate an indexed random variable assignment into the distribution and
3590 the index for it.
3592 Parameters
3593 ----------
3594 expression : str
3595 RHS of a function evaluation model statement, which will be parsed for
3596 the distribution name and index name.
3598 Returns
3599 -------
3600 dstn : str
3601 Name of the distribution in this statement.
3602 index : str
3603 Name of the indexing variable in this statement.
3604 """
3605 # Get the name of the index
3606 lb = expression.find("[")
3607 rb = expression.find("]")
3608 if lb == -1 or rb == -1 or rb < (lb + 2):
3609 raise ValueError("An indexed random variable assignment must have an [index]!")
3610 index = expression[(lb + 1) : rb]
3612 # Get the name of the distribution
3613 dstn = expression[:lb]
3615 return dstn, index
3618def format_block_statement(statement):
3619 """
3620 Ensure that a string statement of a model block (maybe a period, maybe an
3621 initializer) is formatted as a list of strings, one statement per entry.
3623 Parameters
3624 ----------
3625 statement : str
3626 A model statement, which might be for a block or an initializer. The
3627 statement might be formatted as a list or as a single string.
3629 Returns
3630 -------
3631 block_statements: [str]
3632 A list of model statements, one per entry.
3633 """
3634 if type(statement) is str:
3635 if statement.find("\n") > -1:
3636 block_statements = []
3637 pos = 0
3638 end = statement.find("\n", pos)
3639 while end > -1:
3640 new_line = statement[pos:end]
3641 block_statements.append(new_line)
3642 pos = end + 1
3643 end = statement.find("\n", pos)
3644 else:
3645 block_statements = [statement.copy()]
3646 if type(statement) is list:
3647 for line in statement:
3648 if type(line) is not str:
3649 raise ValueError("The model statement somehow includes a non-string!")
3650 block_statements = statement.copy()
3651 return block_statements
3654@njit
3655def aggregate_blobs_onto_polynomial_grid(
3656 vals, pmv, origins, grid, J, Q
3657): # pragma: no cover
3658 """
3659 Numba-compatible helper function for casting "probability blobs" onto a discretized
3660 grid of outcome values, based on their origin in the arrival state space. This
3661 version is for non-continuation variables, returning only the probability array
3662 mapping from arrival states to the outcome variable.
3663 """
3664 bot = grid[0]
3665 top = grid[-1]
3666 M = grid.size
3667 Mm1 = M - 1
3668 N = pmv.size
3669 scale = 1.0 / (top - bot)
3670 order = 1.0 / Q
3671 diffs = grid[1:] - grid[:-1]
3673 probs = np.zeros((J, M))
3675 for n in range(N):
3676 x = vals[n]
3677 jj = origins[n]
3678 p = pmv[n]
3679 if (x > bot) and (x < top):
3680 ii = int(np.floor(((x - bot) * scale) ** order * Mm1))
3681 temp = (x - grid[ii]) / diffs[ii]
3682 probs[jj, ii] += (1.0 - temp) * p
3683 probs[jj, ii + 1] += temp * p
3684 elif x <= bot:
3685 probs[jj, 0] += p
3686 else:
3687 probs[jj, -1] += p
3688 return probs
3691@njit
3692def aggregate_blobs_onto_polynomial_grid_alt(
3693 vals, pmv, origins, grid, J, Q
3694): # pragma: no cover
3695 """
3696 Numba-compatible helper function for casting "probability blobs" onto a discretized
3697 grid of outcome values, based on their origin in the arrival state space. This
3698 version is for continuation variables, returning the probability array mapping
3699 from arrival states to the outcome variable, the index in the outcome variable grid
3700 for each blob, and the alpha weighting between gridpoints.
3701 """
3702 bot = grid[0]
3703 top = grid[-1]
3704 M = grid.size
3705 Mm1 = M - 1
3706 N = pmv.size
3707 scale = 1.0 / (top - bot)
3708 order = 1.0 / Q
3709 diffs = grid[1:] - grid[:-1]
3711 probs = np.zeros((J, M))
3712 idx = np.empty(N, dtype=np.dtype(np.int32))
3713 alpha = np.empty(N)
3715 for n in range(N):
3716 x = vals[n]
3717 jj = origins[n]
3718 p = pmv[n]
3719 if (x > bot) and (x < top):
3720 ii = int(np.floor(((x - bot) * scale) ** order * Mm1))
3721 temp = (x - grid[ii]) / diffs[ii]
3722 probs[jj, ii] += (1.0 - temp) * p
3723 probs[jj, ii + 1] += temp * p
3724 alpha[n] = temp
3725 idx[n] = ii
3726 elif x <= bot:
3727 probs[jj, 0] += p
3728 alpha[n] = 0.0
3729 idx[n] = 0
3730 else:
3731 probs[jj, -1] += p
3732 alpha[n] = 1.0
3733 idx[n] = M - 2
3734 return probs, idx, alpha
3737@njit
3738def aggregate_blobs_onto_discrete_grid(vals, pmv, origins, M, J): # pragma: no cover
3739 """
3740 Numba-compatible helper function for allocating "probability blobs" to a grid
3741 over a discrete state-- the state itself is truly discrete.
3742 """
3743 probs = np.zeros((J, M))
3744 N = pmv.size
3745 for n in range(N):
3746 ii = vals[n]
3747 jj = origins[n]
3748 p = pmv[n]
3749 probs[jj, ii] += p
3750 return probs
3753@njit
3754def calc_overall_trans_probs(
3755 out, idx, alpha, binary, offset, pmv, origins
3756): # pragma: no cover
3757 """
3758 Numba-compatible helper function for combining transition probabilities from
3759 the arrival state space to *multiple* continuation variables into a single
3760 unified transition matrix.
3761 """
3762 N = alpha.shape[0]
3763 B = binary.shape[0]
3764 D = binary.shape[1]
3765 for n in range(N):
3766 ii = origins[n]
3767 jj_base = idx[n]
3768 p = pmv[n]
3769 for b in range(B):
3770 adj = offset[b]
3771 P = p
3772 for d in range(D):
3773 k = binary[b, d]
3774 P *= alpha[n, d, k]
3775 jj = jj_base + adj
3776 out[ii, jj] += P
3777 return out