Coverage for HARK/core.py: 83%
925 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-02 05:14 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-02 05:14 +0000
1"""
2High-level functions and classes for solving a wide variety of economic models.
3The "core" of HARK is a framework for "microeconomic" and "macroeconomic"
4models. A micro model concerns the dynamic optimization problem for some type
5of agents, where agents take the inputs to their problem as exogenous. A macro
6model adds an additional layer, endogenizing some of the inputs to the micro
7problem by finding a general equilibrium dynamic rule.
8"""
10# Set logging and define basic functions
11import inspect
12import logging
13import sys
14from collections import namedtuple
15from copy import copy, deepcopy
16from dataclasses import dataclass, field
17from time import time
18from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
19from warnings import warn
20import multiprocessing
21from joblib import Parallel, delayed
23import numpy as np
24import pandas as pd
25from xarray import DataArray
27from HARK.distributions import (
28 Distribution,
29 IndexDistribution,
30 combine_indep_dstns,
31)
32from HARK.utilities import NullFunc, get_arg_names, get_it_from
33from HARK.simulator import make_simulator_from_agent
34from HARK.SSJutils import (
35 make_basic_SSJ_matrices,
36 calc_shock_response_manually,
37)
39logging.basicConfig(format="%(message)s")
40_log = logging.getLogger("HARK")
41_log.setLevel(logging.ERROR)
44def disable_logging():
45 _log.disabled = True
48def enable_logging():
49 _log.disabled = False
52def warnings():
53 _log.setLevel(logging.WARNING)
56def quiet():
57 _log.setLevel(logging.ERROR)
60def verbose():
61 _log.setLevel(logging.INFO)
64def set_verbosity_level(level):
65 _log.setLevel(level)
68class Parameters:
69 """
70 A smart container for model parameters that handles age-varying dynamics.
72 This class stores parameters as an internal dictionary and manages their
73 age-varying properties, providing both attribute-style and dictionary-style
74 access. It is designed to handle the time-varying dynamics of parameters
75 in economic models.
77 Attributes
78 ----------
79 _length : int
80 The terminal age of the agents in the model.
81 _invariant_params : Set[str]
82 A set of parameter names that are invariant over time.
83 _varying_params : Set[str]
84 A set of parameter names that vary over time.
85 _parameters : Dict[str, Any]
86 The internal dictionary storing all parameters.
87 """
89 __slots__ = ("_length", "_invariant_params", "_varying_params", "_parameters")
91 def __init__(self, **parameters: Any) -> None:
92 """
93 Initialize a Parameters object and parse the age-varying dynamics of parameters.
95 Parameters
96 ----------
97 **parameters : Any
98 Any number of parameters in the form key=value.
99 """
100 self._length: int = parameters.pop("T_cycle", 1)
101 self._invariant_params: Set[str] = set()
102 self._varying_params: Set[str] = set()
103 self._parameters: Dict[str, Any] = {"T_cycle": self._length}
105 for key, value in parameters.items():
106 self[key] = value
108 def __getitem__(self, item_or_key: Union[int, str]) -> Union["Parameters", Any]:
109 """
110 Access parameters by age index or parameter name.
112 If item_or_key is an integer, returns a Parameters object with the parameters
113 that apply to that age. This includes all invariant parameters and the
114 `item_or_key`th element of all age-varying parameters. If item_or_key is a
115 string, it returns the value of the parameter with that name.
117 Parameters
118 ----------
119 item_or_key : Union[int, str]
120 Age index or parameter name.
122 Returns
123 -------
124 Union[Parameters, Any]
125 A new Parameters object for the specified age, or the value of the
126 specified parameter.
128 Raises
129 ------
130 ValueError:
131 If the age index is out of bounds.
132 KeyError:
133 If the parameter name is not found.
134 TypeError:
135 If the key is neither an integer nor a string.
136 """
137 if isinstance(item_or_key, int):
138 if item_or_key >= self._length:
139 raise ValueError(
140 f"Age {item_or_key} is out of bounds (max: {self._length - 1})."
141 )
143 params = {key: self._parameters[key] for key in self._invariant_params}
144 params.update(
145 {
146 key: (
147 self._parameters[key][item_or_key]
148 if isinstance(self._parameters[key], (list, tuple, np.ndarray))
149 else self._parameters[key]
150 )
151 for key in self._varying_params
152 }
153 )
154 return Parameters(**params)
155 elif isinstance(item_or_key, str):
156 return self._parameters[item_or_key]
157 else:
158 raise TypeError("Key must be an integer (age) or string (parameter name).")
160 def __setitem__(self, key: str, value: Any) -> None:
161 """
162 Set parameter values, automatically inferring time variance.
164 If the parameter is a scalar, numpy array, boolean, distribution, callable
165 or None, it is assumed to be invariant over time. If the parameter is a
166 list or tuple, it is assumed to be varying over time. If the parameter
167 is a list or tuple of length greater than 1, the length of the list or
168 tuple must match the `_length` attribute of the Parameters object.
170 Parameters
171 ----------
172 key : str
173 Name of the parameter.
174 value : Any
175 Value of the parameter.
177 Raises
178 ------
179 ValueError:
180 If the parameter name is not a string or if the value type is unsupported.
181 If the parameter value is inconsistent with the current model length.
182 """
183 if not isinstance(key, str):
184 raise ValueError(f"Parameter name must be a string, got {type(key)}")
186 if isinstance(
187 value, (int, float, np.ndarray, type(None), Distribution, bool, Callable)
188 ):
189 self._invariant_params.add(key)
190 self._varying_params.discard(key)
191 elif isinstance(value, (list, tuple)):
192 if len(value) == 1:
193 value = value[0]
194 self._invariant_params.add(key)
195 self._varying_params.discard(key)
196 elif self._length is None or self._length == 1:
197 self._length = len(value)
198 self._varying_params.add(key)
199 self._invariant_params.discard(key)
200 elif len(value) == self._length:
201 self._varying_params.add(key)
202 self._invariant_params.discard(key)
203 else:
204 raise ValueError(
205 f"Parameter {key} must have length 1 or {self._length}, not {len(value)}"
206 )
207 else:
208 raise ValueError(f"Unsupported type for parameter {key}: {type(value)}")
210 self._parameters[key] = value
212 def __iter__(self) -> Iterator[str]:
213 """Allow iteration over parameter names."""
214 return iter(self._parameters)
216 def __len__(self) -> int:
217 """Return the number of parameters."""
218 return len(self._parameters)
220 def keys(self) -> Iterator[str]:
221 """Return a view of parameter names."""
222 return self._parameters.keys()
224 def values(self) -> Iterator[Any]:
225 """Return a view of parameter values."""
226 return self._parameters.values()
228 def items(self) -> Iterator[Tuple[str, Any]]:
229 """Return a view of parameter (name, value) pairs."""
230 return self._parameters.items()
232 def to_dict(self) -> Dict[str, Any]:
233 """
234 Convert parameters to a plain dictionary.
236 Returns
237 -------
238 Dict[str, Any]
239 A dictionary containing all parameters.
240 """
241 return dict(self._parameters)
243 def to_namedtuple(self) -> namedtuple:
244 """
245 Convert parameters to a namedtuple.
247 Returns
248 -------
249 namedtuple
250 A namedtuple containing all parameters.
251 """
252 return namedtuple("Parameters", self.keys())(**self.to_dict())
254 def update(self, other: Union["Parameters", Dict[str, Any]]) -> None:
255 """
256 Update parameters from another Parameters object or dictionary.
258 Parameters
259 ----------
260 other : Union[Parameters, Dict[str, Any]]
261 The source of parameters to update from.
263 Raises
264 ------
265 TypeError
266 If the input is neither a Parameters object nor a dictionary.
267 """
268 if isinstance(other, Parameters):
269 for key, value in other._parameters.items():
270 self[key] = value
271 elif isinstance(other, dict):
272 for key, value in other.items():
273 self[key] = value
274 else:
275 raise TypeError(
276 "Update source must be a Parameters object or a dictionary."
277 )
279 def __repr__(self) -> str:
280 """Return a detailed string representation of the Parameters object."""
281 return (
282 f"Parameters(_length={self._length}, "
283 f"_invariant_params={self._invariant_params}, "
284 f"_varying_params={self._varying_params}, "
285 f"_parameters={self._parameters})"
286 )
288 def __str__(self) -> str:
289 """Return a simple string representation of the Parameters object."""
290 return f"Parameters({str(self._parameters)})"
292 def __getattr__(self, name: str) -> Any:
293 """
294 Allow attribute-style access to parameters.
296 Parameters
297 ----------
298 name : str
299 Name of the parameter to access.
301 Returns
302 -------
303 Any
304 The value of the specified parameter.
306 Raises
307 ------
308 AttributeError:
309 If the parameter name is not found.
310 """
311 if name.startswith("_"):
312 return super().__getattribute__(name)
313 try:
314 return self._parameters[name]
315 except KeyError:
316 raise AttributeError(f"'Parameters' object has no attribute '{name}'")
318 def __setattr__(self, name: str, value: Any) -> None:
319 """
320 Allow attribute-style setting of parameters.
322 Parameters
323 ----------
324 name : str
325 Name of the parameter to set.
326 value : Any
327 Value to set for the parameter.
328 """
329 if name.startswith("_"):
330 super().__setattr__(name, value)
331 else:
332 self[name] = value
334 def __contains__(self, item: str) -> bool:
335 """Check if a parameter exists in the Parameters object."""
336 return item in self._parameters
338 def copy(self) -> "Parameters":
339 """
340 Create a deep copy of the Parameters object.
342 Returns
343 -------
344 Parameters
345 A new Parameters object with the same contents.
346 """
347 return deepcopy(self)
349 def add_to_time_vary(self, *params: str) -> None:
350 """
351 Adds any number of parameters to the time-varying set.
353 Parameters
354 ----------
355 *params : str
356 Any number of strings naming parameters to be added to time_vary.
357 """
358 for param in params:
359 if param in self._parameters:
360 self._varying_params.add(param)
361 self._invariant_params.discard(param)
362 else:
363 warn(
364 f"Parameter '{param}' does not exist and cannot be added to time_vary."
365 )
367 def add_to_time_inv(self, *params: str) -> None:
368 """
369 Adds any number of parameters to the time-invariant set.
371 Parameters
372 ----------
373 *params : str
374 Any number of strings naming parameters to be added to time_inv.
375 """
376 for param in params:
377 if param in self._parameters:
378 self._invariant_params.add(param)
379 self._varying_params.discard(param)
380 else:
381 warn(
382 f"Parameter '{param}' does not exist and cannot be added to time_inv."
383 )
385 def del_from_time_vary(self, *params: str) -> None:
386 """
387 Removes any number of parameters from the time-varying set.
389 Parameters
390 ----------
391 *params : str
392 Any number of strings naming parameters to be removed from time_vary.
393 """
394 for param in params:
395 self._varying_params.discard(param)
397 def del_from_time_inv(self, *params: str) -> None:
398 """
399 Removes any number of parameters from the time-invariant set.
401 Parameters
402 ----------
403 *params : str
404 Any number of strings naming parameters to be removed from time_inv.
405 """
406 for param in params:
407 self._invariant_params.discard(param)
409 def get(self, key: str, default: Any = None) -> Any:
410 """
411 Get a parameter value, returning a default if not found.
413 Parameters
414 ----------
415 key : str
416 The parameter name.
417 default : Any, optional
418 The default value to return if the key is not found.
420 Returns
421 -------
422 Any
423 The parameter value or the default.
424 """
425 return self._parameters.get(key, default)
427 def set_many(self, **kwargs: Any) -> None:
428 """
429 Set multiple parameters at once.
431 Parameters
432 ----------
433 **kwargs : Keyword arguments representing parameter names and values.
434 """
435 for key, value in kwargs.items():
436 self[key] = value
438 def is_time_varying(self, key: str) -> bool:
439 """
440 Check if a parameter is time-varying.
442 Parameters
443 ----------
444 key : str
445 The parameter name.
447 Returns
448 -------
449 bool
450 True if the parameter is time-varying, False otherwise.
451 """
452 return key in self._varying_params
455class Model:
456 """
457 A class with special handling of parameters assignment.
458 """
460 def __init__(self):
461 if not hasattr(self, "parameters"):
462 self.parameters = {}
463 if not hasattr(self, "constructors"):
464 self.constructors = {}
466 def assign_parameters(self, **kwds):
467 """
468 Assign an arbitrary number of attributes to this agent.
470 Parameters
471 ----------
472 **kwds : keyword arguments
473 Any number of keyword arguments of the form key=value.
474 Each value will be assigned to the attribute named in self.
476 Returns
477 -------
478 None
479 """
480 self.parameters.update(kwds)
481 for key in kwds:
482 setattr(self, key, kwds[key])
484 def get_parameter(self, name):
485 """
486 Returns a parameter of this model
488 Parameters
489 ----------
490 name : str
491 The name of the parameter to get
493 Returns
494 -------
495 value : The value of the parameter
496 """
497 return self.parameters[name]
499 def __eq__(self, other):
500 if isinstance(other, type(self)):
501 return self.parameters == other.parameters
503 return NotImplemented
505 def __str__(self):
506 type_ = type(self)
507 module = type_.__module__
508 qualname = type_.__qualname__
510 s = f"<{module}.{qualname} object at {hex(id(self))}.\n"
511 s += "Parameters:"
513 for p in self.parameters:
514 s += f"\n{p}: {self.parameters[p]}"
516 s += ">"
517 return s
519 def describe(self):
520 return self.__str__()
522 def del_param(self, param_name):
523 """
524 Deletes a parameter from this instance, removing it both from the object's
525 namespace (if it's there) and the parameters dictionary (likewise).
527 Parameters
528 ----------
529 param_name : str
530 A string naming a parameter or data to be deleted from this instance.
531 Removes information from self.parameters dictionary and own namespace.
533 Returns
534 -------
535 None
536 """
537 if param_name in self.parameters:
538 del self.parameters[param_name]
539 if hasattr(self, param_name):
540 delattr(self, param_name)
542 def construct(self, *args, force=False):
543 """
544 Top-level method for building constructed inputs. If called without any
545 inputs, construct builds each of the objects named in the keys of the
546 constructors dictionary; it draws inputs for the constructors from the
547 parameters dictionary and adds its results to the same. If passed one or
548 more strings as arguments, the method builds only the named keys. The
549 method will do multiple "passes" over the requested keys, as some cons-
550 tructors require inputs built by other constructors. If any requested
551 constructors failed to build due to missing data, those keys (and the
552 missing data) will be named in self._missing_key_data. Other errors are
553 recorded in the dictionary attribute _constructor_errors.
555 This method tries to "start from scratch" by removing prior constructed
556 objects, holding them in a backup dictionary during construction. This
557 is done so that dependencies among constructors are resolved properly,
558 without mistakenly relying on "old information". A backup value is used
559 if a constructor function is set to None (i.e. "don't do anything"), or
560 if the construct method fails to produce a new object.
562 Parameters
563 ----------
564 *args : str, optional
565 Keys of self.constructors that are requested to be constructed.
566 If no arguments are passed, *all* elements of the dictionary are implied.
567 force : bool, optional
568 When True, the method will force its way past any errors, including
569 missing constructors, missing arguments for constructors, and errors
570 raised during execution of constructors. Information about all such
571 errors is stored in the dictionary attributes described above. When
572 False (default), any errors or exception will be raised.
574 Returns
575 -------
576 None
577 """
578 # Set up the requested work
579 if len(args) > 0:
580 keys = args
581 else:
582 keys = list(self.constructors.keys())
583 N_keys = len(keys)
584 keys_complete = np.zeros(N_keys, dtype=bool)
585 if N_keys == 0:
586 return # Do nothing if there are no constructed objects
588 # Remove pre-existing constructed objects, preventing "incomplete" updates,
589 # but store the current values in a backup dictionary in case something fails
590 backup = {}
591 for key in keys:
592 if hasattr(self, key):
593 backup[key] = getattr(self, key)
594 self.del_param(key)
596 # Get the dictionary of constructor errors
597 if not hasattr(self, "_constructor_errors"):
598 self._constructor_errors = {}
599 errors = self._constructor_errors
601 # As long as the work isn't complete and we made some progress on the last
602 # pass, repeatedly perform passes of trying to construct objects
603 any_keys_incomplete = np.any(np.logical_not(keys_complete))
604 go = any_keys_incomplete
605 while go:
606 anything_accomplished_this_pass = False # Nothing done yet!
607 missing_key_data = [] # Keep this up-to-date on each pass
609 # Loop over keys to be constructed
610 for i in range(N_keys):
611 if keys_complete[i]:
612 continue # This key has already been built
614 # Get this key and its constructor function
615 key = keys[i]
616 try:
617 constructor = self.constructors[key]
618 except Exception as not_found:
619 errors[key] = "No constructor found for " + str(not_found)
620 if force:
621 continue
622 else:
623 raise ValueError("No constructor found for " + key) from None
625 # If this constructor is None, do nothing and mark it as completed;
626 # this includes restoring the previous value if it exists
627 if constructor is None:
628 if key in backup.keys():
629 setattr(self, key, backup[key])
630 self.parameters[key] = backup[key]
631 keys_complete[i] = True
632 anything_accomplished_this_pass = True # We did something!
633 continue
635 # SPECIAL: if the constructor is get_it_from, handle it separately
636 if isinstance(constructor, get_it_from):
637 try:
638 parent = getattr(self, constructor.name)
639 query = key
640 any_missing = False
641 missing_args = []
642 except:
643 parent = None
644 query = None
645 any_missing = True
646 missing_args = [constructor.name]
647 temp_dict = {"parent": parent, "query": query}
649 # Get the names of arguments for this constructor and try to gather them
650 else: # (if it's not the special case of get_it_from)
651 args_needed = get_arg_names(constructor)
652 has_no_default = {
653 k: v.default is inspect.Parameter.empty
654 for k, v in inspect.signature(constructor).parameters.items()
655 }
656 temp_dict = {}
657 any_missing = False
658 missing_args = []
659 for j in range(len(args_needed)):
660 this_arg = args_needed[j]
661 if hasattr(self, this_arg):
662 temp_dict[this_arg] = getattr(self, this_arg)
663 else:
664 try:
665 temp_dict[this_arg] = self.parameters[this_arg]
666 except:
667 if has_no_default[this_arg]:
668 # Record missing key-data pair
669 any_missing = True
670 missing_key_data.append((key, this_arg))
671 missing_args.append(this_arg)
673 # If all of the required data was found, run the constructor and
674 # store the result in parameters (and on self)
675 if not any_missing:
676 try:
677 temp = constructor(**temp_dict)
678 except Exception as problem:
679 errors[key] = str(type(problem)) + ": " + str(problem)
680 self.del_param(key)
681 if force:
682 continue
683 else:
684 raise
685 setattr(self, key, temp)
686 self.parameters[key] = temp
687 if key in errors:
688 del errors[key]
689 keys_complete[i] = True
690 anything_accomplished_this_pass = True # We did something!
691 else:
692 msg = "Missing required arguments:"
693 for arg in missing_args:
694 msg += " " + arg + ","
695 msg = msg[:-1]
696 errors[key] = msg
697 self.del_param(key)
698 # Never raise exceptions here, as the arguments might be filled in later
700 # Check whether another pass should be performed
701 any_keys_incomplete = np.any(np.logical_not(keys_complete))
702 go = any_keys_incomplete and anything_accomplished_this_pass
704 # Store missing key-data pairs and exit
705 self._missing_key_data = missing_key_data
706 self._constructor_errors = errors
707 if any_keys_incomplete:
708 msg = "Did not construct these objects:"
709 for i in range(N_keys):
710 if keys_complete[i]:
711 continue
712 msg += " " + keys[i] + ","
713 key = keys[i]
714 if key in backup.keys():
715 setattr(self, key, backup[key])
716 self.parameters[key] = backup[key]
717 msg = msg[:-1]
718 if not force:
719 raise ValueError(msg)
720 return
722 def describe_constructors(self, *args):
723 """
724 Prints to screen a string describing this instance's constructed objects,
725 including their names, the function that constructs them, the names of
726 those functions inputs, and whether those inputs are present.
728 Parameters
729 ----------
730 *args : str, optional
731 Optional list of strings naming constructed inputs to be described.
732 If none are passed, all constructors are described.
734 Returns
735 -------
736 None
737 """
738 if len(args) > 0:
739 keys = args
740 else:
741 keys = list(self.constructors.keys())
742 yes = "\u2713"
743 no = "X"
744 maybe = "*"
745 noyes = [no, yes]
747 out = ""
748 for key in keys:
749 has_val = hasattr(self, key) or (key in self.parameters)
751 # Get the constructor function if possible
752 try:
753 constructor = self.constructors[key]
754 out += (
755 noyes[int(has_val)]
756 + " "
757 + key
758 + " : "
759 + constructor.__name__
760 + "\n"
761 )
762 except:
763 if isinstance(constructor, get_it_from):
764 parent_name = self.constructors[key].name
765 out += (
766 noyes[int(has_val)]
767 + " "
768 + key
769 + " : get it from "
770 + parent_name
771 + "\n"
772 )
773 else:
774 out += noyes[int(has_val)] + " " + key + " : NO CONSTRUCTOR FOUND\n"
775 continue
777 # Get constructor argument names
778 arg_names = get_arg_names(constructor)
779 has_no_default = {
780 k: v.default is inspect.Parameter.empty
781 for k, v in inspect.signature(constructor).parameters.items()
782 }
784 # Check whether each argument exists
785 for j in range(len(arg_names)):
786 this_arg = arg_names[j]
787 if hasattr(self, this_arg) or this_arg in self.parameters:
788 symb = yes
789 elif not has_no_default[this_arg]:
790 symb = maybe
791 else:
792 symb = no
793 out += " " + symb + " " + this_arg + "\n"
795 # Print the string to screen
796 print(out)
797 return
799 # This is a "synonym" method so that old calls to update() still work
800 def update(self, *args):
801 self.construct(*args)
804class AgentType(Model):
805 """
806 A superclass for economic agents in the HARK framework. Each model should
807 specify its own subclass of AgentType, inheriting its methods and overwriting
808 as necessary. Critically, every subclass of AgentType should define class-
809 specific static values of the attributes time_vary and time_inv as lists of
810 strings. Each element of time_vary is the name of a field in AgentSubType
811 that varies over time in the model. Each element of time_inv is the name of
812 a field in AgentSubType that is constant over time in the model.
814 Parameters
815 ----------
816 solution_terminal : Solution
817 A representation of the solution to the terminal period problem of
818 this AgentType instance, or an initial guess of the solution if this
819 is an infinite horizon problem.
820 cycles : int
821 The number of times the sequence of periods is experienced by this
822 AgentType in their "lifetime". cycles=1 corresponds to a lifecycle
823 model, with a certain sequence of one period problems experienced
824 once before terminating. cycles=0 corresponds to an infinite horizon
825 model, with a sequence of one period problems repeating indefinitely.
826 pseudo_terminal : bool
827 Indicates whether solution_terminal isn't actually part of the
828 solution to the problem (as a known solution to the terminal period
829 problem), but instead represents a "scrap value"-style termination.
830 When True, solution_terminal is not included in the solution; when
831 False, solution_terminal is the last element of the solution.
832 tolerance : float
833 Maximum acceptable "distance" between successive solutions to the
834 one period problem in an infinite horizon (cycles=0) model in order
835 for the solution to be considered as having "converged". Inoperative
836 when cycles>0.
837 verbose : int
838 Level of output to be displayed by this instance, default is 1.
839 quiet : bool
840 Indicator for whether this instance should operate "quietly", default False.
841 seed : int
842 A seed for this instance's random number generator.
843 construct : bool
844 Indicator for whether this instance's construct() method should be run
845 when initialized (default True). When False, an instance of the class
846 can be created even if not all of its attributes can be constructed.
847 use_defaults : bool
848 Indicator for whether this instance should use the values in the class'
849 default dictionary to fill in parameters and constructors for those not
850 provided by the user (default True). Setting this to False is useful for
851 situations where the user wants to be absolutely sure that they know what
852 is being passed to the class initializer, without resorting to defaults.
854 Attributes
855 ----------
856 AgentCount : int
857 The number of agents of this type to use in simulation.
859 state_vars : list of string
860 The string labels for this AgentType's model state variables.
861 """
863 time_vary_ = []
864 time_inv_ = []
865 shock_vars_ = []
866 state_vars = []
867 poststate_vars = []
868 distributions = []
869 default_ = {"params": {}, "solver": NullFunc()}
871 def __init__(
872 self,
873 solution_terminal=None,
874 pseudo_terminal=True,
875 tolerance=0.000001,
876 verbose=1,
877 quiet=False,
878 seed=0,
879 construct=True,
880 use_defaults=True,
881 **kwds,
882 ):
883 super().__init__()
884 params = deepcopy(self.default_["params"]) if use_defaults else {}
885 params.update(kwds)
887 # Correctly handle constructors that have been passed in kwds
888 if "constructors" in self.default_["params"].keys() and use_defaults:
889 constructors = deepcopy(self.default_["params"]["constructors"])
890 else:
891 constructors = {}
892 if "constructors" in kwds.keys():
893 constructors.update(kwds["constructors"])
894 params["constructors"] = constructors
896 # Set model file name if possible
897 try:
898 self.model_file = copy(self.default_["model"])
899 except (KeyError, TypeError):
900 # Fallback to None if "model" key is missing or invalid for copying
901 self.model_file = None
903 if solution_terminal is None:
904 solution_terminal = NullFunc()
906 self.solve_one_period = self.default_["solver"] # NOQA
907 self.solution_terminal = solution_terminal # NOQA
908 self.pseudo_terminal = pseudo_terminal # NOQA
909 self.tolerance = tolerance # NOQA
910 self.verbose = verbose
911 self.quiet = quiet
912 set_verbosity_level((4 - verbose) * 10)
913 self.seed = seed # NOQA
914 self.track_vars = [] # NOQA
915 self.state_now = {sv: None for sv in self.state_vars}
916 self.state_prev = self.state_now.copy()
917 self.controls = {}
918 self.shocks = {}
919 self.read_shocks = False # NOQA
920 self.shock_history = {}
921 self.newborn_init_history = {}
922 self.history = {}
923 self.assign_parameters(**params) # NOQA
924 self.reset_rng() # NOQA
925 self.bilt = {}
926 if construct:
927 self.construct()
929 # Add instance-level lists and objects
930 self.time_vary = deepcopy(self.time_vary_)
931 self.time_inv = deepcopy(self.time_inv_)
932 self.shock_vars = deepcopy(self.shock_vars_)
934 def add_to_time_vary(self, *params):
935 """
936 Adds any number of parameters to time_vary for this instance.
938 Parameters
939 ----------
940 params : string
941 Any number of strings naming attributes to be added to time_vary
943 Returns
944 -------
945 None
946 """
947 for param in params:
948 if param not in self.time_vary:
949 self.time_vary.append(param)
951 def add_to_time_inv(self, *params):
952 """
953 Adds any number of parameters to time_inv for this instance.
955 Parameters
956 ----------
957 params : string
958 Any number of strings naming attributes to be added to time_inv
960 Returns
961 -------
962 None
963 """
964 for param in params:
965 if param not in self.time_inv:
966 self.time_inv.append(param)
968 def del_from_time_vary(self, *params):
969 """
970 Removes any number of parameters from time_vary for this instance.
972 Parameters
973 ----------
974 params : string
975 Any number of strings naming attributes to be removed from time_vary
977 Returns
978 -------
979 None
980 """
981 for param in params:
982 if param in self.time_vary:
983 self.time_vary.remove(param)
985 def del_from_time_inv(self, *params):
986 """
987 Removes any number of parameters from time_inv for this instance.
989 Parameters
990 ----------
991 params : string
992 Any number of strings naming attributes to be removed from time_inv
994 Returns
995 -------
996 None
997 """
998 for param in params:
999 if param in self.time_inv:
1000 self.time_inv.remove(param)
1002 def unpack(self, parameter):
1003 """
1004 Unpacks an attribute from a solution object for easier access.
1005 After the model has been solved, its components (like consumption function)
1006 reside in the attributes of each element of `ThisType.solution` (e.g. `cFunc`).
1007 This method creates a (time varying) attribute of the given attribute name
1008 that contains a list of elements accessible by `ThisType.parameter`.
1010 Parameters
1011 ----------
1012 parameter: str
1013 Name of the attribute to unpack from the solution
1015 Returns
1016 -------
1017 none
1018 """
1019 setattr(self, parameter, list())
1020 for solution_t in self.solution:
1021 self.__dict__[parameter].append(solution_t.__dict__[parameter])
1022 self.add_to_time_vary(parameter)
1024 def solve(self, verbose=False, presolve=True, from_solution=None, from_t=None):
1025 """
1026 Solve the model for this instance of an agent type by backward induction.
1027 Loops through the sequence of one period problems, passing the solution
1028 from period t+1 to the problem for period t.
1030 Parameters
1031 ----------
1032 verbose : bool, optional
1033 If True, solution progress is printed to screen. Default False.
1034 presolve : bool, optional
1035 If True (default), the pre_solve method is run before solving.
1036 from_solution: Solution
1037 If different from None, will be used as the starting point of backward
1038 induction, instead of self.solution_terminal.
1039 from_t : int or None
1040 If not None, indicates which period of the model the solver should start
1041 from. It should usually only be used in combination with from_solution.
1042 Stands for the time index that from_solution represents, and thus is
1043 only compatible with cycles=1 and will be reset to None otherwise.
1045 Returns
1046 -------
1047 none
1048 """
1050 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep-
1051 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is
1052 # -np.inf, np.inf/np.inf is np.nan and so on.
1053 with np.errstate(
1054 divide="ignore", over="ignore", under="ignore", invalid="ignore"
1055 ):
1056 if presolve:
1057 self.pre_solve() # Do pre-solution stuff
1058 self.solution = solve_agent(
1059 self,
1060 verbose,
1061 from_solution,
1062 from_t,
1063 ) # Solve the model by backward induction
1064 self.post_solve() # Do post-solution stuff
1066 def reset_rng(self):
1067 """
1068 Reset the random number generator and all distributions for this type.
1069 Type-checking for lists is to handle the following three cases:
1071 1) The target is a single distribution object
1072 2) The target is a list of distribution objects (probably time-varying)
1073 3) The target is a nested list of distributions, as in ConsMarkovModel.
1074 """
1075 self.RNG = np.random.default_rng(self.seed)
1076 for name in self.distributions:
1077 if not hasattr(self, name):
1078 continue
1080 dstn = getattr(self, name)
1081 if isinstance(dstn, list):
1082 for D in dstn:
1083 if isinstance(D, list):
1084 for d in D:
1085 d.reset()
1086 else:
1087 D.reset()
1088 else:
1089 dstn.reset()
1091 def check_elements_of_time_vary_are_lists(self):
1092 """
1093 A method to check that elements of time_vary are lists.
1094 """
1095 for param in self.time_vary:
1096 if not hasattr(self, param):
1097 continue
1098 if not isinstance(
1099 getattr(self, param),
1100 (IndexDistribution,),
1101 ):
1102 assert type(getattr(self, param)) == list, (
1103 param
1104 + " is not a list or time varying distribution,"
1105 + " but should be because it is in time_vary"
1106 )
1108 def check_restrictions(self):
1109 """
1110 A method to check that various restrictions are met for the model class.
1111 """
1112 return
1114 def pre_solve(self):
1115 """
1116 A method that is run immediately before the model is solved, to check inputs or to prepare
1117 the terminal solution, perhaps.
1119 Parameters
1120 ----------
1121 none
1123 Returns
1124 -------
1125 none
1126 """
1127 self.check_restrictions()
1128 self.check_elements_of_time_vary_are_lists()
1129 return None
1131 def post_solve(self):
1132 """
1133 A method that is run immediately after the model is solved, to finalize
1134 the solution in some way. Does nothing here.
1136 Parameters
1137 ----------
1138 none
1140 Returns
1141 -------
1142 none
1143 """
1144 return None
1146 def initialize_sym(self, **kwargs):
1147 """
1148 Use the new simulator structure to build a simulator from the agents'
1149 attributes, storing it in a private attribute.
1150 """
1151 self.reset_rng() # ensure seeds are set identically each time
1152 self._simulator = make_simulator_from_agent(self, **kwargs)
1153 self._simulator.reset()
1155 def initialize_sim(self):
1156 """
1157 Prepares this AgentType for a new simulation. Resets the internal random number generator,
1158 makes initial states for all agents (using sim_birth), clears histories of tracked variables.
1160 Parameters
1161 ----------
1162 None
1164 Returns
1165 -------
1166 None
1167 """
1168 if not hasattr(self, "T_sim"):
1169 raise Exception(
1170 "To initialize simulation variables it is necessary to first "
1171 + "set the attribute T_sim to the largest number of observations "
1172 + "you plan to simulate for each agent including re-births."
1173 )
1174 elif self.T_sim <= 0:
1175 raise Exception(
1176 "T_sim represents the largest number of observations "
1177 + "that can be simulated for an agent, and must be a positive number."
1178 )
1180 self.reset_rng()
1181 self.t_sim = 0
1182 all_agents = np.ones(self.AgentCount, dtype=bool)
1183 blank_array = np.empty(self.AgentCount)
1184 blank_array[:] = np.nan
1185 for var in self.state_now:
1186 if self.state_now[var] is None:
1187 self.state_now[var] = copy(blank_array)
1189 # elif self.state_prev[var] is None:
1190 # self.state_prev[var] = copy(blank_array)
1191 self.t_age = np.zeros(
1192 self.AgentCount, dtype=int
1193 ) # Number of periods since agent entry
1194 self.t_cycle = np.zeros(
1195 self.AgentCount, dtype=int
1196 ) # Which cycle period each agent is on
1197 self.sim_birth(all_agents)
1199 # If we are asked to use existing shocks and a set of initial conditions
1200 # exist, use them
1201 if self.read_shocks and bool(self.newborn_init_history):
1202 for var_name in self.state_now:
1203 # Check that we are actually given a value for the variable
1204 if var_name in self.newborn_init_history.keys():
1205 # Copy only array-like idiosyncratic states. Aggregates should
1206 # not be set by newborns
1207 idio = (
1208 isinstance(self.state_now[var_name], np.ndarray)
1209 and len(self.state_now[var_name]) == self.AgentCount
1210 )
1211 if idio:
1212 self.state_now[var_name] = self.newborn_init_history[var_name][
1213 0
1214 ]
1216 else:
1217 warn(
1218 "The option for reading shocks was activated but "
1219 + "the model requires state "
1220 + var_name
1221 + ", not contained in "
1222 + "newborn_init_history."
1223 )
1225 self.clear_history()
1226 return None
1228 def sim_one_period(self):
1229 """
1230 Simulates one period for this type. Calls the methods get_mortality(), get_shocks() or
1231 read_shocks, get_states(), get_controls(), and get_poststates(). These should be defined for
1232 AgentType subclasses, except get_mortality (define its components sim_death and sim_birth
1233 instead) and read_shocks.
1235 Parameters
1236 ----------
1237 None
1239 Returns
1240 -------
1241 None
1242 """
1243 if not hasattr(self, "solution"):
1244 raise Exception(
1245 "Model instance does not have a solution stored. To simulate, it is necessary"
1246 " to run the `solve()` method first."
1247 )
1249 # Mortality adjusts the agent population
1250 self.get_mortality() # Replace some agents with "newborns"
1252 # state_{t-1}
1253 for var in self.state_now:
1254 self.state_prev[var] = self.state_now[var]
1256 if isinstance(self.state_now[var], np.ndarray):
1257 self.state_now[var] = np.empty(self.AgentCount)
1258 else:
1259 # Probably an aggregate variable. It may be getting set by the Market.
1260 pass
1262 if self.read_shocks: # If shock histories have been pre-specified, use those
1263 self.read_shocks_from_history()
1264 else: # Otherwise, draw shocks as usual according to subclass-specific method
1265 self.get_shocks()
1266 self.get_states() # Determine each agent's state at decision time
1267 self.get_controls() # Determine each agent's choice or control variables based on states
1268 self.get_poststates() # Calculate variables that come *after* decision-time
1270 # Advance time for all agents
1271 self.t_age = self.t_age + 1 # Age all consumers by one period
1272 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle
1273 self.t_cycle[self.t_cycle == self.T_cycle] = (
1274 0 # Resetting to zero for those who have reached the end
1275 )
1277 def make_shock_history(self):
1278 """
1279 Makes a pre-specified history of shocks for the simulation. Shock variables should be named
1280 in self.shock_vars, a list of strings that is subclass-specific. This method runs a subset
1281 of the standard simulation loop by simulating only mortality and shocks; each variable named
1282 in shock_vars is stored in a T_sim x AgentCount array in history dictionary self.history[X].
1283 Automatically sets self.read_shocks to True so that these pre-specified shocks are used for
1284 all subsequent calls to simulate().
1286 Parameters
1287 ----------
1288 None
1290 Returns
1291 -------
1292 None
1293 """
1294 # Re-initialize the simulation
1295 self.initialize_sim()
1297 # Make blank history arrays for each shock variable (and mortality)
1298 for var_name in self.shock_vars:
1299 self.shock_history[var_name] = (
1300 np.zeros((self.T_sim, self.AgentCount)) + np.nan
1301 )
1302 self.shock_history["who_dies"] = np.zeros(
1303 (self.T_sim, self.AgentCount), dtype=bool
1304 )
1306 # Also make blank arrays for the draws of newborns' initial conditions
1307 for var_name in self.state_vars:
1308 self.newborn_init_history[var_name] = (
1309 np.zeros((self.T_sim, self.AgentCount)) + np.nan
1310 )
1312 # Record the initial condition of the newborns created by
1313 # initialize_sim -> sim_births
1314 for var_name in self.state_vars:
1315 # Check whether the state is idiosyncratic or an aggregate
1316 idio = (
1317 isinstance(self.state_now[var_name], np.ndarray)
1318 and len(self.state_now[var_name]) == self.AgentCount
1319 )
1320 if idio:
1321 self.newborn_init_history[var_name][self.t_sim] = self.state_now[
1322 var_name
1323 ]
1324 else:
1325 # Aggregate state is a scalar. Assign it to every agent.
1326 self.newborn_init_history[var_name][self.t_sim, :] = self.state_now[
1327 var_name
1328 ]
1330 # Make and store the history of shocks for each period
1331 for t in range(self.T_sim):
1332 # Deaths
1333 self.get_mortality()
1334 self.shock_history["who_dies"][t, :] = self.who_dies
1336 # Initial conditions of newborns
1337 if self.who_dies.any():
1338 for var_name in self.state_vars:
1339 # Check whether the state is idiosyncratic or an aggregate
1340 idio = (
1341 isinstance(self.state_now[var_name], np.ndarray)
1342 and len(self.state_now[var_name]) == self.AgentCount
1343 )
1344 if idio:
1345 self.newborn_init_history[var_name][t, self.who_dies] = (
1346 self.state_now[var_name][self.who_dies]
1347 )
1348 else:
1349 self.newborn_init_history[var_name][t, self.who_dies] = (
1350 self.state_now[var_name]
1351 )
1353 # Other Shocks
1354 self.get_shocks()
1355 for var_name in self.shock_vars:
1356 self.shock_history[var_name][t, :] = self.shocks[var_name]
1358 self.t_sim += 1
1359 self.t_age = self.t_age + 1 # Age all consumers by one period
1360 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle
1361 self.t_cycle[self.t_cycle == self.T_cycle] = (
1362 0 # Resetting to zero for those who have reached the end
1363 )
1365 # Flag that shocks can be read rather than simulated
1366 self.read_shocks = True
1368 def get_mortality(self):
1369 """
1370 Simulates mortality or agent turnover according to some model-specific rules named sim_death
1371 and sim_birth (methods of an AgentType subclass). sim_death takes no arguments and returns
1372 a Boolean array of size AgentCount, indicating which agents of this type have "died" and
1373 must be replaced. sim_birth takes such a Boolean array as an argument and generates initial
1374 post-decision states for those agent indices.
1376 Parameters
1377 ----------
1378 None
1380 Returns
1381 -------
1382 None
1383 """
1384 if self.read_shocks:
1385 who_dies = self.shock_history["who_dies"][self.t_sim, :]
1386 # Instead of simulating births, assign the saved newborn initial conditions
1387 if who_dies.any():
1388 for var_name in self.state_now:
1389 if var_name in self.newborn_init_history.keys():
1390 # Copy only array-like idiosyncratic states. Aggregates should
1391 # not be set by newborns
1392 idio = (
1393 isinstance(self.state_now[var_name], np.ndarray)
1394 and len(self.state_now[var_name]) == self.AgentCount
1395 )
1396 if idio:
1397 self.state_now[var_name][who_dies] = (
1398 self.newborn_init_history[var_name][
1399 self.t_sim, who_dies
1400 ]
1401 )
1403 else:
1404 warn(
1405 "The option for reading shocks was activated but "
1406 + "the model requires state "
1407 + var_name
1408 + ", not contained in "
1409 + "newborn_init_history."
1410 )
1412 # Reset ages of newborns
1413 self.t_age[who_dies] = 0
1414 self.t_cycle[who_dies] = 0
1415 else:
1416 who_dies = self.sim_death()
1417 self.sim_birth(who_dies)
1418 self.who_dies = who_dies
1419 return None
1421 def sim_death(self):
1422 """
1423 Determines which agents in the current population "die" or should be replaced. Takes no
1424 inputs, returns a Boolean array of size self.AgentCount, which has True for agents who die
1425 and False for those that survive. Returns all False by default, must be overwritten by a
1426 subclass to have replacement events.
1428 Parameters
1429 ----------
1430 None
1432 Returns
1433 -------
1434 who_dies : np.array
1435 Boolean array of size self.AgentCount indicating which agents die and are replaced.
1436 """
1437 who_dies = np.zeros(self.AgentCount, dtype=bool)
1438 return who_dies
1440 def sim_birth(self, which_agents):
1441 """
1442 Makes new agents for the simulation. Takes a boolean array as an input, indicating which
1443 agent indices are to be "born". Does nothing by default, must be overwritten by a subclass.
1445 Parameters
1446 ----------
1447 which_agents : np.array(Bool)
1448 Boolean array of size self.AgentCount indicating which agents should be "born".
1450 Returns
1451 -------
1452 None
1453 """
1454 print("AgentType subclass must define method sim_birth!")
1455 return None
1457 def get_shocks(self):
1458 """
1459 Gets values of shock variables for the current period. Does nothing by default, but can
1460 be overwritten by subclasses of AgentType.
1462 Parameters
1463 ----------
1464 None
1466 Returns
1467 -------
1468 None
1469 """
1470 return None
1472 def read_shocks_from_history(self):
1473 """
1474 Reads values of shock variables for the current period from history arrays.
1475 For each variable X named in self.shock_vars, this attribute of self is
1476 set to self.history[X][self.t_sim,:].
1478 This method is only ever called if self.read_shocks is True. This can
1479 be achieved by using the method make_shock_history() (or manually after
1480 storing a "handcrafted" shock history).
1482 Parameters
1483 ----------
1484 None
1486 Returns
1487 -------
1488 None
1489 """
1490 for var_name in self.shock_vars:
1491 self.shocks[var_name] = self.shock_history[var_name][self.t_sim, :]
1493 def get_states(self):
1494 """
1495 Gets values of state variables for the current period.
1496 By default, calls transition function and assigns values
1497 to the state_now dictionary.
1499 Parameters
1500 ----------
1501 None
1503 Returns
1504 -------
1505 None
1506 """
1507 new_states = self.transition()
1509 for i, var in enumerate(self.state_now):
1510 # a hack for now to deal with 'post-states'
1511 if i < len(new_states):
1512 self.state_now[var] = new_states[i]
1514 def transition(self):
1515 """
1517 Parameters
1518 ----------
1519 None
1521 [Eventually, to match dolo spec:
1522 exogenous_prev, endogenous_prev, controls, exogenous, parameters]
1524 Returns
1525 -------
1527 endogenous_state: ()
1528 Tuple with new values of the endogenous states
1529 """
1530 return ()
1532 def get_controls(self):
1533 """
1534 Gets values of control variables for the current period, probably by using current states.
1535 Does nothing by default, but can be overwritten by subclasses of AgentType.
1537 Parameters
1538 ----------
1539 None
1541 Returns
1542 -------
1543 None
1544 """
1545 return None
1547 def get_poststates(self):
1548 """
1549 Gets values of post-decision state variables for the current period,
1550 probably by current
1551 states and controls and maybe market-level events or shock variables.
1552 Does nothing by
1553 default, but can be overwritten by subclasses of AgentType.
1555 Parameters
1556 ----------
1557 None
1559 Returns
1560 -------
1561 None
1562 """
1563 return None
1565 def symulate(self, T=None):
1566 """
1567 Run the new simulation structure, with history results written to the
1568 hystory attribute of self.
1569 """
1570 self._simulator.simulate(T)
1571 self.hystory = self._simulator.history
1573 def describe_model(self, display=True):
1574 """
1575 Print to screen information about this agent's model, based on its model
1576 file. This is useful for learning about outcome variable names for tracking
1577 during simulation, or for use with sequence space Jacobians.
1578 """
1579 if not hasattr(self, "_simulator"):
1580 self.initialize_sym()
1581 self._simulator.describe(display=display)
1583 def simulate(self, sim_periods=None):
1584 """
1585 Simulates this agent type for a given number of periods. Defaults to self.T_sim,
1586 or all remaining periods to simulate (T_sim - t_sim). Records histories of
1587 attributes named in self.track_vars in self.history[varname].
1589 Parameters
1590 ----------
1591 sim_periods : int or None
1592 Number of periods to simulate. Default is all remaining periods (usually T_sim).
1594 Returns
1595 -------
1596 history : dict
1597 The history tracked during the simulation.
1598 """
1599 if not hasattr(self, "t_sim"):
1600 raise Exception(
1601 "It seems that the simulation variables were not initialize before calling "
1602 + "simulate(). Call initialize_sim() to initialize the variables before calling simulate() again."
1603 )
1605 if not hasattr(self, "T_sim"):
1606 raise Exception(
1607 "This agent type instance must have the attribute T_sim set to a positive integer."
1608 + "Set T_sim to match the largest dataset you might simulate, and run this agent's"
1609 + "initialize_sim() method before running simulate() again."
1610 )
1612 if sim_periods is not None and self.T_sim < sim_periods:
1613 raise Exception(
1614 "To simulate, sim_periods has to be larger than the maximum data set size "
1615 + "T_sim. Either increase the attribute T_sim of this agent type instance "
1616 + "and call the initialize_sim() method again, or set sim_periods <= T_sim."
1617 )
1619 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep-
1620 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is
1621 # -np.inf, np.inf/np.inf is np.nan and so on.
1622 with np.errstate(
1623 divide="ignore", over="ignore", under="ignore", invalid="ignore"
1624 ):
1625 if sim_periods is None:
1626 sim_periods = self.T_sim - self.t_sim
1628 for t in range(sim_periods):
1629 self.sim_one_period()
1631 for var_name in self.track_vars:
1632 if var_name in self.state_now:
1633 self.history[var_name][self.t_sim, :] = self.state_now[var_name]
1634 elif var_name in self.shocks:
1635 self.history[var_name][self.t_sim, :] = self.shocks[var_name]
1636 elif var_name in self.controls:
1637 self.history[var_name][self.t_sim, :] = self.controls[var_name]
1638 else:
1639 if var_name == "who_dies" and self.t_sim > 1:
1640 self.history[var_name][self.t_sim - 1, :] = getattr(
1641 self, var_name
1642 )
1643 else:
1644 self.history[var_name][self.t_sim, :] = getattr(
1645 self, var_name
1646 )
1647 self.t_sim += 1
1649 def clear_history(self):
1650 """
1651 Clears the histories of the attributes named in self.track_vars.
1653 Parameters
1654 ----------
1655 None
1657 Returns
1658 -------
1659 None
1660 """
1661 for var_name in self.track_vars:
1662 self.history[var_name] = np.empty((self.T_sim, self.AgentCount))
1663 self.history[var_name].fill(np.nan)
1665 def make_basic_SSJ(self, shock, outcomes, grids, **kwargs):
1666 """
1667 Construct and return sequence space Jacobian matrices for specified outcomes
1668 with respect to specified "shock" variable. This "basic" method only works
1669 for "one period infinite horizon" models (cycles=0, T_cycle=1). See documen-
1670 tation for simulator.make_basic_SSJ_matrices for more information.
1671 """
1672 return make_basic_SSJ_matrices(self, shock, outcomes, grids, **kwargs)
1674 def calc_impulse_response_manually(self, shock, outcomes, grids, **kwargs):
1675 """
1676 Calculate and return the impulse response(s) of a perturbation to the shock
1677 parameter in period t=s, essentially computing one column of the sequence
1678 space Jacobian matrix manually. This "basic" method only works for "one
1679 period infinite horizon" models (cycles=0, T_cycle=1). See documentation
1680 for simulator.calc_shock_response_manually for more information.
1681 """
1682 return calc_shock_response_manually(self, shock, outcomes, grids, **kwargs)
1685def solve_agent(agent, verbose, from_solution=None, from_t=None):
1686 """
1687 Solve the dynamic model for one agent type using backwards induction. This
1688 function iterates on "cycles" of an agent's model either a given number of
1689 times or until solution convergence if an infinite horizon model is used
1690 (with agent.cycles = 0).
1692 Parameters
1693 ----------
1694 agent : AgentType
1695 The microeconomic AgentType whose dynamic problem
1696 is to be solved.
1697 verbose : boolean
1698 If True, solution progress is printed to screen (when cycles != 1).
1699 from_solution: Solution
1700 If different from None, will be used as the starting point of backward
1701 induction, instead of self.solution_terminal
1702 from_t : int or None
1703 If not None, indicates which period of the model the solver should start
1704 from. It should usually only be used in combination with from_solution.
1705 Stands for the time index that from_solution represents, and thus is
1706 only compatible with cycles=1 and will be reset to None otherwise.
1708 Returns
1709 -------
1710 solution : [Solution]
1711 A list of solutions to the one period problems that the agent will
1712 encounter in his "lifetime".
1713 """
1714 # Check to see whether this is an (in)finite horizon problem
1715 cycles_left = agent.cycles # NOQA
1716 infinite_horizon = cycles_left == 0 # NOQA
1718 if from_solution is None:
1719 solution_last = agent.solution_terminal # NOQA
1720 else:
1721 solution_last = from_solution
1722 if agent.cycles != 1:
1723 from_t = None
1725 # Initialize the solution, which includes the terminal solution if it's not a pseudo-terminal period
1726 solution = []
1727 if not agent.pseudo_terminal:
1728 solution.insert(0, deepcopy(solution_last))
1730 # Initialize the process, then loop over cycles
1731 go = True # NOQA
1732 completed_cycles = 0 # NOQA
1733 max_cycles = 5000 # NOQA - escape clause
1734 if verbose:
1735 t_last = time()
1736 while go:
1737 # Solve a cycle of the model, recording it if horizon is finite
1738 solution_cycle = solve_one_cycle(agent, solution_last, from_t)
1739 if not infinite_horizon:
1740 solution = solution_cycle + solution
1742 # Check for termination: identical solutions across
1743 # cycle iterations or run out of cycles
1744 solution_now = solution_cycle[0]
1745 if infinite_horizon:
1746 if completed_cycles > 0:
1747 solution_distance = solution_now.distance(solution_last)
1748 agent.solution_distance = (
1749 solution_distance # Add these attributes so users can
1750 )
1751 agent.completed_cycles = (
1752 completed_cycles # query them to see if solution is ready
1753 )
1754 go = (
1755 solution_distance > agent.tolerance
1756 and completed_cycles < max_cycles
1757 )
1758 else: # Assume solution does not converge after only one cycle
1759 solution_distance = 100.0
1760 go = True
1761 else:
1762 cycles_left += -1
1763 go = cycles_left > 0
1765 # Update the "last period solution"
1766 solution_last = solution_now
1767 completed_cycles += 1
1769 # Display progress if requested
1770 if verbose:
1771 t_now = time()
1772 if infinite_horizon:
1773 print(
1774 "Finished cycle #"
1775 + str(completed_cycles)
1776 + " in "
1777 + str(t_now - t_last)
1778 + " seconds, solution distance = "
1779 + str(solution_distance)
1780 )
1781 else:
1782 print(
1783 "Finished cycle #"
1784 + str(completed_cycles)
1785 + " of "
1786 + str(agent.cycles)
1787 + " in "
1788 + str(t_now - t_last)
1789 + " seconds."
1790 )
1791 t_last = t_now
1793 # Record the last cycle if horizon is infinite (solution is still empty!)
1794 if infinite_horizon:
1795 solution = (
1796 solution_cycle # PseudoTerminal=False impossible for infinite horizon
1797 )
1799 return solution
1802def solve_one_cycle(agent, solution_last, from_t):
1803 """
1804 Solve one "cycle" of the dynamic model for one agent type. This function
1805 iterates over the periods within an agent's cycle, updating the time-varying
1806 parameters and passing them to the single period solver(s).
1808 Parameters
1809 ----------
1810 agent : AgentType
1811 The microeconomic AgentType whose dynamic problem is to be solved.
1812 solution_last : Solution
1813 A representation of the solution of the period that comes after the
1814 end of the sequence of one period problems. This might be the term-
1815 inal period solution, a "pseudo terminal" solution, or simply the
1816 solution to the earliest period from the succeeding cycle.
1817 from_t : int or None
1818 If not None, indicates which period of the model the solver should start
1819 from. When used, represents the time index that solution_last is from.
1821 Returns
1822 -------
1823 solution_cycle : [Solution]
1824 A list of one period solutions for one "cycle" of the AgentType's
1825 microeconomic model.
1826 """
1828 # Check if the agent has a 'Parameters' attribute of the 'Parameters' class
1829 # if so, take advantage of it. Else, use the old method
1830 if hasattr(agent, "params") and isinstance(agent.params, Parameters):
1831 T = agent.params._length if from_t is None else from_t
1833 # Initialize the solution for this cycle, then iterate on periods
1834 solution_cycle = []
1835 solution_next = solution_last
1837 cycles_range = [0] + list(range(T - 1, 0, -1))
1838 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range:
1839 # Update which single period solver to use (if it depends on time)
1840 if hasattr(agent.solve_one_period, "__getitem__"):
1841 solve_one_period = agent.solve_one_period[k]
1842 else:
1843 solve_one_period = agent.solve_one_period
1845 if hasattr(solve_one_period, "solver_args"):
1846 these_args = solve_one_period.solver_args
1847 else:
1848 these_args = get_arg_names(solve_one_period)
1850 # Make a temporary dictionary for this period
1851 temp_pars = agent.params[k]
1852 temp_dict = {
1853 name: solution_next if name == "solution_next" else temp_pars[name]
1854 for name in these_args
1855 }
1857 # Solve one period, add it to the solution, and move to the next period
1858 solution_t = solve_one_period(**temp_dict)
1859 solution_cycle.insert(0, solution_t)
1860 solution_next = solution_t
1862 else:
1863 # Calculate number of periods per cycle, defaults to 1 if all variables are time invariant
1864 if len(agent.time_vary) > 0:
1865 T = agent.T_cycle if from_t is None else from_t
1866 else:
1867 T = 1
1869 solve_dict = {
1870 parameter: agent.__dict__[parameter] for parameter in agent.time_inv
1871 }
1872 solve_dict.update({parameter: None for parameter in agent.time_vary})
1874 # Initialize the solution for this cycle, then iterate on periods
1875 solution_cycle = []
1876 solution_next = solution_last
1878 cycles_range = [0] + list(range(T - 1, 0, -1))
1879 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range:
1880 # Update which single period solver to use (if it depends on time)
1881 if hasattr(agent.solve_one_period, "__getitem__"):
1882 solve_one_period = agent.solve_one_period[k]
1883 else:
1884 solve_one_period = agent.solve_one_period
1886 if hasattr(solve_one_period, "solver_args"):
1887 these_args = solve_one_period.solver_args
1888 else:
1889 these_args = get_arg_names(solve_one_period)
1891 # Update time-varying single period inputs
1892 for name in agent.time_vary:
1893 if name in these_args:
1894 solve_dict[name] = agent.__dict__[name][k]
1895 solve_dict["solution_next"] = solution_next
1897 # Make a temporary dictionary for this period
1898 temp_dict = {name: solve_dict[name] for name in these_args}
1900 # Solve one period, add it to the solution, and move to the next period
1901 solution_t = solve_one_period(**temp_dict)
1902 solution_cycle.insert(0, solution_t)
1903 solution_next = solution_t
1905 # Return the list of per-period solutions
1906 return solution_cycle
1909def make_one_period_oo_solver(solver_class):
1910 """
1911 Returns a function that solves a single period consumption-saving
1912 problem.
1913 Parameters
1914 ----------
1915 solver_class : Solver
1916 A class of Solver to be used.
1917 -------
1918 solver_function : function
1919 A function for solving one period of a problem.
1920 """
1922 def one_period_solver(**kwds):
1923 solver = solver_class(**kwds)
1925 # not ideal; better if this is defined in all Solver classes
1926 if hasattr(solver, "prepare_to_solve"):
1927 solver.prepare_to_solve()
1929 solution_now = solver.solve()
1930 return solution_now
1932 one_period_solver.solver_class = solver_class
1933 # This can be revisited once it is possible to export parameters
1934 one_period_solver.solver_args = get_arg_names(solver_class.__init__)[1:]
1936 return one_period_solver
1939# ========================================================================
1940# ========================================================================
1943class Market(Model):
1944 """
1945 A superclass to represent a central clearinghouse of information. Used for
1946 dynamic general equilibrium models to solve the "macroeconomic" model as a
1947 layer on top of the "microeconomic" models of one or more AgentTypes.
1949 Parameters
1950 ----------
1951 agents : [AgentType]
1952 A list of all the AgentTypes in this market.
1953 sow_vars : [string]
1954 Names of variables generated by the "aggregate market process" that should
1955 be "sown" to the agents in the market. Aggregate state, etc.
1956 reap_vars : [string]
1957 Names of variables to be collected ("reaped") from agents in the market
1958 to be used in the "aggregate market process".
1959 const_vars : [string]
1960 Names of attributes of the Market instance that are used in the "aggregate
1961 market process" but do not come from agents-- they are constant or simply
1962 parameters inherent to the process.
1963 track_vars : [string]
1964 Names of variables generated by the "aggregate market process" that should
1965 be tracked as a "history" so that a new dynamic rule can be calculated.
1966 This is often a subset of sow_vars.
1967 dyn_vars : [string]
1968 Names of variables that constitute a "dynamic rule".
1969 mill_rule : function
1970 A function that takes inputs named in reap_vars and returns a tuple the
1971 same size and order as sow_vars. The "aggregate market process" that
1972 transforms individual agent actions/states/data into aggregate data to
1973 be sent back to agents.
1974 calc_dynamics : function
1975 A function that takes inputs named in track_vars and returns an object
1976 with attributes named in dyn_vars. Looks at histories of aggregate
1977 variables and generates a new "dynamic rule" for agents to believe and
1978 act on.
1979 act_T : int
1980 The number of times that the "aggregate market process" should be run
1981 in order to generate a history of aggregate variables.
1982 tolerance: float
1983 Minimum acceptable distance between "dynamic rules" to consider the
1984 Market solution process converged. Distance is a user-defined metric.
1985 """
1987 def __init__(
1988 self,
1989 agents=None,
1990 sow_vars=None,
1991 reap_vars=None,
1992 const_vars=None,
1993 track_vars=None,
1994 dyn_vars=None,
1995 mill_rule=None,
1996 calc_dynamics=None,
1997 act_T=1000,
1998 tolerance=0.000001,
1999 **kwds,
2000 ):
2001 super().__init__()
2002 self.agents = agents if agents is not None else list() # NOQA
2004 self.reap_vars = reap_vars if reap_vars is not None else list() # NOQA
2005 self.reap_state = {var: [] for var in self.reap_vars}
2007 self.sow_vars = sow_vars if sow_vars is not None else list() # NOQA
2008 # dictionaries for tracking initial and current values
2009 # of the sow variables.
2010 self.sow_init = {var: None for var in self.sow_vars}
2011 self.sow_state = {var: None for var in self.sow_vars}
2013 const_vars = const_vars if const_vars is not None else list() # NOQA
2014 self.const_vars = {var: None for var in const_vars}
2016 self.track_vars = track_vars if track_vars is not None else list() # NOQA
2017 self.dyn_vars = dyn_vars if dyn_vars is not None else list() # NOQA
2019 if mill_rule is not None: # To prevent overwriting of method-based mill_rules
2020 self.mill_rule = mill_rule
2021 if calc_dynamics is not None: # Ditto for calc_dynamics
2022 self.calc_dynamics = calc_dynamics
2023 self.act_T = act_T # NOQA
2024 self.tolerance = tolerance # NOQA
2025 self.max_loops = 1000 # NOQA
2026 self.history = {}
2027 self.assign_parameters(**kwds)
2029 self.print_parallel_error_once = True
2030 # Print the error associated with calling the parallel method
2031 # "solve_agents" one time. If set to false, the error will never
2032 # print. See "solve_agents" for why this prints once or never.
2034 def solve_agents(self):
2035 """
2036 Solves the microeconomic problem for all AgentTypes in this market.
2038 Parameters
2039 ----------
2040 None
2042 Returns
2043 -------
2044 None
2045 """
2046 try:
2047 multi_thread_commands(self.agents, ["solve()"])
2048 except Exception as err:
2049 if self.print_parallel_error_once:
2050 # Set flag to False so this is only printed once.
2051 self.print_parallel_error_once = False
2052 print(
2053 "**** WARNING: could not execute multi_thread_commands in HARK.core.Market.solve_agents() ",
2054 "so using the serial version instead. This will likely be slower. "
2055 "The multiTreadCommands() functions failed with the following error:",
2056 "\n",
2057 sys.exc_info()[0],
2058 ":",
2059 err,
2060 ) # sys.exc_info()[0])
2061 multi_thread_commands_fake(self.agents, ["solve()"])
2063 def solve(self):
2064 """
2065 "Solves" the market by finding a "dynamic rule" that governs the aggregate
2066 market state such that when agents believe in these dynamics, their actions
2067 collectively generate the same dynamic rule.
2069 Parameters
2070 ----------
2071 None
2073 Returns
2074 -------
2075 None
2076 """
2077 go = True
2078 max_loops = self.max_loops # Failsafe against infinite solution loop
2079 completed_loops = 0
2080 old_dynamics = None
2082 while go: # Loop until the dynamic process converges or we hit the loop cap
2083 self.solve_agents() # Solve each AgentType's micro problem
2084 self.make_history() # "Run" the model while tracking aggregate variables
2085 new_dynamics = self.update_dynamics() # Find a new aggregate dynamic rule
2087 # Check to see if the dynamic rule has converged (if this is not the first loop)
2088 if completed_loops > 0:
2089 distance = new_dynamics.distance(old_dynamics)
2090 else:
2091 distance = 1000000.0
2093 # Move to the next loop if the terminal conditions are not met
2094 old_dynamics = new_dynamics
2095 completed_loops += 1
2096 go = distance >= self.tolerance and completed_loops < max_loops
2098 self.dynamics = new_dynamics # Store the final dynamic rule in self
2100 def reap(self):
2101 """
2102 Collects attributes named in reap_vars from each AgentType in the market,
2103 storing them in respectively named attributes of self.
2105 Parameters
2106 ----------
2107 none
2109 Returns
2110 -------
2111 none
2112 """
2113 for var in self.reap_state:
2114 harvest = []
2116 for agent in self.agents:
2117 # TODO: generalized variable lookup across namespaces
2118 if var in agent.state_now:
2119 # or state_now ??
2120 harvest.append(agent.state_now[var])
2122 self.reap_state[var] = harvest
2124 def sow(self):
2125 """
2126 Distributes attrributes named in sow_vars from self to each AgentType
2127 in the market, storing them in respectively named attributes.
2129 Parameters
2130 ----------
2131 none
2133 Returns
2134 -------
2135 none
2136 """
2137 for sow_var in self.sow_state:
2138 for this_type in self.agents:
2139 if sow_var in this_type.state_now:
2140 this_type.state_now[sow_var] = self.sow_state[sow_var]
2141 if sow_var in this_type.shocks:
2142 this_type.shocks[sow_var] = self.sow_state[sow_var]
2143 else:
2144 setattr(this_type, sow_var, self.sow_state[sow_var])
2146 def mill(self):
2147 """
2148 Processes the variables collected from agents using the function mill_rule,
2149 storing the results in attributes named in aggr_sow.
2151 Parameters
2152 ----------
2153 none
2155 Returns
2156 -------
2157 none
2158 """
2159 # Make a dictionary of inputs for the mill_rule
2160 mill_dict = copy(self.reap_state)
2161 mill_dict.update(self.const_vars)
2163 # Run the mill_rule and store its output in self
2164 product = self.mill_rule(**mill_dict)
2166 for i, sow_var in enumerate(self.sow_state):
2167 self.sow_state[sow_var] = product[i]
2169 def cultivate(self):
2170 """
2171 Has each AgentType in agents perform their market_action method, using
2172 variables sown from the market (and maybe also "private" variables).
2173 The market_action method should store new results in attributes named in
2174 reap_vars to be reaped later.
2176 Parameters
2177 ----------
2178 none
2180 Returns
2181 -------
2182 none
2183 """
2184 for this_type in self.agents:
2185 this_type.market_action()
2187 def reset(self):
2188 """
2189 Reset the state of the market (attributes in sow_vars, etc) to some
2190 user-defined initial state, and erase the histories of tracked variables.
2192 Parameters
2193 ----------
2194 none
2196 Returns
2197 -------
2198 none
2199 """
2200 # Reset the history of tracked variables
2201 self.history = {var_name: [] for var_name in self.track_vars}
2203 # Set the sow variables to their initial levels
2204 for var_name in self.sow_state:
2205 self.sow_state[var_name] = self.sow_init[var_name]
2207 # Reset each AgentType in the market
2208 for this_type in self.agents:
2209 this_type.reset()
2211 def store(self):
2212 """
2213 Record the current value of each variable X named in track_vars in an
2214 dictionary field named history[X].
2216 Parameters
2217 ----------
2218 none
2220 Returns
2221 -------
2222 none
2223 """
2224 for var_name in self.track_vars:
2225 if var_name in self.sow_state:
2226 value_now = self.sow_state[var_name]
2227 elif var_name in self.reap_state:
2228 value_now = self.reap_state[var_name]
2229 elif var_name in self.const_vars:
2230 value_now = self.const_vars[var_name]
2231 else:
2232 value_now = getattr(self, var_name)
2234 self.history[var_name].append(value_now)
2236 def make_history(self):
2237 """
2238 Runs a loop of sow-->cultivate-->reap-->mill act_T times, tracking the
2239 evolution of variables X named in track_vars in dictionary fields
2240 history[X].
2242 Parameters
2243 ----------
2244 none
2246 Returns
2247 -------
2248 none
2249 """
2250 self.reset() # Initialize the state of the market
2251 for t in range(self.act_T):
2252 self.sow() # Distribute aggregated information/state to agents
2253 self.cultivate() # Agents take action
2254 self.reap() # Collect individual data from agents
2255 self.mill() # Process individual data into aggregate data
2256 self.store() # Record variables of interest
2258 def update_dynamics(self):
2259 """
2260 Calculates a new "aggregate dynamic rule" using the history of variables
2261 named in track_vars, and distributes this rule to AgentTypes in agents.
2263 Parameters
2264 ----------
2265 none
2267 Returns
2268 -------
2269 dynamics : instance
2270 The new "aggregate dynamic rule" that agents believe in and act on.
2271 Should have attributes named in dyn_vars.
2272 """
2273 # Make a dictionary of inputs for the dynamics calculator
2274 arg_names = list(get_arg_names(self.calc_dynamics))
2275 if "self" in arg_names:
2276 arg_names.remove("self")
2277 update_dict = {name: self.history[name] for name in arg_names}
2278 # Calculate a new dynamic rule and distribute it to the agents in agent_list
2279 dynamics = self.calc_dynamics(**update_dict) # User-defined dynamics calculator
2280 for var_name in self.dyn_vars:
2281 this_obj = getattr(dynamics, var_name)
2282 for this_type in self.agents:
2283 setattr(this_type, var_name, this_obj)
2284 return dynamics
2287def distribute_params(agent, param_name, param_count, distribution):
2288 """
2289 Distributes heterogeneous values of one parameter to the AgentTypes in self.agents.
2290 Parameters
2291 ----------
2292 agent: AgentType
2293 An agent to clone.
2294 param_name : string
2295 Name of the parameter to be assigned.
2296 param_count : int
2297 Number of different values the parameter will take on.
2298 distribution : Distribution
2299 A 1-D distribution.
2301 Returns
2302 -------
2303 agent_set : [AgentType]
2304 A list of param_count agents, ex ante heterogeneous with
2305 respect to param_name. The AgentCount of the original
2306 will be split between the agents of the returned
2307 list in proportion to the given distribution.
2308 """
2309 param_dist = distribution.discretize(N=param_count)
2311 agent_set = [deepcopy(agent) for i in range(param_count)]
2313 for j in range(param_count):
2314 agent_set[j].assign_parameters(
2315 **{"AgentCount": int(agent.AgentCount * param_dist.pmv[j])}
2316 )
2317 # agent_set[j].__dict__[param_name] = param_dist.atoms[j]
2319 agent_set[j].assign_parameters(**{param_name: param_dist.atoms[0, j]})
2321 return agent_set
2324@dataclass
2325class AgentPopulation:
2326 """
2327 A class for representing a population of ex-ante heterogeneous agents.
2328 """
2330 agent_type: AgentType # type of agent in the population
2331 parameters: dict # dictionary of parameters
2332 seed: int = 0 # random seed
2333 time_var: List[str] = field(init=False)
2334 time_inv: List[str] = field(init=False)
2335 distributed_params: List[str] = field(init=False)
2336 agent_type_count: Optional[int] = field(init=False)
2337 term_age: Optional[int] = field(init=False)
2338 continuous_distributions: Dict[str, Distribution] = field(init=False)
2339 discrete_distributions: Dict[str, Distribution] = field(init=False)
2340 population_parameters: List[Dict[str, Union[List[float], float]]] = field(
2341 init=False
2342 )
2343 agents: List[AgentType] = field(init=False)
2344 agent_database: pd.DataFrame = field(init=False)
2345 solution: List[Any] = field(init=False)
2347 def __post_init__(self):
2348 """
2349 Initialize the population of agents, determine distributed parameters,
2350 and infer `agent_type_count` and `term_age`.
2351 """
2352 # create a dummy agent and obtain its time-varying
2353 # and time-invariant attributes
2354 dummy_agent = self.agent_type()
2355 self.time_var = dummy_agent.time_vary
2356 self.time_inv = dummy_agent.time_inv
2358 # create list of distributed parameters
2359 # these are parameters that differ across agents
2360 self.distributed_params = [
2361 key
2362 for key, param in self.parameters.items()
2363 if (isinstance(param, list) and isinstance(param[0], list))
2364 or isinstance(param, Distribution)
2365 or (isinstance(param, DataArray) and param.dims[0] == "agent")
2366 ]
2368 self.__infer_counts__()
2370 self.print_parallel_error_once = True
2371 # Print warning once if parallel simulation fails
2373 def __infer_counts__(self):
2374 """
2375 Infer `agent_type_count` and `term_age` from the parameters.
2376 If parameters include a `Distribution` type, a list of lists,
2377 or a `DataArray` with `agent` as the first dimension, then
2378 the AgentPopulation contains ex-ante heterogenous agents.
2379 """
2381 # infer agent_type_count from distributed parameters
2382 agent_type_count = 1
2383 for key in self.distributed_params:
2384 param = self.parameters[key]
2385 if isinstance(param, Distribution):
2386 agent_type_count = None
2387 warn(
2388 "Cannot infer agent_type_count from a Distribution. "
2389 "Please provide approximation parameters."
2390 )
2391 break
2392 elif isinstance(param, list):
2393 agent_type_count = max(agent_type_count, len(param))
2394 elif isinstance(param, DataArray) and param.dims[0] == "agent":
2395 agent_type_count = max(agent_type_count, param.shape[0])
2397 self.agent_type_count = agent_type_count
2399 # infer term_age from all parameters
2400 term_age = 1
2401 for param in self.parameters.values():
2402 if isinstance(param, Distribution):
2403 term_age = None
2404 warn(
2405 "Cannot infer term_age from a Distribution. "
2406 "Please provide approximation parameters."
2407 )
2408 break
2409 elif isinstance(param, list) and isinstance(param[0], list):
2410 term_age = max(term_age, len(param[0]))
2411 elif isinstance(param, DataArray) and param.dims[-1] == "age":
2412 term_age = max(term_age, param.shape[-1])
2414 self.term_age = term_age
2416 def approx_distributions(self, approx_params: dict):
2417 """
2418 Approximate continuous distributions with discrete ones. If the initial
2419 parameters include a `Distribution` type, then the AgentPopulation is
2420 not ready to solve, and stands for an abstract population. To solve the
2421 AgentPopulation, we need discretization parameters for each continuous
2422 distribution. This method approximates the continuous distributions with
2423 discrete ones, and updates the parameters dictionary.
2424 """
2425 self.continuous_distributions = {}
2426 self.discrete_distributions = {}
2428 for key, args in approx_params.items():
2429 param = self.parameters[key]
2430 if key in self.distributed_params and isinstance(param, Distribution):
2431 self.continuous_distributions[key] = param
2432 self.discrete_distributions[key] = param.discretize(**args)
2433 else:
2434 raise ValueError(
2435 f"Warning: parameter {key} is not a Distribution found "
2436 f"in agent type {self.agent_type}"
2437 )
2439 if len(self.discrete_distributions) > 1:
2440 joint_dist = combine_indep_dstns(*self.discrete_distributions.values())
2441 else:
2442 joint_dist = list(self.discrete_distributions.values())[0]
2444 for i, key in enumerate(self.discrete_distributions):
2445 self.parameters[key] = DataArray(joint_dist.atoms[i], dims=("agent"))
2447 self.__infer_counts__()
2449 def __parse_parameters__(self) -> None:
2450 """
2451 Creates distributed dictionaries of parameters for each ex-ante
2452 heterogeneous agent in the parameterized population. The parameters
2453 are stored in a list of dictionaries, where each dictionary contains
2454 the parameters for one agent. Expands parameters that vary over time
2455 to a list of length `term_age`.
2456 """
2458 population_parameters = [] # container for dictionaries of each agent subgroup
2459 for agent in range(self.agent_type_count):
2460 agent_parameters = {}
2461 for key, param in self.parameters.items():
2462 if key in self.time_var:
2463 # parameters that vary over time have to be repeated
2464 if isinstance(param, (int, float)):
2465 parameter_per_t = [param] * self.term_age
2466 elif isinstance(param, list):
2467 if isinstance(param[0], list):
2468 parameter_per_t = param[agent]
2469 else:
2470 parameter_per_t = param
2471 elif isinstance(param, DataArray):
2472 if param.dims[0] == "agent":
2473 if param.dims[-1] == "age":
2474 parameter_per_t = param[agent].item()
2475 else:
2476 parameter_per_t = param.item()
2477 elif param.dims[0] == "age":
2478 parameter_per_t = param.item()
2480 agent_parameters[key] = parameter_per_t
2482 elif key in self.time_inv:
2483 if isinstance(param, (int, float)):
2484 agent_parameters[key] = param
2485 elif isinstance(param, list):
2486 if isinstance(param[0], list):
2487 agent_parameters[key] = param[agent]
2488 else:
2489 agent_parameters[key] = param
2490 elif isinstance(param, DataArray) and param.dims[0] == "agent":
2491 agent_parameters[key] = param[agent].item()
2493 else:
2494 if isinstance(param, (int, float)):
2495 agent_parameters[key] = param # assume time inv
2496 elif isinstance(param, list):
2497 if isinstance(param[0], list):
2498 agent_parameters[key] = param[agent] # assume agent vary
2499 else:
2500 agent_parameters[key] = param # assume time vary
2501 elif isinstance(param, DataArray):
2502 if param.dims[0] == "agent":
2503 if param.dims[-1] == "age":
2504 agent_parameters[key] = param[
2505 agent
2506 ].item() # assume agent vary
2507 else:
2508 agent_parameters[key] = param.item() # assume time vary
2509 elif param.dims[0] == "age":
2510 agent_parameters[key] = param.item() # assume time vary
2512 population_parameters.append(agent_parameters)
2514 self.population_parameters = population_parameters
2516 def create_distributed_agents(self):
2517 """
2518 Parses the parameters dictionary and creates a list of agents with the
2519 appropriate parameters. Also sets the seed for each agent.
2520 """
2522 self.__parse_parameters__()
2524 rng = np.random.default_rng(self.seed)
2526 self.agents = [
2527 self.agent_type(seed=rng.integers(0, 2**31 - 1), **agent_dict)
2528 for agent_dict in self.population_parameters
2529 ]
2531 def create_database(self):
2532 """
2533 Optionally creates a pandas DataFrame with the parameters for each agent.
2534 """
2535 database = pd.DataFrame(self.population_parameters)
2536 database["agents"] = self.agents
2538 self.agent_database = database
2540 def solve(self):
2541 """
2542 Solves each agent of the population serially.
2543 """
2545 # see Market class for an example of how to solve distributed agents in parallel
2547 for agent in self.agents:
2548 agent.solve()
2550 def unpack_solutions(self):
2551 """
2552 Unpacks the solutions of each agent into an attribute of the population.
2553 """
2554 self.solution = [agent.solution for agent in self.agents]
2556 def initialize_sim(self):
2557 """
2558 Initializes the simulation for each agent.
2559 """
2560 for agent in self.agents:
2561 agent.initialize_sim()
2563 def simulate(self, num_jobs=None):
2564 """
2565 Simulates each agent of the population.
2567 Parameters
2568 ----------
2569 num_jobs : int, optional
2570 Number of parallel jobs to use. Defaults to using all available
2571 cores when ``None``. Falls back to serial execution if parallel
2572 processing fails.
2573 """
2574 try:
2575 multi_thread_commands(self.agents, ["simulate()"], num_jobs)
2576 except Exception as err:
2577 if getattr(self, "print_parallel_error_once", False):
2578 self.print_parallel_error_once = False
2579 print(
2580 "**** WARNING: could not execute multi_thread_commands in HARK.core.AgentPopulation.simulate() ",
2581 "so using the serial version instead. This will likely be slower. ",
2582 "The multi_thread_commands() function failed with the following error:\n",
2583 sys.exc_info()[0],
2584 ":",
2585 err,
2586 )
2587 multi_thread_commands_fake(self.agents, ["simulate()"], num_jobs)
2589 def __iter__(self):
2590 """
2591 Allows for iteration over the agents in the population.
2592 """
2593 return iter(self.agents)
2595 def __getitem__(self, idx):
2596 """
2597 Allows for indexing into the population.
2598 """
2599 return self.agents[idx]
2602###############################################################################
2605def multi_thread_commands_fake(
2606 agent_list: List, command_list: List, num_jobs=None
2607) -> None:
2608 """
2609 Executes the list of commands in command_list for each AgentType in agent_list
2610 in an ordinary, single-threaded loop. Each command should be a method of
2611 that AgentType subclass. This function exists so as to easily disable
2612 multithreading, as it uses the same syntax as multi_thread_commands.
2614 Parameters
2615 ----------
2616 agent_list : [AgentType]
2617 A list of instances of AgentType on which the commands will be run.
2618 command_list : [string]
2619 A list of commands to run for each AgentType.
2620 num_jobs : None
2621 Dummy input to match syntax of multi_thread_commands. Does nothing.
2623 Returns
2624 -------
2625 none
2626 """
2627 for agent in agent_list:
2628 for command in command_list:
2629 # TODO: Code should be updated to pass in the method name instead of method()
2630 getattr(agent, command[:-2])()
2633def multi_thread_commands(agent_list: List, command_list: List, num_jobs=None) -> None:
2634 """
2635 Executes the list of commands in command_list for each AgentType in agent_list
2636 using a multithreaded system. Each command should be a method of that AgentType subclass.
2638 Parameters
2639 ----------
2640 agent_list : [AgentType]
2641 A list of instances of AgentType on which the commands will be run.
2642 command_list : [string]
2643 A list of commands to run for each AgentType in agent_list.
2645 Returns
2646 -------
2647 None
2648 """
2649 if len(agent_list) == 1:
2650 multi_thread_commands_fake(agent_list, command_list)
2651 return None
2653 # Default number of parallel jobs is the smaller of number of AgentTypes in
2654 # the input and the number of available cores.
2655 if num_jobs is None:
2656 num_jobs = min(len(agent_list), multiprocessing.cpu_count())
2658 # Send each command in command_list to each of the types in agent_list to be run
2659 agent_list_out = Parallel(n_jobs=num_jobs)(
2660 delayed(run_commands)(*args)
2661 for args in zip(agent_list, len(agent_list) * [command_list])
2662 )
2664 # Replace the original types with the output from the parallel call
2665 for j in range(len(agent_list)):
2666 agent_list[j] = agent_list_out[j]
2669def run_commands(agent: Any, command_list: List) -> Any:
2670 """
2671 Executes each command in command_list on a given AgentType. The commands
2672 should be methods of that AgentType's subclass.
2674 Parameters
2675 ----------
2676 agent : AgentType
2677 An instance of AgentType on which the commands will be run.
2678 command_list : [string]
2679 A list of commands that the agent should run, as methods.
2681 Returns
2682 -------
2683 agent : AgentType
2684 The same AgentType instance passed as input, after running the commands.
2685 """
2686 for command in command_list:
2687 # TODO: Code should be updated to pass in the method name instead of method()
2688 getattr(agent, command[:-2])()
2689 return agent