Coverage for HARK / simulator.py: 90%
1351 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
1"""
2A module with classes and functions for automated simulation of HARK.AgentType
3models from a human- and machine-readable model specification.
4"""
6from dataclasses import dataclass, field
7from copy import copy, deepcopy
8import numpy as np
9from numba import njit
10from sympy.utilities.lambdify import lambdify
11from sympy import symbols, IndexedBase
12from typing import Callable
13from HARK.utilities import NullFunc, make_exponential_grid
14from HARK.distributions import Distribution
15from scipy.sparse import csr_matrix
16from scipy.sparse.linalg import eigs
17from itertools import product
18import importlib.resources
19import yaml
21# Prevent pre-commit from removing sympy
22x = symbols("x")
23del x
24y = IndexedBase("y")
25del y
28@dataclass(kw_only=True)
29class ModelEvent:
30 """
31 Class for representing "events" that happen to agents in the course of their
32 model. These might be statements of dynamics, realization of a random shock,
33 or the evaluation of a function (potentially a control or other solution-
34 based object). This is a superclass for types of events defined below.
36 Parameters
37 ----------
38 description : str
39 Text description of this model event.
40 statement : str
41 The line of the model statement that this event corresponds to.
42 parameters : dict
43 Dictionary of objects that are static / universal within this event.
44 assigns : list[str]
45 List of names of variables that this event assigns values for.
46 needs : list[str]
47 List of names of variables that this event requires to be run.
48 data : dict
49 Dictionary of current variable values within this event.
50 common : bool
51 Indicator for whether the variables assigned in this event are commonly
52 held across all agents, rather than idiosyncratic.
53 N : int
54 Number of agents currently in this event.
55 """
57 statement: str = field(default="")
58 parameters: dict = field(default_factory=dict)
59 description: str = field(default="")
60 assigns: list[str] = field(default_factory=list, repr=False)
61 needs: list = field(default_factory=list, repr=False)
62 data: dict = field(default_factory=dict, repr=False)
63 common: bool = field(default=False, repr=False)
64 N: int = field(default=1, repr=False)
66 def run(self):
67 """
68 This method should be filled in by each subclass.
69 """
70 pass # pragma: nocover
72 def reset(self):
73 self.data = {}
75 def assign(self, output):
76 if len(self.assigns) > 1:
77 assert len(self.assigns) == len(output)
78 for j in range(len(self.assigns)):
79 var = self.assigns[j]
80 if type(output[j]) is not np.ndarray:
81 output[j] = np.array([output[j]])
82 self.data[var] = output[j]
83 else:
84 var = self.assigns[0]
85 if type(output) is not np.ndarray:
86 output = np.array([output])
87 self.data[var] = output
89 def expand_information(self, origins, probs, atoms, which=None):
90 """
91 This method is only called internally when a RandomEvent or MarkovEvent
92 runs its quasi_run() method. It expands the set of of "probability blobs"
93 by applying a random realization event. All extant blobs for which the
94 shock applies are replicated for each atom in the random event, with the
95 probability mass divided among the replicates.
97 Parameters
98 ----------
99 origins : np.array
100 Array that tracks which arrival state space node each blob originated
101 from. This is expanded into origins_new, which is returned.
102 probs : np.array
103 Vector of probabilities of each of the random possibilities.
104 atoms : [np.array]
105 List of arrays with realization values for the distribution. Each
106 array corresponds to one variable that is assigned by this event.
107 which : np.array or None
108 If given, a Boolean array indicating which of the pre-existing blobs
109 is affected by the given probabilities and atoms. By default, all
110 blobs are assumed to be affected.
112 Returns
113 -------
114 origins_new : np.array
115 Expanded boolean array of indicating the arrival state space node that
116 each blob originated from.
117 """
118 K = probs.size
119 N = self.N
120 if which is None:
121 which = np.ones(N, dtype=bool)
122 other = np.logical_not(which)
123 M = np.sum(which) # how many blobs are we affecting?
124 MX = N - M # how many blobs are we not affecting?
126 # Update probabilities of outcomes
127 pmv_old = np.reshape(self.data["pmv_"][which], (M, 1))
128 pmv_new = (pmv_old * np.reshape(probs, (1, K))).flatten()
129 self.data["pmv_"] = np.concatenate((self.data["pmv_"][other], pmv_new))
131 # Replicate the pre-existing data for each atom
132 for var in self.data.keys():
133 if (var == "pmv_") or (var in self.assigns):
134 continue # don't double expand pmv, and don't touch assigned variables
135 data_old = np.reshape(self.data[var][which], (M, 1))
136 data_new = np.tile(data_old, (1, K)).flatten()
137 self.data[var] = np.concatenate((self.data[var][other], data_new))
139 # If any of the assigned variables don't exist yet, add dummy versions
140 # of them. This section exists so that the code works with "partial events"
141 # on both the first pass and subsequent passes.
142 for j in range(len(self.assigns)):
143 var = self.assigns[j]
144 if var in self.data.keys():
145 continue
146 self.data[var] = np.zeros(N, dtype=atoms[j].dtype)
147 # Zeros are just dummy values
149 # Add the new random variables to the simulation data. This generates
150 # replicates for the affected blobs and leaves the others untouched,
151 # still with their dummy values. They will be altered on later passes.
152 for j in range(len(self.assigns)):
153 var = self.assigns[j]
154 data_new = np.tile(np.reshape(atoms[j], (1, K)), (M, 1)).flatten()
155 self.data[var] = np.concatenate((self.data[var][other], data_new))
157 # Expand the origins array to account for the new replicates
158 origins_new = np.tile(np.reshape(origins[which], (M, 1)), (1, K)).flatten()
159 origins_new = np.concatenate((origins[other], origins_new))
160 self.N = MX + M * K
162 # Send the new origins array back to the calling process
163 return origins_new
165 def add_idiosyncratic_bernoulli_info(self, origins, probs):
166 """
167 Special method for adding Bernoulli outcomes to the information set when
168 probabilities are idiosyncratic to each agent. All extant blobs are duplicated
169 with the appropriate probability
171 Parameters
172 ----------
173 origins : np.array
174 Array that tracks which arrival state space node each blob originated
175 from. This is expanded into origins_new, which is returned.
176 probs : np.array
177 Vector of probabilities of drawing True for each blob.
179 Returns
180 -------
181 origins_new : np.array
182 Expanded boolean array of indicating the arrival state space node that
183 each blob originated from.
184 """
185 N = self.N
187 # # Update probabilities of outcomes, replicating each one
188 pmv_old = np.reshape(self.data["pmv_"], (N, 1))
189 P = np.reshape(probs, (N, 1))
190 PX = np.concatenate([1.0 - P, P], axis=1)
191 pmv_new = (pmv_old * PX).flatten()
192 self.data["pmv_"] = pmv_new
194 # Replicate the pre-existing data for each atom
195 for var in self.data.keys():
196 if (var == "pmv_") or (var in self.assigns):
197 continue # don't double expand pmv, and don't touch assigned variables
198 data_old = np.reshape(self.data[var], (N, 1))
199 data_new = np.tile(data_old, (1, 2)).flatten()
200 self.data[var] = data_new
202 # Add the (one and only) new random variable to the simulation data
203 var = self.assigns[0]
204 data_new = np.tile(np.array([[0, 1]]), (N, 1)).flatten()
205 self.data[var] = data_new
207 # Expand the origins array to account for the new replicates
208 origins_new = np.tile(np.reshape(origins, (N, 1)), (1, 2)).flatten()
209 self.N = N * 2
211 # Send the new origins array back to the calling process
212 return origins_new
215@dataclass(kw_only=True)
216class DynamicEvent(ModelEvent):
217 """
218 Class for representing model dynamics for an agent, consisting of an expression
219 to be evaluated and variables to which the results are assigned.
221 Parameters
222 ----------
223 expr : Callable
224 Function or expression to be evaluated for the assigned variables.
225 args : list[str]
226 Ordered list of argument names for the expression.
227 """
229 expr: Callable = field(default_factory=NullFunc, repr=False)
230 args: list[str] = field(default_factory=list, repr=False)
232 def evaluate(self):
233 temp_dict = self.data.copy()
234 temp_dict.update(self.parameters)
235 args = (temp_dict[arg] for arg in self.args)
236 out = self.expr(*args)
237 return out
239 def run(self):
240 self.assign(self.evaluate())
242 def quasi_run(self, origins, norm=None):
243 self.run()
244 return origins
247@dataclass(kw_only=True)
248class RandomEvent(ModelEvent):
249 """
250 Class for representing the realization of random variables for an agent,
251 consisting of a shock distribution and variables to which the results are assigned.
253 Parameters
254 ----------
255 dstn : Distribution
256 Distribution of one or more random variables that are drawn from during
257 this event and assigned to the corresponding variables.
258 """
260 dstn: Distribution = field(default_factory=Distribution, repr=False)
262 def reset(self):
263 self.dstn.reset()
264 ModelEvent.reset(self)
266 def draw(self):
267 out = np.empty((len(self.assigns), self.N))
268 if not self.common:
269 out[:, :] = self.dstn.draw(self.N)
270 else:
271 out[:, :] = self.dstn.draw(1)
272 if len(self.assigns) == 1:
273 out = out.flatten()
274 return out
276 def run(self):
277 self.assign(self.draw())
279 def quasi_run(self, origins, norm=None):
280 # Get distribution
281 atoms = self.dstn.atoms
282 probs = self.dstn.pmv.copy()
284 # Apply Harmenberg normalization if applicable
285 try:
286 harm_idx = self.assigns.index(norm)
287 probs *= atoms[harm_idx]
288 except:
289 pass
291 # Expand the set of simulated blobs
292 origins_new = self.expand_information(origins, probs, atoms)
293 return origins_new
296@dataclass(kw_only=True)
297class RandomIndexedEvent(RandomEvent):
298 """
299 Class for representing the realization of random variables for an agent,
300 consisting of a list of shock distributions, an index for the list, and the
301 variables to which the results are assigned.
303 Parameters
304 ----------
305 dstn : [Distribution]
306 List of distributions of one or more random variables that are drawn
307 from during this event and assigned to the corresponding variables.
308 index : str
309 Name of the index that is used to choose a distribution for each agent.
310 """
312 index: str = field(default="", repr=False)
313 dstn: list[Distribution] = field(default_factory=list, repr=False)
315 def draw(self):
316 idx = self.data[self.index]
317 K = len(self.assigns)
318 out = np.empty((K, self.N))
319 out.fill(np.nan)
321 if self.common:
322 k = idx[0] # this will behave badly if index is not itself common
323 out[:, :] = self.dstn[k].draw(1)
324 return out
326 for k in range(len(self.dstn)):
327 these = idx == k
328 if not np.any(these):
329 continue
330 out[:, these] = self.dstn[k].draw(np.sum(these))
331 if K == 1:
332 out = out.flatten()
333 return out
335 def reset(self):
336 for k in range(len(self.dstn)):
337 self.dstn[k].reset()
338 ModelEvent.reset(self)
340 def quasi_run(self, origins, norm=None):
341 origins_new = origins.copy()
342 J = len(self.dstn)
344 for j in range(J):
345 idx = self.data[self.index]
346 these = idx == j
348 # Get distribution
349 atoms = self.dstn[j].atoms
350 probs = self.dstn[j].pmv.copy()
352 # Apply Harmenberg normalization if applicable
353 try:
354 harm_idx = self.assigns.index(norm)
355 probs *= atoms[harm_idx]
356 except:
357 pass
359 # Expand the set of simulated blobs
360 origins_new = self.expand_information(
361 origins_new, probs, atoms, which=these
362 )
364 # Return the altered origins array
365 return origins_new
368@dataclass(kw_only=True)
369class MarkovEvent(ModelEvent):
370 """
371 Class for representing the realization of a Markov draw for an agent, in which
372 a Markov probabilities (array, vector, or a single float) is used to determine
373 the realization of some discrete outcome. If the probabilities are a 2D array,
374 it represents a Markov matrix (rows sum to 1), and there must be an index; if
375 the probabilities are a vector, it should be a stochastic vector; if it's a
376 single float, it represents a Bernoulli probability.
377 """
379 probs: str = field(default="", repr=False)
380 index: str = field(default="", repr=False)
381 N: int = field(default=1, repr=False)
382 seed: int = field(default=0, repr=False)
383 # seed is overwritten when each period is created
385 def __post_init__(self):
386 self.reset_rng()
388 def reset(self):
389 self.reset_rng()
390 ModelEvent.reset(self)
392 def reset_rng(self):
393 self.RNG = np.random.RandomState(self.seed)
395 def draw(self):
396 # Initialize the output
397 out = -np.ones(self.N, dtype=int)
398 if self.probs in self.parameters:
399 probs = self.parameters[self.probs]
400 probs_are_param = True
401 else:
402 probs = self.data[self.probs]
403 probs_are_param = False
405 # Make the base draw(s)
406 if self.common:
407 X = self.RNG.rand(1)
408 else:
409 X = self.RNG.rand(self.N)
411 if self.index: # it's a Markov matrix
412 idx = self.data[self.index]
413 J = probs.shape[0]
414 for j in range(J):
415 these = idx == j
416 if not np.any(these):
417 continue
418 P = np.cumsum(probs[j, :])
419 if self.common:
420 out[:] = np.searchsorted(P, X[0]) # only one value of X!
421 else:
422 out[these] = np.searchsorted(P, X[these])
423 return out
425 if (isinstance(probs, np.ndarray)) and (
426 probs_are_param
427 ): # it's a stochastic vector
428 P = np.cumsum(probs)
429 if self.common:
430 out[:] = np.searchsorted(P, X[0])
431 return out
432 else:
433 return np.searchsorted(P, X)
435 # Otherwise, this is just a Bernoulli RV
436 P = probs
437 if self.common:
438 out[:] = X < P
439 return out
440 else:
441 return X < P # basic Bernoulli
443 def run(self):
444 self.assign(self.draw())
446 def quasi_run(self, origins, norm=None):
447 if self.probs in self.parameters:
448 probs = self.parameters[self.probs]
449 probs_are_param = True
450 else:
451 probs = self.data[self.probs]
452 probs_are_param = False
454 # If it's a Markov matrix:
455 if self.index:
456 K = probs.shape[0]
457 atoms = np.array([np.arange(probs.shape[1], dtype=int)])
458 origins_new = origins.copy()
459 for k in range(K):
460 idx = self.data[self.index]
461 these = idx == k
462 probs_temp = probs[k, :]
463 origins_new = self.expand_information(
464 origins_new, probs_temp, atoms, which=these
465 )
466 return origins_new
468 # If it's a stochastic vector:
469 if (isinstance(probs, np.ndarray)) and (probs_are_param):
470 atoms = np.array([np.arange(probs.shape[0], dtype=int)])
471 origins_new = self.expand_information(origins, probs, atoms)
472 return origins_new
474 # Otherwise, this is just a Bernoulli RV, but it might have idiosyncratic probability
475 if probs_are_param:
476 P = probs
477 atoms = np.array([[False, True]])
478 origins_new = self.expand_information(origins, np.array([1 - P, P]), atoms)
479 return origins_new
481 # Final case: probability is idiosyncratic Bernoulli
482 origins_new = self.add_idiosyncratic_bernoulli_info(origins, probs)
483 return origins_new
486@dataclass(kw_only=True)
487class EvaluationEvent(ModelEvent):
488 """
489 Class for representing the evaluation of a model function. This might be from
490 the solution of the model (like a policy function or decision rule) or just
491 a non-algebraic function used in the model. This looks a lot like DynamicEvent.
493 Parameters
494 ----------
495 func : Callable
496 Model function that is evaluated in this event, with the output assigned
497 to the appropriate variables.
498 """
500 func: Callable = field(default_factory=NullFunc, repr=False)
501 arguments: list[str] = field(default_factory=list, repr=False)
503 def evaluate(self):
504 temp_dict = self.data.copy()
505 temp_dict.update(self.parameters)
506 args_temp = (temp_dict[arg] for arg in self.arguments)
507 out = self.func(*args_temp)
508 return out
510 def run(self):
511 self.assign(self.evaluate())
513 def quasi_run(self, origins, norm=None):
514 self.run()
515 return origins
518@dataclass(kw_only=True)
519class SimBlock:
520 """
521 Class for representing a "block" of a simulated model, which might be a whole
522 period or a "stage" within a period.
524 Parameters
525 ----------
526 description : str
527 Textual description of what happens in this simulated block.
528 statement : str
529 Verbatim model statement that was used to create this block.
530 content : dict
531 Dictionary of objects that are constant / universal within the block.
532 This includes both traditional numeric parameters as well as functions.
533 arrival : list[str]
534 List of inbound states: information available at the *start* of the block.
535 events: list[ModelEvent]
536 Ordered list of events that happen during the block.
537 data: dict
538 Dictionary that stores current variable values.
539 N : int
540 Number of idiosyncratic agents in this block.
541 """
543 statement: str = field(default="", repr=False)
544 content: dict = field(default_factory=dict)
545 description: str = field(default="", repr=False)
546 arrival: list[str] = field(default_factory=list, repr=False)
547 events: list[ModelEvent] = field(default_factory=list, repr=False)
548 data: dict = field(default_factory=dict, repr=False)
549 N: int = field(default=1, repr=False)
551 def run(self):
552 """
553 Run this simulated block by running each of its events in order.
554 """
555 for j in range(len(self.events)):
556 event = self.events[j]
557 for k in range(len(event.assigns)):
558 var = event.assigns[k]
559 if var in event.data.keys():
560 del event.data[var]
561 for k in range(len(event.needs)):
562 var = event.needs[k]
563 event.data[var] = self.data[var]
564 event.N = self.N
565 event.run()
566 for k in range(len(event.assigns)):
567 var = event.assigns[k]
568 self.data[var] = event.data[var]
570 def reset(self):
571 """
572 Reset the simulated block by resetting each of its events.
573 """
574 self.data = {}
575 for j in range(len(self.events)):
576 self.events[j].reset()
578 def distribute_content(self):
579 """
580 Fill in parameters, functions, and distributions to each event.
581 """
582 for event in self.events:
583 for param in event.parameters.keys():
584 try:
585 event.parameters[param] = self.content[param]
586 except:
587 raise ValueError(
588 "Could not distribute the parameter called " + param + "!"
589 )
590 if (type(event) is RandomEvent) or (type(event) is RandomIndexedEvent):
591 try:
592 event.dstn = self.content[event._dstn_name]
593 except:
594 raise ValueError(
595 "Could not find a distribution called " + event._dstn_name + "!"
596 )
597 if type(event) is EvaluationEvent:
598 try:
599 event.func = self.content[event._func_name]
600 except:
601 raise ValueError(
602 "Could not find a function called " + event._func_name + "!"
603 )
605 def make_transition_matrices(self, grid_specs, twist=None, norm=None):
606 """
607 Construct a transition matrix for this block, moving from a discretized
608 grid of arrival variables to a discretized grid of end-of-block variables.
609 User specifies how the grids of pre-states should be built. Output is
610 stored in attributes of self as follows:
612 - matrices : A dictionary of arrays that cast from the arrival state space
613 to the grid of outcome variables. Doing np.dot(dstn, matrices[var])
614 will yield the discretized distribution of that outcome variable.
615 - grids : A dictionary of discretized grids for outcome variables. Doing
616 np.dot(np.dot(dstn, matrices[var]), grids[var]) yields the *average*
617 of that outcome in the population.
619 Parameters
620 ----------
621 grid_specs : dict
622 Dictionary of dictionaries of grid specifications. For now, these have
623 at most a minimum value, a maximum value, a number of nodes, and a poly-
624 nomial order. They are equispaced if a min and max are specified, and
625 polynomially spaced with the specified order > 0 if provided. Otherwise,
626 they are set at 0,..,N if only N is provided.
627 twist : dict or None
628 Mapping from end-of-period (continuation) variables to successor's
629 arrival variables. When this is specified, additional output is created
630 for the "full period" arrival-to-arrival transition matrix.
631 norm : str or None
632 Name of the shock variable by which to normalize for Harmenberg
633 aggregation. By default, no normalization happens.
635 Returns
636 -------
637 None
638 """
639 # Initialize dictionaries of input and output grids
640 arrival_N = len(self.arrival)
641 completed = arrival_N * [False]
642 grids_in = {}
643 grids_out = {}
644 if arrival_N == 0: # should only be for initializer block
645 dummy_grid = np.array([0])
646 grids_in["_dummy"] = dummy_grid
648 # Construct a grid for each requested variable
649 continuous_grid_out_bool = []
650 grid_orders = {}
651 for var in grid_specs.keys():
652 spec = grid_specs[var]
653 try:
654 idx = self.arrival.index(var)
655 completed[idx] = True
656 is_arrival = True
657 except:
658 is_arrival = False
659 if ("min" in spec) and ("max" in spec):
660 Q = spec["order"] if "order" in spec else 1.0
661 bot = spec["min"]
662 top = spec["max"]
663 N = spec["N"]
664 new_grid = make_exponential_grid(bot, top, N, Q)
665 is_cont = True
666 grid_orders[var] = Q
667 elif "N" in spec:
668 new_grid = np.arange(spec["N"], dtype=int)
669 is_cont = False
670 grid_orders[var] = -1
671 else:
672 new_grid = None # could not make grid, construct later
673 is_cont = False
674 grid_orders[var] = None
676 if is_arrival:
677 grids_in[var] = new_grid
678 else:
679 grids_out[var] = new_grid
680 continuous_grid_out_bool.append(is_cont)
682 # Verify that specifications were passed for all arrival variables
683 for j in range(len(self.arrival)):
684 if not completed[j]:
685 raise ValueError(
686 "No grid specification was provided for " + self.arrival[var] + "!"
687 )
689 # If an intertemporal twist was specified, make result grids for continuation variables.
690 # This overrides any grids for these variables that were explicitly specified
691 if twist is not None:
692 for cont_var in twist.keys():
693 arr_var = twist[cont_var]
694 if cont_var not in list(grids_out.keys()):
695 is_cont = grids_in[arr_var].dtype is np.dtype(np.float64)
696 continuous_grid_out_bool.append(is_cont)
697 grids_out[cont_var] = copy(grids_in[arr_var])
698 grid_orders[cont_var] = grid_orders[arr_var]
699 grid_out_is_continuous = np.array(continuous_grid_out_bool)
701 # Make meshes of all the arrival grids, which will be the initial simulation data
702 if arrival_N > 0:
703 state_meshes = np.meshgrid(
704 *[grids_in[k] for k in self.arrival], indexing="ij"
705 )
706 else: # this only happens in the initializer block
707 state_meshes = [dummy_grid.copy()]
708 state_init = {
709 self.arrival[k]: state_meshes[k].flatten() for k in range(arrival_N)
710 }
711 N_orig = state_meshes[0].size
712 self.N = N_orig
713 mesh_tuples = [
714 [state_init[self.arrival[k]][n] for k in range(arrival_N)]
715 for n in range(self.N)
716 ]
718 # Make the initial vector of probability masses
719 state_init["pmv_"] = np.ones(self.N)
721 # Initialize the array of arrival states
722 origin_array = np.arange(self.N, dtype=int)
724 # Reset the block's state and give it the initial state data
725 self.reset()
726 self.data.update(state_init)
728 # Loop through each event in order and quasi-simulate it
729 for j in range(len(self.events)):
730 event = self.events[j]
731 event.data = self.data # Give event *all* data directly
732 event.N = self.N
733 origin_array = event.quasi_run(origin_array, norm=norm)
734 self.N = self.data["pmv_"].size
736 # Add survival to output if mortality is in the model
737 if "dead" in self.data.keys():
738 grids_out["dead"] = None
740 # Get continuation variable names, making sure they're in the same order
741 # as named by the arrival variables. This should maybe be done in the
742 # simulator when it's initialized.
743 if twist is not None:
744 cont_vars_orig = list(twist.keys())
745 temp_dict = {twist[var]: var for var in cont_vars_orig}
746 cont_vars = []
747 for var in self.arrival:
748 cont_vars.append(temp_dict[var])
749 if "dead" in self.data.keys():
750 cont_vars.append("dead")
751 grid_out_is_continuous = np.concatenate(
752 (grid_out_is_continuous, [False])
753 )
754 else:
755 cont_vars = list(grids_out.keys()) # all outcomes are arrival vars
756 D = len(cont_vars)
758 # Now project the final results onto the output or result grids
759 N = self.N
760 J = state_meshes[0].size
761 matrices_out = {}
762 cont_idx = {}
763 cont_alpha = {}
764 cont_M = {}
765 cont_discrete = {}
766 k = 0
767 for var in grids_out.keys():
768 if var not in self.data.keys():
769 raise ValueError(
770 "Variable " + var + " does not exist but a grid was specified!"
771 )
772 grid = grids_out[var]
773 vals = self.data[var]
774 pmv = self.data["pmv_"]
775 M = grid.size if grid is not None else 0
777 # Semi-hacky fix to deal with omitted arrival variables
778 if (M == 1) and (vals.dtype is np.dtype(np.float64)):
779 grid = grid.astype(float)
780 grids_out[var] = grid
781 grid_out_is_continuous[k] = True
783 if grid_out_is_continuous[k]:
784 # Split the final values among discrete gridpoints on the interior.
785 # NB: This will only work properly if the grid is equispaced
786 if M > 1:
787 Q = grid_orders[var]
788 if var in cont_vars:
789 trans_matrix, cont_idx[var], cont_alpha[var] = (
790 aggregate_blobs_onto_polynomial_grid_alt(
791 vals, pmv, origin_array, grid, J, Q
792 )
793 )
794 cont_M[var] = M
795 cont_discrete[var] = False
796 else:
797 trans_matrix = aggregate_blobs_onto_polynomial_grid(
798 vals, pmv, origin_array, grid, J, Q
799 )
800 else: # Skip if the grid is a dummy with only one value.
801 trans_matrix = np.ones((J, M))
802 if var in cont_vars:
803 cont_idx[var] = np.zeros(N, dtype=int)
804 cont_alpha[var] = np.zeros(N)
805 cont_M[var] = M
806 cont_discrete[var] = False
808 else: # Grid is discrete, can use simpler method
809 if grid is None:
810 M = np.max(vals.astype(int))
811 if var == "dead":
812 M = 2
813 grid = np.arange(M, dtype=int)
814 grids_out[var] = grid
815 M = grid.size
816 vals = vals.astype(int)
817 trans_matrix = aggregate_blobs_onto_discrete_grid(
818 vals, pmv, origin_array, M, J
819 )
820 if var in cont_vars:
821 cont_idx[var] = vals
822 cont_alpha[var] = np.zeros(N)
823 cont_M[var] = M
824 cont_discrete[var] = True
826 # Store the transition matrix for this variable
827 matrices_out[var] = trans_matrix
828 k += 1
830 # Construct an overall transition matrix from arrival to continuation variables.
831 # If this is the initializer block, the "arrival" variable is just the initial
832 # dummy state, and the "continuation" variables are actually the arrival variables
833 # for ordinary blocks/periods.
835 # Count the number of non-trivial dimensions. A continuation dimension
836 # is non-trivial if it is both continuous and has more than one grid node.
837 C = 0
838 shape = [N_orig]
839 trivial = []
840 for var in cont_vars:
841 shape.append(cont_M[var])
842 if (not cont_discrete[var]) and (cont_M[var] > 1):
843 C += 1
844 trivial.append(False)
845 else:
846 trivial.append(True)
847 trivial = np.array(trivial)
849 # Make a binary array of offsets from the base index
850 bin_array_base = np.array(list(product([0, 1], repeat=C)))
851 bin_array = np.empty((2**C, D), dtype=int)
852 some_zeros = np.zeros(2**C, dtype=int)
853 c = 0
854 for d in range(D):
855 bin_array[:, d] = some_zeros if trivial[d] else bin_array_base[:, c]
856 c += not trivial[d]
858 # Make a vector of dimensional offsets from the base index
859 dim_offsets = np.ones(D, dtype=int)
860 for d in range(D - 1):
861 dim_offsets[d] = np.prod(shape[(d + 2) :])
862 dim_offsets_X = np.tile(dim_offsets, (2**C, 1))
863 offsets = np.sum(bin_array * dim_offsets_X, axis=1)
865 # Make combined arrays of indices and alphas
866 index_array = np.empty((N, D), dtype=int)
867 alpha_array = np.empty((N, D, 2))
868 for d in range(D):
869 var = cont_vars[d]
870 index_array[:, d] = cont_idx[var]
871 alpha_array[:, d, 0] = 1.0 - cont_alpha[var]
872 alpha_array[:, d, 1] = cont_alpha[var]
873 idx_array = np.dot(index_array, dim_offsets)
875 # Make the master transition array
876 blank = np.zeros(np.array((N_orig, np.prod(shape[1:]))))
877 master_trans_array_X = calc_overall_trans_probs(
878 blank, idx_array, alpha_array, bin_array, offsets, pmv, origin_array
879 )
881 # Condition on survival if relevant
882 if "dead" in self.data.keys():
883 master_trans_array_X = np.reshape(master_trans_array_X, (N_orig, N_orig, 2))
884 survival_probs = np.reshape(matrices_out["dead"][:, 0], [N_orig, 1])
885 master_trans_array_X = master_trans_array_X[..., 0] / survival_probs
887 # Reshape the transition matrix depending on what kind of block this is
888 if arrival_N == 0:
889 # If this is the initializer block, the "transition" matrix is really
890 # just the initial distribution of states at model birth; flatten it.
891 master_init_array = master_trans_array_X.flatten()
892 else:
893 # In an ordinary period, reshape the transition array so it's square.
894 master_trans_array = np.reshape(master_trans_array_X, (N_orig, N_orig))
896 # Store the results as attributes of self
897 grids = {}
898 grids.update(grids_in)
899 grids.update(grids_out)
900 self.grids = grids
901 self.matrices = matrices_out
902 self.mesh = mesh_tuples
903 if twist is not None:
904 self.trans_array = master_trans_array
905 if arrival_N == 0:
906 self.init_dstn = master_init_array
909@dataclass(kw_only=True)
910class AgentSimulator:
911 """
912 A class for representing an entire simulator structure for an AgentType.
913 It includes a sequence of SimBlocks representing periods of the model, which
914 could be built from the information on an AgentType instance.
916 Parameters
917 ----------
918 name : str
919 Short name of this model.s
920 description : str
921 Textual description of what happens in this simulated block.
922 statement : str
923 Verbatim model statement that was used to create this simulator.
924 comments : dict
925 Dictionary of comments or descriptions for various model objects.
926 parameters : list[str]
927 List of parameter names used in the model.
928 distributions : list[str]
929 List of distribution names used in the model.
930 functions : list[str]
931 List of function names used in the model.
932 common: list[str]
933 Names of variables that are common across idiosyncratic agents.
934 types: dict
935 Dictionary of data types for all variables in the model.
936 N_agents: int
937 Number of idiosyncratic agents in this simulation.
938 T_total: int
939 Total number of periods in these agents' model.
940 T_sim: int
941 Maximum number of periods that will be simulated, determining the size
942 of the history arrays.
943 T_age: int
944 Period after which to automatically terminate an agent if they would
945 survive past this period.
946 stop_dead : bool
947 Whether simulated agents who draw dead=True should actually cease acting.
948 Default is True. Setting to False allows "cohort-style" simulation that
949 will generate many agents that survive to old ages. In most cases, T_sim
950 should not exceed T_age, unless the user really does want multiple succ-
951 essive cohorts to be born and fully simulated.
952 replace_dead : bool
953 Whether simulated agents who are marked as dead should be replaced with
954 newborns (default True) or simply cease acting without replacement (False).
955 The latter option is useful for models with state-dependent mortality,
956 to allow "cohort-style" simulation with the correct distribution of states
957 for survivors at each age. Setting to False has no effect if stop_dead is True.
958 periods: list[SimBlock]
959 Ordered list of simulation blocks, each representing a period.
960 twist : dict
961 Dictionary that maps period t-1 variables to period t variables, as a
962 relabeling "between" periods.
963 initializer : SimBlock
964 A special simulated block that should have *no* arrival variables, because
965 it represents the initialization of "newborn" agents.
966 data : dict
967 Dictionary that holds *current* values of model variables.
968 track_vars : list[str]
969 List of names of variables whose history should be tracked in the simulation.
970 history : dict
971 Dictionary that holds the histories of tracked variables.
972 """
974 name: str = field(default="")
975 description: str = field(default="")
976 statement: str = field(default="", repr=False)
977 comments: dict = field(default_factory=dict, repr=False)
978 parameters: list[str] = field(default_factory=list, repr=False)
979 distributions: list[str] = field(default_factory=list, repr=False)
980 functions: list[str] = field(default_factory=list, repr=False)
981 common: list[str] = field(default_factory=list, repr=False)
982 types: dict = field(default_factory=dict, repr=False)
983 N_agents: int = field(default=1)
984 T_total: int = field(default=1, repr=False)
985 T_sim: int = field(default=1)
986 T_age: int = field(default=0, repr=False)
987 stop_dead: bool = field(default=True)
988 replace_dead: bool = field(default=True)
989 periods: list[SimBlock] = field(default_factory=list, repr=False)
990 twist: dict = field(default_factory=dict, repr=False)
991 data: dict = field(default_factory=dict, repr=False)
992 initializer: field(default_factory=SimBlock, repr=False)
993 track_vars: list[str] = field(default_factory=list, repr=False)
994 history: dict = field(default_factory=dict, repr=False)
996 def simulate(self, T=None):
997 """
998 Simulates the model for T periods, including replacing dead agents as
999 warranted and storing tracked variables in the history. If T is not
1000 specified, the agents are simulated for the entire T_sim periods.
1001 This is the primary user-facing simulation method.
1002 """
1003 if T is None:
1004 T = self.T_sim - self.t_sim # All remaining simulated periods
1005 if (T + self.t_sim) > self.T_sim:
1006 raise ValueError("Can't simulate more than T_sim periods!")
1008 # Execute the simulation loop for T periods
1009 for t in range(T):
1010 # Do the ordinary work for simulating a period
1011 self.sim_one_period()
1013 # Mark agents who have reached maximum allowable age
1014 if "dead" in self.data.keys() and self.T_age > 0:
1015 too_old = self.t_age == self.T_age
1016 self.data["dead"][too_old] = True
1018 # Record tracked variables and advance age
1019 self.store_tracked_vars()
1020 self.advance_age()
1022 # Handle death and replacement depending on simulation style
1023 if "dead" in self.data.keys() and self.stop_dead:
1024 self.mark_dead_agents()
1025 self.t_sim += 1
1027 def reset(self):
1028 """
1029 Completely reset this simulator back to its original state so that it
1030 can be run from scratch. This should allow it to generate the same results
1031 every single time the simulator is run (if nothing changes).
1032 """
1033 N = self.N_agents
1034 T = self.T_sim
1035 self.t_sim = 0 # Time index for the simulation
1037 # Reset the variable data and history arrays
1038 self.clear_data()
1039 self.history = {}
1040 for var in self.track_vars:
1041 self.history[var] = np.empty((T, N), dtype=self.types[var])
1043 # Reset all of the blocks / periods
1044 self.initializer.reset()
1045 for t in range(len(self.periods)):
1046 self.periods[t].reset()
1048 # Specify all agents as "newborns" assigned to the initializer block
1049 self.t_seq_bool_array = np.zeros((self.T_total, N), dtype=bool)
1050 self.t_age = -np.ones(N, dtype=int)
1052 def clear_data(self, skip=None):
1053 """
1054 Reset all current data arrays back to blank, other than those designated
1055 to be skipped, if any.
1057 Parameters
1058 ----------
1059 skip : [str] or None
1060 Names of variables *not* to be cleared from data. Default is None.
1062 Returns
1063 -------
1064 None
1065 """
1066 if skip is None:
1067 skip = []
1068 N = self.N_agents
1069 # self.data = {}
1070 for var in self.types.keys():
1071 if var in skip:
1072 continue
1073 this_type = self.types[var]
1074 if this_type is float:
1075 self.data[var] = np.full((N,), np.nan)
1076 elif this_type is bool:
1077 self.data[var] = np.zeros((N,), dtype=bool)
1078 elif this_type is int:
1079 self.data[var] = np.zeros((N,), dtype=np.int32)
1080 elif this_type is complex:
1081 self.data[var] = np.full((N,), np.nan, dtype=complex)
1082 else:
1083 raise ValueError(
1084 "Type "
1085 + str(this_type)
1086 + " of variable "
1087 + var
1088 + " was not recognized!"
1089 )
1091 def mark_dead_agents(self):
1092 """
1093 Looks at the special data field "dead" and marks those agents for replacement.
1094 If no variable called "dead" has been defined, this is skipped.
1095 """
1096 who_died = self.data["dead"]
1097 self.t_seq_bool_array[:, who_died] = False
1098 self.t_age[who_died] = -1
1100 def create_newborns(self):
1101 """
1102 Calls the initializer to generate newborns where needed.
1103 """
1104 # Skip this step if there are no newborns
1105 newborns = self.t_age == -1
1106 if not np.any(newborns):
1107 return
1109 # Generate initial arrival variables
1110 N = np.sum(newborns)
1111 self.initializer.data = {} # by definition
1112 self.initializer.N = N
1113 self.initializer.run()
1115 # Set the initial arrival data for newborns and clear other variables
1116 init_arrival = self.periods[0].arrival
1117 for var in self.types:
1118 self.data[var][newborns] = (
1119 self.initializer.data[var]
1120 if var in init_arrival
1121 else np.empty(N, dtype=self.types[var])
1122 )
1124 # Set newborns' period to 0
1125 self.t_age[newborns] = 0
1126 self.t_seq_bool_array[0, newborns] = True
1128 def store_tracked_vars(self):
1129 """
1130 Record current values of requested variables in the history dictionary.
1131 """
1132 for var in self.track_vars:
1133 self.history[var][self.t_sim, :] = self.data[var]
1135 def advance_age(self):
1136 """
1137 Increments age for all agents, altering t_age and t_age_bool. Agents in
1138 the last period of the sequence will be assigned to the initial period.
1139 In a lifecycle model, those agents should be marked as dead and replaced
1140 in short order.
1141 """
1142 alive = self.t_age >= 0 # Don't age the dead
1143 self.t_age[alive] += 1
1144 X = self.t_seq_bool_array # For shorter typing on next line
1145 self.t_seq_bool_array[:, alive] = np.concatenate(
1146 (X[-1:, alive], X[:-1, alive]), axis=0
1147 )
1149 def sim_one_period(self):
1150 """
1151 Simulates one period of the model by advancing all agents one period.
1152 This includes creating newborns, but it does NOT include eliminating
1153 dead agents nor storing tracked results in the history. This method
1154 should usually not be called by a user, instead using simulate(1) if
1155 you want to run the model for exactly one period.
1156 """
1157 # Use the "twist" information to advance last period's end-of-period
1158 # information/values to be the arrival variables for this period. Then, for
1159 # each variable other than those brought in with the twist, wipe it clean.
1160 keepers = []
1161 for var_tm1 in self.twist:
1162 var_t = self.twist[var_tm1]
1163 keepers.append(var_t)
1164 self.data[var_t] = self.data[var_tm1].copy()
1165 self.clear_data(skip=keepers)
1167 # Create newborns first so the arrival vars exist. This should be done in
1168 # the first simulated period (t_sim=0) or if decedents should be replaced.
1169 if self.replace_dead or self.t_sim == 0:
1170 self.create_newborns()
1172 # Loop through ages and run the model on the appropriately aged agents
1173 for t in range(self.T_total):
1174 these = self.t_seq_bool_array[t, :]
1175 if not np.any(these):
1176 continue # Skip any "empty ages"
1177 this_period = self.periods[t]
1179 data_temp = {var: self.data[var][these] for var in this_period.arrival}
1180 this_period.data = data_temp
1181 this_period.N = np.sum(these)
1182 this_period.run()
1184 # Extract all of the variables from this period and write it to data
1185 for var in this_period.data.keys():
1186 self.data[var][these] = this_period.data[var]
1188 # Put time information into the data dictionary
1189 self.data["t_age"] = self.t_age.copy()
1190 self.data["t_seq"] = np.argmax(self.t_seq_bool_array, axis=0).astype(int)
1192 def make_transition_matrices(
1193 self, grid_specs, norm=None, fake_news_timing=False, for_t=None
1194 ):
1195 """
1196 Build Markov-style transition matrices for each period of the model, as
1197 well as the initial distribution of arrival variables for newborns.
1198 Stores results to the attributes of self as follows:
1200 - trans_arrays : List of Markov matrices for transitioning from the arrival
1201 state space in period t to the arrival state space in t+1.
1202 This transition includes death (and replacement).
1203 - newborn_dstn : Stochastic vector as a NumPy array, representing the distribution
1204 of arrival states for "newborns" who were just initialized.
1205 - state_grids : Nested list of tuples representing the arrival state space for
1206 each period. Each element corresponds to the discretized arrival
1207 state space point with the same index in trans_arrays (and
1208 newborn_dstn). Arrival states are ordered within a tuple in the
1209 same order as the model file. Linked from period[t].mesh.
1210 - outcome_arrays : List of dictionaries of arrays that cast from the arrival
1211 state space to the grid of outcome variables, for each period.
1212 Doing np.dot(state_dstn, outcome_arrays[t][var]) will yield
1213 the discretized distribution of that outcome variable. Linked
1214 from periods[t].matrices.
1215 - outcome_grids : List of dictionaries of discretized outcomes in each period.
1216 Keys are names of outcome variables, and entries are vectors
1217 of discretized values that the outcome variable can take on.
1218 Doing np.dot(np.dot(state_dstn, outcome_arrays[var]), outcome_grids[var])
1219 yields the *average* of that outcome in the population. Linked
1220 from periods[t].grids.
1222 Parameters
1223 ----------
1224 grid_specs : dict
1225 Dictionary of dictionaries with specifications for discretized grids
1226 of all variables of interest. If any arrival variables are omitted,
1227 they will be given a default trivial grid with one node at 0. This
1228 should only be done if that arrival variable is closely tied to the
1229 Harmenberg normalizing variable; see below. A grid specification must
1230 include a number of gridpoints N, and should also include a min and
1231 max if the variable is continuous. If the variable is discrete, the
1232 grid values are assumed to be 0,..,N.
1233 norm : str or None
1234 Name of the variable for which Harmenberg normalization should be
1235 applied, if any. This should be a variable that is directly drawn
1236 from a distribution, not a "downstream" variable.
1237 fake_news_timing : bool
1238 Indicator for whether this call is part of the "fake news" algorithm
1239 for constructing sequence space Jacobians (SSJs). This should only
1240 ever be set to True in that situation, which affects how mortality
1241 is handled between periods. In short, the simulator usually assumes
1242 that "newborns" start with t_seq=0, but during the fake news algorithm,
1243 that is not the case.
1244 for_t : list or None
1245 Optional list of time indices for which the matrices should be built.
1246 When not specified, all periods are constructed. The most common use
1247 for this arg is during the "fake news" algorithm for lifecycle models.
1249 Returns
1250 -------
1251 None
1252 """
1253 # Sort grid specifications into those needed by the initializer vs those
1254 # used by other blocks (ordinary periods)
1255 arrival = self.periods[0].arrival
1256 arrival_N = len(arrival)
1257 check_bool = np.zeros(arrival_N, dtype=bool)
1258 grid_specs_init_orig = {}
1259 grid_specs_other = {}
1260 for name in grid_specs.keys():
1261 if name in arrival:
1262 idx = arrival.index(name)
1263 check_bool[idx] = True
1264 grid_specs_init_orig[name] = copy(grid_specs[name])
1265 grid_specs_other[name] = copy(grid_specs[name])
1267 # Build the dictionary of arrival variables, making sure it's in the
1268 # same order as named self.arrival. For any arrival grids that are
1269 # not specified, make a dummy specification.
1270 grid_specs_init = {}
1271 for n in range(arrival_N):
1272 name = arrival[n]
1273 if check_bool[n]:
1274 grid_specs_init[name] = grid_specs_init_orig[name]
1275 continue
1276 dummy_grid_spec = {"N": 1}
1277 grid_specs_init[name] = dummy_grid_spec
1278 grid_specs_other[name] = dummy_grid_spec
1280 # Make the initial state distribution for newborns
1281 self.initializer.make_transition_matrices(grid_specs_init)
1282 self.newborn_dstn = self.initializer.init_dstn
1283 K = self.newborn_dstn.size
1285 # Make the period-by-period transition matrices
1286 these_t = range(len(self.periods)) if for_t is None else for_t
1287 for t in these_t:
1288 block = self.periods[t]
1289 block.make_transition_matrices(
1290 grid_specs_other, twist=self.twist, norm=norm
1291 )
1292 block.reset()
1294 # Extract the master transition matrices into a single list
1295 p2p_trans_arrays = [block.trans_array for block in self.periods]
1297 # Apply agent replacement to the last period of the model, representing
1298 # newborns filling in for decedents. This will usually only do anything
1299 # at all in "one period infinite horizon" models. If this is part of the
1300 # fake news algorithm for constructing SSJs, then replace decedents with
1301 # newborns in *all* periods, because model timing is funny in this case.
1302 if fake_news_timing:
1303 T_set = np.arange(len(self.periods)).tolist()
1304 else:
1305 T_set = [-1]
1306 newborn_dstn = np.reshape(self.newborn_dstn, (1, K))
1307 for t in T_set:
1308 if "dead" not in self.periods[t].matrices.keys():
1309 continue
1310 death_prbs = self.periods[t].matrices["dead"][:, 1]
1311 p2p_trans_arrays[t] *= np.tile(np.reshape(1 - death_prbs, (K, 1)), (1, K))
1312 p2p_trans_arrays[t] += np.reshape(death_prbs, (K, 1)) * newborn_dstn
1314 # Store the transition arrays as attributes of self
1315 self.trans_arrays = p2p_trans_arrays
1317 # Build and store lists of state meshes, outcome arrays, and outcome grids
1318 self.state_grids = [self.periods[t].mesh for t in range(len(self.periods))]
1319 self.outcome_grids = [self.periods[t].grids for t in range(len(self.periods))]
1320 self.outcome_arrays = [
1321 self.periods[t].matrices for t in range(len(self.periods))
1322 ]
1324 def find_steady_state(self):
1325 """
1326 Calculates the steady state distribution of arrival states for a "one period
1327 infinite horizon" model, storing the result to the attribute steady_state_dstn.
1328 Should only be run after make_transition_matrices(), and only if T_total = 1
1329 and the model is infinite horizon.
1330 """
1331 if self.T_total != 1:
1332 raise ValueError(
1333 "This method currently only works with one period infinite horizon problems."
1334 )
1336 # Find the eigenvector associated with the largest eigenvalue of the
1337 # infinite horizon transition matrix. The largest eigenvalue *should*
1338 # be 1 for any Markov matrix, but double check to be sure.
1339 trans_T = csr_matrix(self.trans_arrays[0].transpose())
1340 v, V = eigs(trans_T, k=1)
1341 if not np.isclose(v[0], 1.0):
1342 raise ValueError(
1343 "The largest eigenvalue of the transition matrix isn't close to 1!"
1344 )
1346 # Normalize that eigenvector and make sure its real, then store it
1347 D = V[:, 0]
1348 SS_dstn = (D / np.sum(D)).real
1349 self.steady_state_dstn = SS_dstn
1351 def get_long_run_average(self, var):
1352 """
1353 Calculate and return the long run / steady state population average of
1354 one named variable. Should only be run after find_steady_state().
1356 Parameters
1357 ----------
1358 var : str
1359 Name of the variable for which to calculate the long run average.
1361 Returns
1362 -------
1363 var_mean : float
1364 Long run / steady state population average of the variable.
1365 """
1366 if not hasattr(self, "steady_state_dstn"):
1367 raise ValueError("This method can only be run after find_steady_state()!")
1369 dstn = self.steady_state_dstn
1370 array = self.outcome_arrays[0][var]
1371 grid = self.outcome_grids[0][var]
1373 var_dstn = np.dot(dstn, array)
1374 var_mean = np.dot(var_dstn, grid)
1375 return var_mean
1377 def simulate_cohort_by_grids(
1378 self,
1379 outcomes,
1380 T_max=None,
1381 calc_dstn=False,
1382 calc_avg=True,
1383 from_dstn=None,
1384 ):
1385 """
1386 Generate a simulated "cohort style" history for this type of agents using
1387 discretized grid methods. Can only be run after running make_transition_matrices().
1388 Starting from the distribution of states at birth, the population is moved
1389 forward in time via the transition matrices, and the distribution and/or
1390 average of specified outcomes are stored in the dictionary attributes
1391 history_dstn and history_avg respectively.
1393 Parameters
1394 ----------
1395 outcomes : str or [str]
1396 Names of one or more outcome variables to be tracked during the grid
1397 simulation. Each named variable should have an outcome grid specified
1398 when make_transition_matrices() was called, whether explicitly or
1399 implicitly. The existence of these grids is checked as a first step.
1400 T_max : int or None
1401 If specified, the number of periods of the model to actually generate
1402 output for. If not specified, all periods are run.
1403 calc_dstn : bool
1404 Whether outcome distributions should be stored in the dictionary
1405 attribute history_dstn. The default is False.
1406 calc_avg : bool
1407 Whether outcome averages should be stored in the dictionary attribute
1408 history_avg. The default is True.
1409 from_dstn : np.array or None
1410 Optional initial distribution of arrival states. If not specified, the
1411 newborn distribution in the initializer is assumed to be used.
1413 Returns
1414 -------
1415 None
1416 """
1417 # First, verify that newborn and transition matrices exist for all periods
1418 if not hasattr(self, "newborn_dstn"):
1419 raise ValueError(
1420 "The newborn state distribution does not exist; make_transition_matrices() must be run before grid simulations!"
1421 )
1422 if T_max is None:
1423 T_max = self.T_total
1424 T_max = np.minimum(T_max, self.T_total)
1425 if not hasattr(self, "trans_arrays"):
1426 raise ValueError(
1427 "The transition arrays do not exist; make_transition_matrices() must be run before grid simulations!"
1428 )
1429 if len(self.trans_arrays) < T_max:
1430 raise ValueError(
1431 "There are somehow fewer elements of trans_array than there should be!"
1432 )
1433 if not (calc_dstn or calc_avg):
1434 return # No work actually requested, we're done here
1436 # Initialize generated output as requested
1437 if isinstance(outcomes, str):
1438 outcomes = [outcomes]
1439 if calc_dstn:
1440 history_dstn = {}
1441 for name in outcomes: # List will be concatenated to array at end
1442 history_dstn[name] = [] # if all distributions are same size
1443 if calc_avg:
1444 history_avg = {}
1445 for name in outcomes:
1446 history_avg[name] = np.empty(T_max)
1448 # Initialize the state distribution
1449 current_dstn = (
1450 self.newborn_dstn.copy() if from_dstn is None else from_dstn.copy()
1451 )
1452 state_dstn_by_age = []
1454 # Loop over requested periods of this agent type's model
1455 for t in range(T_max):
1456 state_dstn_by_age.append(current_dstn)
1458 # Calculate outcome distributions and averages as requested
1459 for name in outcomes:
1460 this_outcome = self.periods[t].matrices[name].transpose()
1461 this_dstn = np.dot(this_outcome, current_dstn)
1462 if calc_dstn:
1463 history_dstn[name].append(this_dstn)
1464 if calc_avg:
1465 this_grid = self.periods[t].grids[name]
1466 history_avg[name][t] = np.dot(this_dstn, this_grid)
1468 # Advance the distribution to the next period
1469 current_dstn = np.dot(self.trans_arrays[t].transpose(), current_dstn)
1471 # Reshape the distribution histories if possible
1472 if calc_dstn:
1473 for name in outcomes:
1474 dstn_sizes = np.array([dstn.size for dstn in history_dstn[name]])
1475 if np.all(dstn_sizes == dstn_sizes[0]):
1476 history_dstn[name] = np.stack(history_dstn[name], axis=1)
1478 # Store results as attributes of self
1479 self.state_dstn_by_age = state_dstn_by_age
1480 if calc_dstn:
1481 self.history_dstn = history_dstn
1482 if calc_avg:
1483 self.history_avg = history_avg
1485 def describe_model(self, display=True):
1486 """
1487 Convenience method that prints model information to screen.
1488 """
1489 # Make a twist statement
1490 twist_statement = ""
1491 for var_tm1 in self.twist.keys():
1492 var_t = self.twist[var_tm1]
1493 new_line = var_tm1 + "[t-1] <---> " + var_t + "[t]\n"
1494 twist_statement += new_line
1496 # Assemble the overall model statement
1497 output = ""
1498 output += "----------------------------------\n"
1499 output += "%%%%% INITIALIZATION AT BIRTH %%%%\n"
1500 output += "----------------------------------\n"
1501 output += self.initializer.statement
1502 output += "----------------------------------\n"
1503 output += "%%%% DYNAMICS WITHIN PERIOD t %%%%\n"
1504 output += "----------------------------------\n"
1505 output += self.statement
1506 output += "----------------------------------\n"
1507 output += "%%%%%%% RELABELING / TWIST %%%%%%%\n"
1508 output += "----------------------------------\n"
1509 output += twist_statement
1510 output += "-----------------------------------"
1512 # Return or print the output
1513 if display:
1514 print(output)
1515 return
1516 else:
1517 return output
1519 def describe_symbols(self, display=True):
1520 """
1521 Convenience method that prints symbol information to screen.
1522 """
1523 # Get names and types
1524 symbols_lines = []
1525 comments = []
1526 for key in self.comments.keys():
1527 comments.append(self.comments[key])
1529 # Get type of object
1530 if key in self.types.keys():
1531 this_type = str(self.types[key].__name__)
1532 elif key in self.distributions:
1533 this_type = "dstn"
1534 elif key in self.parameters:
1535 this_type = "param"
1536 elif key in self.functions:
1537 this_type = "func"
1539 # Add tags
1540 if key in self.common:
1541 this_type += ", common"
1542 # if key in self.solution:
1543 # this_type += ', solution'
1544 this_line = key + " (" + this_type + ")"
1545 symbols_lines.append(this_line)
1547 # Add comments, aligned
1548 symbols_text = ""
1549 longest = np.max([len(this) for this in symbols_lines])
1550 for j in range(len(symbols_lines)):
1551 line = symbols_lines[j]
1552 comment = comments[j]
1553 L = len(line)
1554 pad = (longest + 1) - L
1555 symbols_text += line + pad * " " + ": " + comment + "\n"
1557 # Return or print the output
1558 output = symbols_text
1559 if display:
1560 print(output)
1561 return
1562 else:
1563 return output
1565 def describe(self, symbols=True, model=True, display=True):
1566 """
1567 Convenience method for showing all information about the model.
1568 """
1569 # Asssemble the requested output
1570 output = self.name + ": " + self.description + "\n"
1571 if symbols or model:
1572 output += "\n"
1573 if symbols:
1574 output += "----------------------------------\n"
1575 output += "%%%%%%%%%%%%% SYMBOLS %%%%%%%%%%%%\n"
1576 output += "----------------------------------\n"
1577 output += self.describe_symbols(display=False)
1578 if model:
1579 output += self.describe_model(display=False)
1580 if symbols and not model:
1581 output += "----------------------------------"
1583 # Return or print the output
1584 if display:
1585 print(output)
1586 return
1587 else:
1588 return output
1591def make_simulator_from_agent(agent, stop_dead=True, replace_dead=True, common=None):
1592 """
1593 Build an AgentSimulator instance based on an AgentType instance. The AgentType
1594 should have its model attribute defined so that it can be parsed and translated
1595 into the simulator structure. The names of objects in the model statement
1596 should correspond to attributes of the AgentType.
1598 Parameters
1599 ----------
1600 agent : AgentType
1601 Agents for whom a new simulator is to be constructed.
1602 stop_dead : bool
1603 Whether simulated agents who draw dead=True should actually cease acting.
1604 Default is True. Setting to False allows "cohort-style" simulation that
1605 will generate many agents that survive to old ages. In most cases, T_sim
1606 should not exceed T_age, unless the user really does want multiple succ-
1607 essive cohorts to be born and fully simulated.
1608 replace_dead : bool
1609 Whether simulated agents who are marked as dead should be replaced with
1610 newborns (default True) or simply cease acting without replacement (False).
1611 The latter option is useful for models with state-dependent mortality,
1612 to allow "cohort-style" simulation with the correct distribution of states
1613 for survivors at each age. Setting False has no effect if stop_dead is True.
1614 common : [str] or None
1615 List of random variables that should be treated as commonly shared across
1616 all agents, rather than idiosyncratically drawn. If this is provided, it
1617 will override the model defaults.
1619 Returns
1620 -------
1621 new_simulator : AgentSimulator
1622 A simulator structure based on the agents.
1623 """
1624 # Read the model statement into a dictionary, and get names of attributes
1625 if hasattr(agent, "model_statement"): # look for a custom model statement
1626 model_statement = copy(agent.model_statement)
1627 else: # otherwise use the default model file
1628 with importlib.resources.open_text("HARK.models", agent.model_file) as f:
1629 model_statement = f.read()
1630 f.close()
1631 model = yaml.safe_load(model_statement)
1632 time_vary = agent.time_vary
1633 time_inv = agent.time_inv
1634 cycles = agent.cycles
1635 T_age = agent.T_age
1636 comments = {}
1637 RNG = agent.RNG # this is only for generating seeds for MarkovEvents
1639 # Extract basic fields from the model
1640 try:
1641 model_name = model["name"]
1642 except:
1643 model_name = "DEFAULT_NAME"
1644 try:
1645 description = model["description"]
1646 except:
1647 description = "(no description provided)"
1648 try:
1649 variables = model["symbols"]["variables"]
1650 except:
1651 variables = []
1652 try:
1653 twist = model["twist"]
1654 except:
1655 twist = {}
1656 if common is None:
1657 try:
1658 common = model["symbols"]["common"]
1659 except:
1660 common = []
1662 # Extract arrival variable names that were explicitly listed
1663 try:
1664 arrival = model["symbols"]["arrival"]
1665 except:
1666 arrival = []
1668 # Make a dictionary of declared data types and add comments
1669 types = {}
1670 for var_line in variables: # Loop through declared variables
1671 var_name, var_type, flags, desc = parse_declaration_for_parts(var_line)
1672 if var_type is not None:
1673 try:
1674 var_type = eval(var_type)
1675 except:
1676 raise ValueError(
1677 "Couldn't understand type "
1678 + var_type
1679 + " for declared variable "
1680 + var_name
1681 + "!"
1682 )
1683 else:
1684 var_type = float
1685 types[var_name] = var_type
1686 comments[var_name] = desc
1687 if ("arrival" in flags) and (var_name not in arrival):
1688 arrival.append(var_name)
1689 if ("common" in flags) and (var_name not in common):
1690 common.append(var_name)
1692 # Make a blank "template" period with structure but no data
1693 template_period, information, offset, solution, block_comments = (
1694 make_template_block(model, arrival, common)
1695 )
1696 comments.update(block_comments)
1698 # Make the agent initializer, without parameter values (etc)
1699 initializer, init_info = make_initializer(model, arrival, common)
1701 # Extract basic fields from the template period and model
1702 statement = template_period.statement
1703 content = template_period.content
1705 # Get the names of parameters, functions, and distributions
1706 parameters = []
1707 functions = []
1708 distributions = []
1709 for key in information.keys():
1710 val = information[key]
1711 if val is None:
1712 parameters.append(key)
1713 elif type(val) is NullFunc:
1714 functions.append(key)
1715 elif type(val) is Distribution:
1716 distributions.append(key)
1718 # Loop through variables that appear in the model block but were undeclared
1719 for var in information.keys():
1720 if var in types.keys():
1721 continue
1722 this = information[var]
1723 if (this is None) or (type(this) is Distribution) or (type(this) is NullFunc):
1724 continue
1725 types[var] = float
1726 comments[var] = ""
1727 if "dead" in types.keys():
1728 types["dead"] = bool
1729 comments["dead"] = "whether agent died this period"
1730 types["t_seq"] = int
1731 types["t_age"] = int
1732 comments["t_seq"] = "which period of the sequence the agent is on"
1733 comments["t_age"] = "how many periods the agent has already lived for"
1735 # Make a dictionary for the initializer and distribute information
1736 init_dict = {}
1737 for name in init_info.keys():
1738 try:
1739 init_dict[name] = getattr(agent, name)
1740 except:
1741 raise ValueError(
1742 "Couldn't get a value for initializer object " + name + "!"
1743 )
1744 initializer.content = init_dict
1745 initializer.distribute_content()
1747 # Make a dictionary of time-invariant parameters
1748 time_inv_dict = {}
1749 for name in content:
1750 if name in time_inv:
1751 try:
1752 time_inv_dict[name] = getattr(agent, name)
1753 except:
1754 raise ValueError(
1755 "Couldn't get a value for time-invariant object " + name + "!"
1756 )
1758 # Create a list of periods, pulling appropriate data from the agent for each one
1759 T_seq = len(agent.solution) # Number of periods in the solution sequence
1760 periods = []
1761 T_cycle = agent.T_cycle
1762 t_cycle = 0
1763 for t in range(T_seq):
1764 # Make a fresh copy of the template period
1765 new_period = deepcopy(template_period)
1767 # Make sure each period's events have unique seeds; this is only for MarkovEvents
1768 for event in new_period.events:
1769 if hasattr(event, "seed"):
1770 event.seed = RNG.integers(0, 2**31 - 1)
1772 # Make the parameter dictionary for this period
1773 new_param_dict = deepcopy(time_inv_dict)
1774 for name in content:
1775 if name in solution:
1776 if type(agent.solution[t]) is dict:
1777 new_param_dict[name] = agent.solution[t][name]
1778 else:
1779 new_param_dict[name] = getattr(agent.solution[t], name)
1780 elif name in time_vary:
1781 s = (t_cycle - 1) if name in offset else t_cycle
1782 try:
1783 new_param_dict[name] = getattr(agent, name)[s]
1784 except:
1785 raise ValueError(
1786 "Couldn't get a value for time-varying object "
1787 + name
1788 + " at time index "
1789 + str(s)
1790 + "!"
1791 )
1792 elif name in time_inv:
1793 continue
1794 else:
1795 raise ValueError(
1796 "The object called "
1797 + name
1798 + " is not named in time_inv nor time_vary!"
1799 )
1801 # Fill in content for this period, then add it to the list
1802 new_period.content = new_param_dict
1803 new_period.distribute_content()
1804 periods.append(new_period)
1806 # Advance time according to the cycle
1807 t_cycle += 1
1808 if t_cycle == T_cycle:
1809 t_cycle = 0
1811 # Calculate maximum age
1812 if T_age is None:
1813 T_age = 0
1814 if cycles > 0:
1815 T_age_max = T_seq - 1
1816 T_age = np.minimum(T_age_max, T_age)
1817 try:
1818 T_sim = agent.T_sim
1819 except:
1820 T_sim = 0 # very boring default!
1822 # Make and return the new simulator
1823 new_simulator = AgentSimulator(
1824 name=model_name,
1825 description=description,
1826 statement=statement,
1827 comments=comments,
1828 parameters=parameters,
1829 functions=functions,
1830 distributions=distributions,
1831 common=common,
1832 types=types,
1833 N_agents=agent.AgentCount,
1834 T_total=T_seq,
1835 T_sim=T_sim,
1836 T_age=T_age,
1837 stop_dead=stop_dead,
1838 replace_dead=replace_dead,
1839 periods=periods,
1840 twist=twist,
1841 initializer=initializer,
1842 track_vars=agent.track_vars,
1843 )
1844 new_simulator.solution = solution # this is for use by SSJ constructor
1845 return new_simulator
1848def make_template_block(model, arrival=None, common=None):
1849 """
1850 Construct a new SimBlock object as a "template" of the model block. It has
1851 events and reference information, but no values filled in.
1853 Parameters
1854 ----------
1855 model : dict
1856 Dictionary with model block information, probably read in as a yaml.
1857 arrival : [str] or None
1858 List of arrival variables that were flagged or explicitly listed.
1859 common : [str] or None
1860 List of variables that are common or shared across all agents, rather
1861 than idiosyncratically drawn.
1863 Returns
1864 -------
1865 template_block : SimBlock
1866 A "template" of this model block, with no parameters (etc) on it.
1867 info : dict
1868 Dictionary of model objects that were referenced within the block. Keys
1869 are object names and entries reveal what kind of object they are:
1870 - None --> parameter
1871 - 0 --> outcome/data variable (including arrival variables)
1872 - NullFunc --> function
1873 - Distribution --> distribution
1874 offset : [str]
1875 List of object names that are offset in time by one period.
1876 solution : [str]
1877 List of object names that are part of the model solution.
1878 comments : dict
1879 Dictionary of comments included with declared functions, distributions,
1880 and parameters.
1881 """
1882 if arrival is None:
1883 arrival = []
1884 if common is None:
1885 common = []
1887 # Extract explicitly listed metadata
1888 try:
1889 name = model["name"]
1890 except:
1891 name = "DEFAULT_NAME"
1892 try:
1893 offset = model["symbols"]["offset"]
1894 except:
1895 offset = []
1896 try:
1897 solution = model["symbols"]["solution"]
1898 except:
1899 solution = []
1901 # Extract parameters, functions, and distributions
1902 comments = {}
1903 parameters = {}
1904 if "parameters" in model["symbols"].keys():
1905 param_lines = model["symbols"]["parameters"]
1906 for line in param_lines:
1907 param_name, datatype, flags, desc = parse_declaration_for_parts(line)
1908 parameters[param_name] = None
1909 comments[param_name] = desc
1910 # TODO: what to do with parameter types?
1911 if ("offset" in flags) and (param_name not in offset):
1912 offset.append(param_name)
1913 if ("solution" in flags) and (param_name not in solution):
1914 solution.append(param_name)
1916 functions = {}
1917 if "functions" in model["symbols"].keys():
1918 func_lines = model["symbols"]["functions"]
1919 for line in func_lines:
1920 func_name, datatype, flags, desc = parse_declaration_for_parts(line)
1921 if (datatype is not None) and (datatype != "func"):
1922 raise ValueError(
1923 func_name
1924 + " was declared as a function, but given a different datatype!"
1925 )
1926 functions[func_name] = NullFunc()
1927 comments[func_name] = desc
1928 if ("offset" in flags) and (func_name not in offset):
1929 offset.append(func_name)
1930 if ("solution" in flags) and (func_name not in solution):
1931 solution.append(func_name)
1933 distributions = {}
1934 if "distributions" in model["symbols"].keys():
1935 dstn_lines = model["symbols"]["distributions"]
1936 for line in dstn_lines:
1937 dstn_name, datatype, flags, desc = parse_declaration_for_parts(line)
1938 if (datatype is not None) and (datatype != "dstn"):
1939 raise ValueError(
1940 dstn_name
1941 + " was declared as a distribution, but given a different datatype!"
1942 )
1943 distributions[dstn_name] = Distribution()
1944 comments[dstn_name] = desc
1945 if ("offset" in flags) and (dstn_name not in offset):
1946 offset.append(dstn_name)
1947 if ("solution" in flags) and (dstn_name not in solution):
1948 solution.append(dstn_name)
1950 # Combine those dictionaries into a single "information" dictionary, which
1951 # represents objects available *at that point* in the dynamic block
1952 content = parameters.copy()
1953 content.update(functions)
1954 content.update(distributions)
1955 info = deepcopy(content)
1956 for var in arrival:
1957 info[var] = 0 # Mark as a state variable
1959 # Parse the model dynamics
1960 dynamics = format_block_statement(model["dynamics"])
1962 # Make the list of ordered events
1963 events = []
1964 names_used_in_dynamics = []
1965 for line in dynamics:
1966 # Make the new event and add it to the list
1967 new_event, names_used = make_new_event(line, info)
1968 events.append(new_event)
1969 names_used_in_dynamics += names_used
1971 # Add newly assigned variables to the information set
1972 for var in new_event.assigns:
1973 if var in info.keys():
1974 raise ValueError(var + " is assigned, but already exists!")
1975 info[var] = 0
1977 # If any assigned variables are common, mark the event as common
1978 for var in new_event.assigns:
1979 if var in common:
1980 new_event.common = True
1981 break # No need to check further
1983 # Remove content that is never referenced within the dynamics
1984 delete_these = []
1985 for name in content.keys():
1986 if name not in names_used_in_dynamics:
1987 delete_these.append(name)
1988 for name in delete_these:
1989 del content[name]
1991 # Make a single string model statement
1992 statement = ""
1993 longest = np.max([len(event.statement) for event in events])
1994 for event in events:
1995 this_statement = event.statement
1996 L = len(this_statement)
1997 pad = (longest + 1) - L
1998 statement += this_statement + pad * " " + ": " + event.description + "\n"
2000 # Make a description for the template block
2001 if name is None:
2002 description = "template block for unnamed block"
2003 else:
2004 description = "template block for " + name
2006 # Make and return the new SimBlock
2007 template_block = SimBlock(
2008 description=description,
2009 arrival=arrival,
2010 content=content,
2011 statement=statement,
2012 events=events,
2013 )
2014 return template_block, info, offset, solution, comments
2017def make_initializer(model, arrival=None, common=None):
2018 """
2019 Construct a new SimBlock object to be the agent initializer, based on the
2020 model dictionary. It has structure and events, but no parameters (etc).
2022 Parameters
2023 ----------
2024 model : dict
2025 Dictionary with model initializer information, probably read in as a yaml.
2026 arrival : [str]
2027 List of arrival variables that were flagged or explicitly listed.
2029 Returns
2030 -------
2031 initializer : SimBlock
2032 A "template" of this model block, with no parameters (etc) on it.
2033 init_requires : dict
2034 Dictionary of model objects that are needed by the initializer to run.
2035 Keys are object names and entries reveal what kind of object they are:
2036 - None --> parameter
2037 - 0 --> outcome variable (these should include all arrival variables)
2038 - NullFunc --> function
2039 - Distribution --> distribution
2040 """
2041 if arrival is None:
2042 arrival = []
2043 if common is None:
2044 common = []
2045 try:
2046 name = model["name"]
2047 except:
2048 name = "DEFAULT_NAME"
2050 # Extract parameters, functions, and distributions
2051 parameters = {}
2052 if "parameters" in model["symbols"].keys():
2053 param_lines = model["symbols"]["parameters"]
2054 for line in param_lines:
2055 param_name, datatype, flags, desc = parse_declaration_for_parts(line)
2056 parameters[param_name] = None
2058 functions = {}
2059 if "functions" in model["symbols"].keys():
2060 func_lines = model["symbols"]["functions"]
2061 for line in func_lines:
2062 func_name, datatype, flags, desc = parse_declaration_for_parts(line)
2063 if (datatype is not None) and (datatype != "func"):
2064 raise ValueError(
2065 func_name
2066 + " was declared as a function, but given a different datatype!"
2067 )
2068 functions[func_name] = NullFunc()
2070 distributions = {}
2071 if "distributions" in model["symbols"].keys():
2072 dstn_lines = model["symbols"]["distributions"]
2073 for line in dstn_lines:
2074 dstn_name, datatype, flags, desc = parse_declaration_for_parts(line)
2075 if (datatype is not None) and (datatype != "dstn"):
2076 raise ValueError(
2077 dstn_name
2078 + " was declared as a distribution, but given a different datatype!"
2079 )
2080 distributions[dstn_name] = Distribution()
2082 # Combine those dictionaries into a single "information" dictionary
2083 content = parameters.copy()
2084 content.update(functions)
2085 content.update(distributions)
2086 info = deepcopy(content)
2088 # Parse the initialization routine
2089 initialize = format_block_statement(model["initialize"])
2091 # Make the list of ordered events
2092 events = []
2093 names_used_in_initialize = [] # this doesn't actually get used
2094 for line in initialize:
2095 # Make the new event and add it to the list
2096 new_event, names_used = make_new_event(line, info)
2097 events.append(new_event)
2098 names_used_in_initialize += names_used
2100 # Add newly assigned variables to the information set
2101 for var in new_event.assigns:
2102 if var in info.keys():
2103 raise ValueError(var + " is assigned, but already exists!")
2104 info[var] = 0
2106 # If any assigned variables are common, mark the event as common
2107 for var in new_event.assigns:
2108 if var in common:
2109 new_event.common = True
2110 break # No need to check further
2112 # Verify that all arrival variables were created in the initializer
2113 for var in arrival:
2114 if var not in info.keys():
2115 raise ValueError(
2116 "The arrival variable " + var + " was not set in the initialize block!"
2117 )
2119 # Make a blank dictionary with information the initializer needs
2120 init_requires = {}
2121 for event in events:
2122 for var in event.parameters.keys():
2123 if var not in init_requires.keys():
2124 try:
2125 init_requires[var] = parameters[var]
2126 except:
2127 raise ValueError(
2128 var
2129 + " was referenced in initialize, but not declared as a parameter!"
2130 )
2131 if type(event) is RandomEvent:
2132 try:
2133 dstn_name = event._dstn_name
2134 init_requires[dstn_name] = distributions[dstn_name]
2135 except:
2136 raise ValueError(
2137 dstn_name
2138 + " was referenced in initialize, but not declared as a distribution!"
2139 )
2140 if type(event) is EvaluationEvent:
2141 try:
2142 func_name = event._func_name
2143 init_requires[dstn_name] = functions[func_name]
2144 except:
2145 raise ValueError(
2146 func_name
2147 + " was referenced in initialize, but not declared as a function!"
2148 )
2150 # Make a single string initializer statement
2151 statement = ""
2152 longest = np.max([len(event.statement) for event in events])
2153 for event in events:
2154 this_statement = event.statement
2155 L = len(this_statement)
2156 pad = (longest + 1) - L
2157 statement += this_statement + pad * " " + ": " + event.description + "\n"
2159 # Make and return the new SimBlock
2160 initializer = SimBlock(
2161 description="agent initializer for " + name,
2162 content=init_requires,
2163 statement=statement,
2164 events=events,
2165 )
2166 return initializer, init_requires
2169def make_new_event(statement, info):
2170 """
2171 Makes a "blank" version of a model event based on a statement line. Determines
2172 which objects are needed vs assigned vs parameters / information from context.
2174 Parameters
2175 ----------
2176 statement : str
2177 One line of a model statement, which will be turned into an event.
2178 info : dict
2179 Empty dictionary of model information that already exists. Consists of
2180 arrival variables, already assigned variables, parameters, and functions.
2181 Typing of each is based on the kind of "empty" object.
2183 Returns
2184 -------
2185 new_event : ModelEvent
2186 A new model event with values and information missing, but structure set.
2187 names_used : [str]
2188 List of names of objects used in this expression.
2189 """
2190 # First determine what kind of event this is
2191 has_eq = "=" in statement
2192 has_tld = "~" in statement
2193 has_amp = "@" in statement
2194 has_brc = ("{" in statement) and ("}" in statement)
2195 has_brk = ("[" in statement) and ("]" in statement)
2196 event_type = None
2197 if has_eq:
2198 if has_tld:
2199 raise ValueError("A statement line can't have both an = and a ~!")
2200 if has_amp:
2201 event_type = EvaluationEvent
2202 else:
2203 event_type = DynamicEvent
2204 if has_tld:
2205 if has_brc:
2206 event_type = MarkovEvent
2207 elif has_brk:
2208 event_type = RandomIndexedEvent
2209 else:
2210 event_type = RandomEvent
2211 if event_type is None:
2212 raise ValueError("Statement line was not any valid type!")
2214 # Now make and return an appropriate event for that type
2215 if event_type is DynamicEvent:
2216 event_maker = make_new_dynamic
2217 if event_type is RandomEvent:
2218 event_maker = make_new_random
2219 if event_type is RandomIndexedEvent:
2220 event_maker = make_new_random_indexed
2221 if event_type is MarkovEvent:
2222 event_maker = make_new_markov
2223 if event_type is EvaluationEvent:
2224 event_maker = make_new_evaluation
2226 new_event, names_used = event_maker(statement, info)
2227 return new_event, names_used
2230def make_new_dynamic(statement, info):
2231 """
2232 Construct a new instance of DynamicEvent based on the given model statement
2233 line and a blank dictionary of parameters. The statement should already be
2234 verified to be a valid dynamic statement: it has an = but no ~ or @.
2236 Parameters
2237 ----------
2238 statement : str
2239 One line dynamics statement, which will be turned into a DynamicEvent.
2240 info : dict
2241 Empty dictionary of available information.
2243 Returns
2244 -------
2245 new_dynamic : DynamicEvent
2246 A new dynamic event with values and information missing, but structure set.
2247 names_used : [str]
2248 List of names of objects used in this expression.
2249 """
2250 # Cut the statement up into its LHS, RHS, and description
2251 lhs, rhs, description = parse_line_for_parts(statement, "=")
2253 # Parse the LHS (assignment) to get assigned variables
2254 assigns = parse_assignment(lhs)
2256 # Parse the RHS (dynamic statement) to extract object names used
2257 obj_names, is_indexed = extract_var_names_from_expr(rhs)
2259 # Allocate each variable to needed dynamic variables or parameters
2260 needs = []
2261 parameters = {}
2262 for j in range(len(obj_names)):
2263 var = obj_names[j]
2264 if var not in info.keys():
2265 raise ValueError(
2266 var + " is used in a dynamic expression, but does not (yet) exist!"
2267 )
2268 val = info[var]
2269 if type(val) is NullFunc:
2270 raise ValueError(
2271 var + " is used in a dynamic expression, but it's a function!"
2272 )
2273 if type(val) is Distribution:
2274 raise ValueError(
2275 var + " is used in a dynamic expression, but it's a distribution!"
2276 )
2277 if val is None:
2278 parameters[var] = None
2279 else:
2280 needs.append(var)
2282 # Declare a SymPy symbol for each variable used; these are temporary
2283 _args = []
2284 for j in range(len(obj_names)):
2285 _var = obj_names[j]
2286 if is_indexed[j]:
2287 exec(_var + " = IndexedBase('" + _var + "')")
2288 else:
2289 exec(_var + " = symbols('" + _var + "')")
2290 _args.append(symbols(_var))
2292 # Make a SymPy expression, then lambdify it
2293 sympy_expr = symbols(rhs)
2294 expr = lambdify(_args, sympy_expr)
2296 # Make an overall list of object names referenced in this event
2297 names_used = assigns + obj_names
2299 # Make and return the new dynamic event
2300 new_dynamic = DynamicEvent(
2301 description=description,
2302 statement=lhs + " = " + rhs,
2303 assigns=assigns,
2304 needs=needs,
2305 parameters=parameters,
2306 expr=expr,
2307 args=obj_names,
2308 )
2309 return new_dynamic, names_used
2312def make_new_random(statement, info):
2313 """
2314 Make a new random variable realization event based on the given model statement
2315 line and a blank dictionary of parameters. The statement should already be
2316 verified to be a valid random statement: it has a ~ but no = or [].
2318 Parameters
2319 ----------
2320 statement : str
2321 One line of the model statement, which will be turned into a random event.
2322 info : dict
2323 Empty dictionary of available information.
2325 Returns
2326 -------
2327 new_random : RandomEvent
2328 A new random event with values and information missing, but structure set.
2329 names_used : [str]
2330 List of names of objects used in this expression.
2331 """
2332 # Cut the statement up into its LHS, RHS, and description
2333 lhs, rhs, description = parse_line_for_parts(statement, "~")
2335 # Parse the LHS (assignment) to get assigned variables
2336 assigns = parse_assignment(lhs)
2338 # Verify that the RHS is actually a distribution
2339 if type(info[rhs]) is not Distribution:
2340 raise ValueError(
2341 rhs + " was treated as a distribution, but not declared as one!"
2342 )
2344 # Make an overall list of object names referenced in this event
2345 names_used = assigns + [rhs]
2347 # Make and return the new random event
2348 new_random = RandomEvent(
2349 description=description,
2350 statement=lhs + " ~ " + rhs,
2351 assigns=assigns,
2352 needs=[],
2353 parameters={},
2354 dstn=info[rhs],
2355 )
2356 new_random._dstn_name = rhs
2357 return new_random, names_used
2360def make_new_random_indexed(statement, info):
2361 """
2362 Make a new indexed random variable realization event based on the given model
2363 statement line and a blank dictionary of parameters. The statement should
2364 already be verified to be a valid random statement: it has a ~ and [].
2366 Parameters
2367 ----------
2368 statement : str
2369 One line of the model statement, which will be turned into a random event.
2370 info : dict
2371 Empty dictionary of available information.
2373 Returns
2374 -------
2375 new_random_indexed : RandomEvent
2376 A new random indexed event with values and information missing, but structure set.
2377 names_used : [str]
2378 List of names of objects used in this expression.
2379 """
2380 # Cut the statement up into its LHS, RHS, and description
2381 lhs, rhs, description = parse_line_for_parts(statement, "~")
2383 # Parse the LHS (assignment) to get assigned variables
2384 assigns = parse_assignment(lhs)
2386 # Split the RHS into the distribution and the index
2387 dstn, index = parse_random_indexed(rhs)
2389 # Verify that the RHS is actually a distribution
2390 if type(info[dstn]) is not Distribution:
2391 raise ValueError(
2392 dstn + " was treated as a distribution, but not declared as one!"
2393 )
2395 # Make an overall list of object names referenced in this event
2396 names_used = assigns + [dstn, index]
2398 # Make and return the new random indexed event
2399 new_random_indexed = RandomIndexedEvent(
2400 description=description,
2401 statement=lhs + " ~ " + rhs,
2402 assigns=assigns,
2403 needs=[index],
2404 parameters={},
2405 index=index,
2406 )
2407 new_random_indexed._dstn_name = dstn
2408 return new_random_indexed, names_used
2411def make_new_markov(statement, info):
2412 """
2413 Make a new Markov-type event based on the given model statement line and a
2414 blank dictionary of parameters. The statement should already be verified to
2415 be a valid Markov statement: it has a ~ and {} and maybe (). This can represent
2416 a Markov matrix transition event, a draw from a discrete index, or just a
2417 Bernoulli random variable. If a Bernoulli event, the "probabilties" can be
2418 idiosyncratic data.
2420 Parameters
2421 ----------
2422 statement : str
2423 One line of the model statement, which will be turned into a random event.
2424 info : dict
2425 Empty dictionary of available information.
2427 Returns
2428 -------
2429 new_markov : MarkovEvent
2430 A new Markov draw event with values and information missing, but structure set.
2431 names_used : [str]
2432 List of names of objects used in this expression.
2433 """
2434 # Cut the statement up into its LHS, RHS, and description
2435 lhs, rhs, description = parse_line_for_parts(statement, "~")
2437 # Parse the LHS (assignment) to get assigned variables
2438 assigns = parse_assignment(lhs)
2440 # Parse the RHS (Markov statement) for the array and index
2441 probs, index = parse_markov(rhs)
2442 if index is None:
2443 needs = []
2444 else:
2445 needs = [index]
2447 # Determine whether probs is an idiosyncratic variable or a parameter, and
2448 # set up the event to grab it appropriately
2449 if info[probs] is None:
2450 parameters = {probs: None}
2451 else:
2452 needs += [probs]
2453 parameters = {}
2455 # Make an overall list of object names referenced in this event
2456 names_used = assigns + needs + [probs]
2458 # Make and return the new Markov event
2459 new_markov = MarkovEvent(
2460 description=description,
2461 statement=lhs + " ~ " + rhs,
2462 assigns=assigns,
2463 needs=needs,
2464 parameters=parameters,
2465 probs=probs,
2466 index=index,
2467 )
2468 return new_markov, names_used
2471def make_new_evaluation(statement, info):
2472 """
2473 Make a new function evaluation event based the given model statement line
2474 and a blank dictionary of parameters. The statement should already be verified
2475 to be a valid evaluation statement: it has an @ and an = but no ~.
2477 Parameters
2478 ----------
2479 statement : str
2480 One line of the model statement, which will be turned into an eval event.
2481 info : dict
2482 Empty dictionary of available information.
2484 Returns
2485 -------
2486 new_evaluation : EvaluationEvent
2487 A new evaluation event with values and information missing, but structure set.
2488 names_used : [str]
2489 List of names of objects used in this expression.
2490 """
2491 # Cut the statement up into its LHS, RHS, and description
2492 lhs, rhs, description = parse_line_for_parts(statement, "=")
2494 # Parse the LHS (assignment) to get assigned variables
2495 assigns = parse_assignment(lhs)
2497 # Parse the RHS (evaluation) for the function and its arguments
2498 func, arguments = parse_evaluation(rhs)
2500 # Allocate each variable to needed dynamic variables or parameters
2501 needs = []
2502 parameters = {}
2503 for j in range(len(arguments)):
2504 var = arguments[j]
2505 if var not in info.keys():
2506 raise ValueError(
2507 var + " is used in an evaluation statement, but does not (yet) exist!"
2508 )
2509 val = info[var]
2510 if type(val) is NullFunc:
2511 raise ValueError(
2512 var
2513 + " is used as an argument an evaluation statement, but it's a function!"
2514 )
2515 if type(val) is Distribution:
2516 raise ValueError(
2517 var + " is used in an evaluation statement, but it's a distribution!"
2518 )
2519 if val is None:
2520 parameters[var] = None
2521 else:
2522 needs.append(var)
2524 # Make an overall list of object names referenced in this event
2525 names_used = assigns + arguments + [func]
2527 # Make and return the new evaluation event
2528 new_evaluation = EvaluationEvent(
2529 description=description,
2530 statement=lhs + " = " + rhs,
2531 assigns=assigns,
2532 needs=needs,
2533 parameters=parameters,
2534 arguments=arguments,
2535 func=info[func],
2536 )
2537 new_evaluation._func_name = func
2538 return new_evaluation, names_used
2541def look_for_char_and_remove(phrase, symb):
2542 """
2543 Check whether a symbol appears in a string, and remove it if it does.
2545 Parameters
2546 ----------
2547 phrase : str
2548 String to be searched for a symbol.
2549 symb : char
2550 Single character to be searched for.
2552 Returns
2553 -------
2554 out : str
2555 Possibly shortened input phrase.
2556 found : bool
2557 Whether the symbol was found and removed.
2558 """
2559 found = symb in phrase
2560 out = phrase.replace(symb, "")
2561 return out, found
2564def parse_declaration_for_parts(line):
2565 """
2566 Split a declaration line from a model file into the object's name, its datatype,
2567 any metadata flags, and any provided comment or description.
2569 Parameters
2570 ----------
2571 line : str
2572 Line of to be parsed into the object name, object type, and a comment or description.
2574 Returns
2575 -------
2576 name : str
2577 Name of the object.
2578 datatype : str or None
2579 Provided datatype string, in parentheses, if any.
2580 flags : [str]
2581 List of metadata flags that were detected. These include ! for a variable
2582 that is in arrival, * for any non-variable that's part of the solution,
2583 + for any object that is offset in time, and & for a common random variable.
2585 desc : str
2586 Comment or description, after //, if any.
2587 """
2588 flags = []
2589 check_for_flags = {"offset": "+", "arrival": "!", "solution": "*", "common": "&"}
2591 # First, separate off the comment or description, if any
2592 slashes = line.find("\\")
2593 desc = "" if slashes == -1 else line[(slashes + 2) :].strip()
2594 rem = line if slashes == -1 else line[:slashes].strip()
2596 # Now look for bracketing parentheses declaring a datatype
2597 lp = rem.find("(")
2598 if lp > -1:
2599 rp = rem.find(")")
2600 if rp == -1:
2601 raise ValueError("Unclosed parentheses on object declaration line!")
2602 datatype = rem[(lp + 1) : rp].strip()
2603 leftover = rem[:lp].strip()
2604 else:
2605 datatype = None
2606 leftover = rem
2608 # What's left over should be the object name plus any flags
2609 for key in check_for_flags.keys():
2610 symb = check_for_flags[key]
2611 leftover, found = look_for_char_and_remove(leftover, symb)
2612 if found:
2613 flags.append(key)
2615 # Remove any remaining spaces, and that *should* be the name
2616 name = leftover.replace(" ", "")
2617 # TODO: Check for valid name formatting based on characters.
2619 return name, datatype, flags, desc
2622def parse_line_for_parts(statement, symb):
2623 """
2624 Split one line of a model statement into its LHS, RHS, and description. The
2625 description is everything following \\, while the LHS and RHS are determined
2626 by a special symbol.
2628 Parameters
2629 ----------
2630 statement : str
2631 One line of a model statement, which will be parsed for its parts.
2632 symb : char
2633 The character that represents the divide between LHS and RHS
2635 Returns
2636 -------
2637 lhs : str
2638 The left-hand (assignment) side of the expression.
2639 rhs : str
2640 The right-hand (evaluation) side of the expression.
2641 desc : str
2642 The provided description of the expression.
2643 """
2644 eq = statement.find(symb)
2645 lhs = statement[:eq].replace(" ", "")
2646 not_lhs = statement[(eq + 1) :]
2647 comment = not_lhs.find("\\")
2648 desc = "" if comment == -1 else not_lhs[(comment + 2) :].strip()
2649 rhs = not_lhs if comment == -1 else not_lhs[:comment]
2650 rhs = rhs.replace(" ", "")
2651 return lhs, rhs, desc
2654def parse_assignment(lhs):
2655 """
2656 Get ordered list of assigned variables from the LHS of a model line.
2658 Parameters
2659 ----------
2660 lhs : str
2661 Left-hand side of a model expression
2663 Returns
2664 -------
2665 assigns : List[str]
2666 List of variable names that are assigned in this model line.
2667 """
2668 if lhs[0] == "(":
2669 if not lhs[-1] == ")":
2670 raise ValueError("Parentheses on assignment was not closed!")
2671 assigns = []
2672 pos = 0
2673 while pos != -1:
2674 pos += 1
2675 end = lhs.find(",", pos)
2676 var = lhs[pos:end]
2677 if var != "":
2678 assigns.append(var)
2679 pos = end
2680 else:
2681 assigns = [lhs]
2682 return assigns
2685def extract_var_names_from_expr(expression):
2686 """
2687 Parse the RHS of a dynamic model statement to get variable names used in it.
2689 Parameters
2690 ----------
2691 expression : str
2692 RHS of a model statement to be parsed for variable names.
2694 Returns
2695 -------
2696 var_names : List[str]
2697 List of variable names used in the expression. These *should* be dynamic
2698 variables and parameters, but not functions.
2699 indexed : List[bool]
2700 Indicators for whether each variable seems to be used with indexing.
2701 """
2702 var_names = []
2703 indexed = []
2704 math_symbols = "+-/*^%.(),[]{}<>"
2705 digits = "01234567890"
2706 cur = ""
2707 for j in range(len(expression)):
2708 c = expression[j]
2709 if (c in math_symbols) or ((c in digits) and cur == ""):
2710 if cur == "":
2711 continue
2712 if cur in var_names:
2713 cur = ""
2714 continue
2715 var_names.append(cur)
2716 if c == "[":
2717 indexed.append(True)
2718 else:
2719 indexed.append(False)
2720 cur = ""
2721 else:
2722 cur += c
2723 if cur != "" and cur not in var_names:
2724 var_names.append(cur)
2725 indexed.append(False) # final symbol couldn't possibly be indexed
2726 return var_names, indexed
2729def parse_evaluation(expression):
2730 """
2731 Separate a function evaluation expression into the function that is called
2732 and the variable inputs that are passed to it.
2734 Parameters
2735 ----------
2736 expression : str
2737 RHS of a function evaluation model statement, which will be parsed for
2738 the function and its inputs.
2740 Returns
2741 -------
2742 func_name : str
2743 Name of the function that will be called in this event.
2744 arg_names : List[str]
2745 List of arguments of the function.
2746 """
2747 # Get the name of the function: what's to the left of the @
2748 amp = expression.find("@")
2749 func_name = expression[:amp]
2751 # Check for parentheses formatting
2752 rem = expression[(amp + 1) :]
2753 if not rem[0] == "(":
2754 raise ValueError(
2755 "The @ in a function evaluation statement must be followed by (!"
2756 )
2757 if not rem[-1] == ")":
2758 raise ValueError("A function evaluation statement must end in )!")
2759 rem = rem[1:-1]
2761 # Parse what's inside the parentheses for argument names
2762 arg_names = []
2763 pos = 0
2764 go = True
2765 while go:
2766 end = rem.find(",", pos)
2767 if end > -1:
2768 arg = rem[pos:end]
2769 else:
2770 arg = rem[pos:]
2771 go = False
2772 if arg != "":
2773 arg_names.append(arg)
2774 pos = end + 1
2776 return func_name, arg_names
2779def parse_markov(expression):
2780 """
2781 Separate a Markov draw declaration into the array of probabilities and the
2782 index for idiosyncratic values.
2784 Parameters
2785 ----------
2786 expression : str
2787 RHS of a function evaluation model statement, which will be parsed for
2788 the probabilities name and index name.
2790 Returns
2791 -------
2792 probs : str
2793 Name of the probabilities object in this statement.
2794 index : str
2795 Name of the indexing variable in this statement.
2796 """
2797 # Get the name of the probabilitie
2798 lb = expression.find("{") # this *should* be 0
2799 rb = expression.find("}")
2800 if lb == -1 or rb == -1 or rb < (lb + 2):
2801 raise ValueError("A Markov assignment must have an {array}!")
2802 probs = expression[(lb + 1) : rb]
2804 # Get the name of the index, if any
2805 x = rb + 1
2806 lp = expression.find("(", x)
2807 rp = expression.find(")", x)
2808 if lp == -1 and rp == -1: # no index present at all
2809 return probs, None
2810 if lp == -1 or rp == -1 or rp < (lp + 2):
2811 raise ValueError("Improper Markov formatting: should be {probs}(index)!")
2812 index = expression[(lp + 1) : rp]
2814 return probs, index
2817def parse_random_indexed(expression):
2818 """
2819 Separate an indexed random variable assignment into the distribution and
2820 the index for it.
2822 Parameters
2823 ----------
2824 expression : str
2825 RHS of a function evaluation model statement, which will be parsed for
2826 the distribution name and index name.
2828 Returns
2829 -------
2830 dstn : str
2831 Name of the distribution in this statement.
2832 index : str
2833 Name of the indexing variable in this statement.
2834 """
2835 # Get the name of the index
2836 lb = expression.find("[")
2837 rb = expression.find("]")
2838 if lb == -1 or rb == -1 or rb < (lb + 2):
2839 raise ValueError("An indexed random variable assignment must have an [index]!")
2840 index = expression[(lb + 1) : rb]
2842 # Get the name of the distribution
2843 dstn = expression[:lb]
2845 return dstn, index
2848def format_block_statement(statement):
2849 """
2850 Ensure that a string stagement of a model block (maybe a period, maybe an
2851 initializer) is formatted as a list of strings, one statement per entry.
2853 Parameters
2854 ----------
2855 statement : str
2856 A model statement, which might be for a block or an initializer. The
2857 statement might be formatted as a list or as a single string.
2859 Returns
2860 -------
2861 block_statements: [str]
2862 A list of model statements, one per entry.
2863 """
2864 if type(statement) is str:
2865 if statement.find("\n") > -1:
2866 block_statements = []
2867 pos = 0
2868 end = statement.find("\n", pos)
2869 while end > -1:
2870 new_line = statement[pos:end]
2871 block_statements.append(new_line)
2872 pos = end + 1
2873 end = statement.find("\n", pos)
2874 else:
2875 block_statements = [statement.copy()]
2876 if type(statement) is list:
2877 for line in statement:
2878 if type(line) is not str:
2879 raise ValueError("The model statement somehow includes a non-string!")
2880 block_statements = statement.copy()
2881 return block_statements
2884@njit
2885def aggregate_blobs_onto_polynomial_grid(
2886 vals, pmv, origins, grid, J, Q
2887): # pragma: no cover
2888 """
2889 Numba-compatible helper function for casting "probability blobs" onto a discretized
2890 grid of outcome values, based on their origin in the arrival state space. This
2891 version is for non-continuation variables, returning only the probability array
2892 mapping from arrival states to the outcome variable.
2893 """
2894 bot = grid[0]
2895 top = grid[-1]
2896 M = grid.size
2897 Mm1 = M - 1
2898 N = pmv.size
2899 scale = 1.0 / (top - bot)
2900 order = 1.0 / Q
2901 diffs = grid[1:] - grid[:-1]
2903 probs = np.zeros((J, M))
2905 for n in range(N):
2906 x = vals[n]
2907 jj = origins[n]
2908 p = pmv[n]
2909 if (x > bot) and (x < top):
2910 ii = int(np.floor(((x - bot) * scale) ** order * Mm1))
2911 temp = (x - grid[ii]) / diffs[ii]
2912 probs[jj, ii] += (1.0 - temp) * p
2913 probs[jj, ii + 1] += temp * p
2914 elif x <= bot:
2915 probs[jj, 0] += p
2916 else:
2917 probs[jj, -1] += p
2918 return probs
2921@njit
2922def aggregate_blobs_onto_polynomial_grid_alt(
2923 vals, pmv, origins, grid, J, Q
2924): # pragma: no cover
2925 """
2926 Numba-compatible helper function for casting "probability blobs" onto a discretized
2927 grid of outcome values, based on their origin in the arrival state space. This
2928 version is for ncontinuation variables, returning the probability array mapping
2929 from arrival states to the outcome variable, the index in the outcome variable grid
2930 for each blob, and the alpha weighting between gridpoints.
2931 """
2932 bot = grid[0]
2933 top = grid[-1]
2934 M = grid.size
2935 Mm1 = M - 1
2936 N = pmv.size
2937 scale = 1.0 / (top - bot)
2938 order = 1.0 / Q
2939 diffs = grid[1:] - grid[:-1]
2941 probs = np.zeros((J, M))
2942 idx = np.empty(N, dtype=np.dtype(np.int32))
2943 alpha = np.empty(N)
2945 for n in range(N):
2946 x = vals[n]
2947 jj = origins[n]
2948 p = pmv[n]
2949 if (x > bot) and (x < top):
2950 ii = int(np.floor(((x - bot) * scale) ** order * Mm1))
2951 temp = (x - grid[ii]) / diffs[ii]
2952 probs[jj, ii] += (1.0 - temp) * p
2953 probs[jj, ii + 1] += temp * p
2954 alpha[n] = temp
2955 idx[n] = ii
2956 elif x <= bot:
2957 probs[jj, 0] += p
2958 alpha[n] = 0.0
2959 idx[n] = 0
2960 else:
2961 probs[jj, -1] += p
2962 alpha[n] = 1.0
2963 idx[n] = M - 2
2964 return probs, idx, alpha
2967@njit
2968def aggregate_blobs_onto_discrete_grid(vals, pmv, origins, M, J): # pragma: no cover
2969 """
2970 Numba-compatible helper function for allocating "probability blobs" to a grid
2971 over a discrete state-- the state itself is truly discrete.
2972 """
2973 out = np.zeros((J, M))
2974 N = pmv.size
2975 for n in range(N):
2976 ii = vals[n]
2977 jj = origins[n]
2978 p = pmv[n]
2979 out[jj, ii] += p
2980 return out
2983@njit
2984def calc_overall_trans_probs(
2985 out, idx, alpha, binary, offset, pmv, origins
2986): # pragma: no cover
2987 """
2988 Numba-compatible helper function for combining transition probabilities from
2989 the arrival state space to *multiple* continuation variables into a single
2990 unified transition matrix.
2991 """
2992 N = alpha.shape[0]
2993 B = binary.shape[0]
2994 D = binary.shape[1]
2995 for n in range(N):
2996 ii = origins[n]
2997 jj_base = idx[n]
2998 p = pmv[n]
2999 for b in range(B):
3000 adj = offset[b]
3001 P = p
3002 for d in range(D):
3003 k = binary[b, d]
3004 P *= alpha[n, d, k]
3005 jj = jj_base + adj
3006 out[ii, jj] += P
3007 return out