Coverage for HARK / core.py: 96%
1092 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-10 06:19 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-10 06:19 +0000
1"""
2High-level functions and classes for solving a wide variety of economic models.
3The "core" of HARK is a framework for "microeconomic" and "macroeconomic"
4models. A micro model concerns the dynamic optimization problem for some type
5of agents, where agents take the inputs to their problem as exogenous. A macro
6model adds an additional layer, endogenizing some of the inputs to the micro
7problem by finding a general equilibrium dynamic rule.
8"""
10# Import basic modules
11import inspect
12import sys
13from collections import namedtuple
14from copy import copy, deepcopy
15from dataclasses import dataclass, field
16from time import time
17from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
18from warnings import warn
19import multiprocessing
20from joblib import Parallel, delayed
21from pandas import DataFrame
23import numpy as np
24import pandas as pd
25from xarray import DataArray
27from HARK.distributions import (
28 Distribution,
29 IndexDistribution,
30 combine_indep_dstns,
31)
32from HARK.utilities import NullFunc, get_arg_names, get_it_from
33from HARK.simulator import make_simulator_from_agent
34from HARK.SSJutils import (
35 make_basic_SSJ_matrices,
36 make_flat_LC_SSJ_matrices,
37 calc_shock_response_manually,
38)
39from HARK.metric import MetricObject, distance_metric
41__all__ = [
42 "AgentType",
43 "Market",
44 "Parameters",
45 "Model",
46 "AgentPopulation",
47 "multi_thread_commands",
48 "multi_thread_commands_fake",
49 "NullFunc",
50 "make_one_period_oo_solver",
51 "distribute_params",
52]
55class Parameters:
56 """
57 A smart container for model parameters that handles age-varying dynamics.
59 This class stores parameters as an internal dictionary and manages their
60 age-varying properties, providing both attribute-style and dictionary-style
61 access. It is designed to handle the time-varying dynamics of parameters
62 in economic models.
64 Attributes
65 ----------
66 _length : int
67 The terminal age of the agents in the model.
68 _invariant_params : Set[str]
69 A set of parameter names that are invariant over time.
70 _varying_params : Set[str]
71 A set of parameter names that vary over time.
72 _parameters : Dict[str, Any]
73 The internal dictionary storing all parameters.
74 """
76 __slots__ = (
77 "_length",
78 "_invariant_params",
79 "_varying_params",
80 "_parameters",
81 "_frozen",
82 "_namedtuple_cache",
83 )
85 def __init__(self, **parameters: Any) -> None:
86 """
87 Initialize a Parameters object and parse the age-varying dynamics of parameters.
89 Parameters
90 ----------
91 T_cycle : int, optional
92 The number of time periods in the model cycle (default: 1).
93 Must be >= 1.
94 frozen : bool, optional
95 If True, the Parameters object will be immutable after initialization
96 (default: False).
97 _time_inv : List[str], optional
98 List of parameter names to explicitly mark as time-invariant,
99 overriding automatic inference.
100 _time_vary : List[str], optional
101 List of parameter names to explicitly mark as time-varying,
102 overriding automatic inference.
103 **parameters : Any
104 Any number of parameters in the form key=value.
106 Raises
107 ------
108 ValueError
109 If T_cycle is less than 1.
111 Notes
112 -----
113 Automatic time-variance inference rules:
114 - Scalars (int, float, bool, None) are time-invariant
115 - NumPy arrays are time-invariant (use lists/tuples for time-varying)
116 - Single-element lists/tuples [x] are unwrapped to x and time-invariant
117 - Multi-element lists/tuples are time-varying if length matches T_cycle
118 - 2D arrays with first dimension matching T_cycle are time-varying
119 - Distributions and Callables are time-invariant
121 Use _time_inv or _time_vary to override automatic inference when needed.
122 """
123 # Extract special parameters
124 self._length: int = parameters.pop("T_cycle", 1)
125 frozen: bool = parameters.pop("frozen", False)
126 time_inv_override: List[str] = parameters.pop("_time_inv", [])
127 time_vary_override: List[str] = parameters.pop("_time_vary", [])
129 # Validate T_cycle
130 if self._length < 1:
131 raise ValueError(f"T_cycle must be >= 1, got {self._length}")
133 # Initialize internal state
134 self._invariant_params: Set[str] = set()
135 self._varying_params: Set[str] = set()
136 self._parameters: Dict[str, Any] = {"T_cycle": self._length}
137 self._frozen: bool = False # Set to False initially to allow setup
138 self._namedtuple_cache: Optional[type] = None
140 # Set parameters using automatic inference
141 for key, value in parameters.items():
142 self[key] = value
144 # Apply explicit overrides
145 for param in time_inv_override:
146 if param in self._parameters:
147 self._invariant_params.add(param)
148 self._varying_params.discard(param)
150 for param in time_vary_override:
151 if param in self._parameters:
152 self._varying_params.add(param)
153 self._invariant_params.discard(param)
155 # Freeze if requested
156 self._frozen = frozen
158 def __getitem__(self, item_or_key: Union[int, str]) -> Union["Parameters", Any]:
159 """
160 Access parameters by age index or parameter name.
162 If item_or_key is an integer, returns a Parameters object with the parameters
163 that apply to that age. This includes all invariant parameters and the
164 `item_or_key`th element of all age-varying parameters. If item_or_key is a
165 string, it returns the value of the parameter with that name.
167 Parameters
168 ----------
169 item_or_key : Union[int, str]
170 Age index or parameter name.
172 Returns
173 -------
174 Union[Parameters, Any]
175 A new Parameters object for the specified age, or the value of the
176 specified parameter.
178 Raises
179 ------
180 ValueError:
181 If the age index is out of bounds.
182 KeyError:
183 If the parameter name is not found.
184 TypeError:
185 If the key is neither an integer nor a string.
186 """
187 if isinstance(item_or_key, int):
188 if item_or_key < 0 or item_or_key >= self._length:
189 raise ValueError(
190 f"Age {item_or_key} is out of bounds (valid: 0-{self._length - 1})."
191 )
193 params = {key: self._parameters[key] for key in self._invariant_params}
194 params.update(
195 {
196 key: (
197 self._parameters[key][item_or_key]
198 if isinstance(self._parameters[key], (list, tuple, np.ndarray))
199 else self._parameters[key]
200 )
201 for key in self._varying_params
202 }
203 )
204 return Parameters(**params)
205 elif isinstance(item_or_key, str):
206 return self._parameters[item_or_key]
207 else:
208 raise TypeError("Key must be an integer (age) or string (parameter name).")
210 def __setitem__(self, key: str, value: Any) -> None:
211 """
212 Set parameter values, automatically inferring time variance.
214 If the parameter is a scalar, numpy array, boolean, distribution, callable
215 or None, it is assumed to be invariant over time. If the parameter is a
216 list or tuple, it is assumed to be varying over time. If the parameter
217 is a list or tuple of length greater than 1, the length of the list or
218 tuple must match the `_length` attribute of the Parameters object.
220 2D numpy arrays with first dimension matching T_cycle are treated as
221 time-varying parameters.
223 Parameters
224 ----------
225 key : str
226 Name of the parameter.
227 value : Any
228 Value of the parameter.
230 Raises
231 ------
232 ValueError:
233 If the parameter name is not a string or if the value type is unsupported.
234 If the parameter value is inconsistent with the current model length.
235 RuntimeError:
236 If the Parameters object is frozen.
237 """
238 if self._frozen:
239 raise RuntimeError("Cannot modify frozen Parameters object")
241 if not isinstance(key, str):
242 raise ValueError(f"Parameter name must be a string, got {type(key)}")
244 # Check for 2D numpy arrays with time-varying first dimension
245 if isinstance(value, np.ndarray) and value.ndim >= 2:
246 if value.shape[0] == self._length:
247 self._varying_params.add(key)
248 self._invariant_params.discard(key)
249 else:
250 self._invariant_params.add(key)
251 self._varying_params.discard(key)
252 elif isinstance(
253 value,
254 (
255 int,
256 float,
257 np.ndarray,
258 type(None),
259 Distribution,
260 bool,
261 Callable,
262 MetricObject,
263 ),
264 ):
265 self._invariant_params.add(key)
266 self._varying_params.discard(key)
267 elif isinstance(value, (list, tuple)):
268 if len(value) == 1:
269 value = value[0]
270 self._invariant_params.add(key)
271 self._varying_params.discard(key)
272 elif self._length is None or self._length == 1:
273 self._length = len(value)
274 self._varying_params.add(key)
275 self._invariant_params.discard(key)
276 elif len(value) == self._length:
277 self._varying_params.add(key)
278 self._invariant_params.discard(key)
279 else:
280 raise ValueError(
281 f"Parameter {key} must have length 1 or {self._length}, not {len(value)}"
282 )
283 else:
284 raise ValueError(f"Unsupported type for parameter {key}: {type(value)}")
286 self._parameters[key] = value
288 def __iter__(self) -> Iterator[str]:
289 """Allow iteration over parameter names."""
290 return iter(self._parameters)
292 def __len__(self) -> int:
293 """Return the number of parameters."""
294 return len(self._parameters)
296 def keys(self) -> Iterator[str]:
297 """Return a view of parameter names."""
298 return self._parameters.keys()
300 def values(self) -> Iterator[Any]:
301 """Return a view of parameter values."""
302 return self._parameters.values()
304 def items(self) -> Iterator[Tuple[str, Any]]:
305 """Return a view of parameter (name, value) pairs."""
306 return self._parameters.items()
308 def to_dict(self) -> Dict[str, Any]:
309 """
310 Convert parameters to a plain dictionary.
312 Returns
313 -------
314 Dict[str, Any]
315 A dictionary containing all parameters.
316 """
317 return dict(self._parameters)
319 def to_namedtuple(self) -> namedtuple:
320 """
321 Convert parameters to a namedtuple.
323 The namedtuple class is cached for efficiency on repeated calls.
325 Returns
326 -------
327 namedtuple
328 A namedtuple containing all parameters.
329 """
330 if self._namedtuple_cache is None:
331 self._namedtuple_cache = namedtuple("Parameters", self.keys())
332 return self._namedtuple_cache(**self.to_dict())
334 def update(self, other: Union["Parameters", Dict[str, Any]]) -> None:
335 """
336 Update parameters from another Parameters object or dictionary.
338 Parameters
339 ----------
340 other : Union[Parameters, Dict[str, Any]]
341 The source of parameters to update from.
343 Raises
344 ------
345 TypeError
346 If the input is neither a Parameters object nor a dictionary.
347 """
348 if isinstance(other, Parameters):
349 for key, value in other._parameters.items():
350 self[key] = value
351 elif isinstance(other, dict):
352 for key, value in other.items():
353 self[key] = value
354 else:
355 raise TypeError(
356 "Update source must be a Parameters object or a dictionary."
357 )
359 def __repr__(self) -> str:
360 """Return a detailed string representation of the Parameters object."""
361 return (
362 f"Parameters(_length={self._length}, "
363 f"_invariant_params={self._invariant_params}, "
364 f"_varying_params={self._varying_params}, "
365 f"_parameters={self._parameters})"
366 )
368 def __str__(self) -> str:
369 """Return a simple string representation of the Parameters object."""
370 return f"Parameters({str(self._parameters)})"
372 def __getattr__(self, name: str) -> Any:
373 """
374 Allow attribute-style access to parameters.
376 Parameters
377 ----------
378 name : str
379 Name of the parameter to access.
381 Returns
382 -------
383 Any
384 The value of the specified parameter.
386 Raises
387 ------
388 AttributeError:
389 If the parameter name is not found.
390 """
391 if name.startswith("_"):
392 return super().__getattribute__(name)
393 try:
394 return self._parameters[name]
395 except KeyError:
396 raise AttributeError(f"'Parameters' object has no attribute '{name}'")
398 def __setattr__(self, name: str, value: Any) -> None:
399 """
400 Allow attribute-style setting of parameters.
402 Parameters
403 ----------
404 name : str
405 Name of the parameter to set.
406 value : Any
407 Value to set for the parameter.
408 """
409 if name.startswith("_"):
410 super().__setattr__(name, value)
411 else:
412 self[name] = value
414 def __contains__(self, item: str) -> bool:
415 """Check if a parameter exists in the Parameters object."""
416 return item in self._parameters
418 def copy(self) -> "Parameters":
419 """
420 Create a deep copy of the Parameters object.
422 Returns
423 -------
424 Parameters
425 A new Parameters object with the same contents.
426 """
427 return deepcopy(self)
429 def add_to_time_vary(self, *params: str) -> None:
430 """
431 Adds any number of parameters to the time-varying set.
433 Parameters
434 ----------
435 *params : str
436 Any number of strings naming parameters to be added to time_vary.
437 """
438 for param in params:
439 if param in self._parameters:
440 self._varying_params.add(param)
441 self._invariant_params.discard(param)
442 else:
443 warn(
444 f"Parameter '{param}' does not exist and cannot be added to time_vary."
445 )
447 def add_to_time_inv(self, *params: str) -> None:
448 """
449 Adds any number of parameters to the time-invariant set.
451 Parameters
452 ----------
453 *params : str
454 Any number of strings naming parameters to be added to time_inv.
455 """
456 for param in params:
457 if param in self._parameters:
458 self._invariant_params.add(param)
459 self._varying_params.discard(param)
460 else:
461 warn(
462 f"Parameter '{param}' does not exist and cannot be added to time_inv."
463 )
465 def del_from_time_vary(self, *params: str) -> None:
466 """
467 Removes any number of parameters from the time-varying set.
469 Parameters
470 ----------
471 *params : str
472 Any number of strings naming parameters to be removed from time_vary.
473 """
474 for param in params:
475 self._varying_params.discard(param)
477 def del_from_time_inv(self, *params: str) -> None:
478 """
479 Removes any number of parameters from the time-invariant set.
481 Parameters
482 ----------
483 *params : str
484 Any number of strings naming parameters to be removed from time_inv.
485 """
486 for param in params:
487 self._invariant_params.discard(param)
489 def get(self, key: str, default: Any = None) -> Any:
490 """
491 Get a parameter value, returning a default if not found.
493 Parameters
494 ----------
495 key : str
496 The parameter name.
497 default : Any, optional
498 The default value to return if the key is not found.
500 Returns
501 -------
502 Any
503 The parameter value or the default.
504 """
505 return self._parameters.get(key, default)
507 def set_many(self, **kwargs: Any) -> None:
508 """
509 Set multiple parameters at once.
511 Parameters
512 ----------
513 **kwargs : Keyword arguments representing parameter names and values.
514 """
515 for key, value in kwargs.items():
516 self[key] = value
518 def is_time_varying(self, key: str) -> bool:
519 """
520 Check if a parameter is time-varying.
522 Parameters
523 ----------
524 key : str
525 The parameter name.
527 Returns
528 -------
529 bool
530 True if the parameter is time-varying, False otherwise.
531 """
532 return key in self._varying_params
534 def at_age(self, age: int) -> "Parameters":
535 """
536 Get parameters for a specific age.
538 This is an alternative to integer indexing (params[age]) that is more
539 explicit and avoids potential confusion with dictionary-style access.
541 Parameters
542 ----------
543 age : int
544 The age index to retrieve parameters for.
546 Returns
547 -------
548 Parameters
549 A new Parameters object with parameters for the specified age.
551 Raises
552 ------
553 ValueError
554 If the age index is out of bounds.
556 Examples
557 --------
558 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97], sigma=2.0)
559 >>> age_1_params = params.at_age(1)
560 >>> age_1_params.beta
561 0.96
562 """
563 return self[age]
565 def validate(self) -> None:
566 """
567 Validate parameter consistency.
569 Checks that all time-varying parameters have length matching T_cycle.
570 This is useful after manual modifications or when parameters are set
571 programmatically.
573 Raises
574 ------
575 ValueError
576 If any time-varying parameter has incorrect length.
578 Examples
579 --------
580 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97])
581 >>> params.validate() # Passes
582 >>> params.add_to_time_vary("beta")
583 >>> params.validate() # Still passes
584 """
585 errors = []
586 for param in self._varying_params:
587 value = self._parameters[param]
588 if isinstance(value, (list, tuple)):
589 if len(value) != self._length:
590 errors.append(
591 f"Parameter '{param}' has length {len(value)}, expected {self._length}"
592 )
593 elif isinstance(value, np.ndarray):
594 if value.ndim == 0:
595 errors.append(
596 f"Parameter '{param}' is a 0-dimensional array (scalar), "
597 "which should not be time-varying"
598 )
599 elif value.ndim >= 2:
600 if value.shape[0] != self._length:
601 errors.append(
602 f"Parameter '{param}' has first dimension {value.shape[0]}, expected {self._length}"
603 )
604 elif value.ndim == 1:
605 if len(value) != self._length:
606 errors.append(
607 f"Parameter '{param}' has length {len(value)}, expected {self._length}"
608 )
609 elif value.ndim == 0:
610 errors.append(
611 f"Parameter '{param}' is a 0-dimensional numpy array, expected length {self._length}"
612 )
614 if errors:
615 raise ValueError(
616 "Parameter validation failed:\n" + "\n".join(f" - {e}" for e in errors)
617 )
620class Model:
621 """
622 A class with special handling of parameters assignment.
623 """
625 def __init__(self):
626 if not hasattr(self, "parameters"):
627 self.parameters = {}
628 if not hasattr(self, "constructors"):
629 self.constructors = {}
631 def assign_parameters(self, **kwds):
632 """
633 Assign an arbitrary number of attributes to this agent.
635 Parameters
636 ----------
637 **kwds : keyword arguments
638 Any number of keyword arguments of the form key=value.
639 Each value will be assigned to the attribute named in self.
641 Returns
642 -------
643 None
644 """
645 self.parameters.update(kwds)
646 for key in kwds:
647 setattr(self, key, kwds[key])
649 def get_parameter(self, name):
650 """
651 Returns a parameter of this model
653 Parameters
654 ----------
655 name : str
656 The name of the parameter to get
658 Returns
659 -------
660 value : The value of the parameter
661 """
662 return self.parameters[name]
664 def __eq__(self, other):
665 if isinstance(other, type(self)):
666 return self.parameters == other.parameters
668 return NotImplemented
670 def __str__(self):
671 type_ = type(self)
672 module = type_.__module__
673 qualname = type_.__qualname__
675 s = f"<{module}.{qualname} object at {hex(id(self))}.\n"
676 s += "Parameters:"
678 for p in self.parameters:
679 s += f"\n{p}: {self.parameters[p]}"
681 s += ">"
682 return s
684 def describe(self):
685 return self.__str__()
687 def del_param(self, param_name):
688 """
689 Deletes a parameter from this instance, removing it both from the object's
690 namespace (if it's there) and the parameters dictionary (likewise).
692 Parameters
693 ----------
694 param_name : str
695 A string naming a parameter or data to be deleted from this instance.
696 Removes information from self.parameters dictionary and own namespace.
698 Returns
699 -------
700 None
701 """
702 if param_name in self.parameters:
703 del self.parameters[param_name]
704 if hasattr(self, param_name):
705 delattr(self, param_name)
707 def _gather_constructor_args(self, key, constructor):
708 """
709 Gather all arguments needed to call the given constructor.
711 Handles both the special ``get_it_from`` case and normal callables.
712 For a normal callable the method inspects the function signature to
713 find every required argument, then resolves each argument from the
714 instance namespace (``self.<arg>``) or from ``self.parameters``.
715 Arguments that have a default value are silently skipped when they
716 cannot be found; required arguments that cannot be resolved are
717 recorded as missing.
719 Parameters
720 ----------
721 key : str
722 The name of the constructed object (used to record missing pairs).
723 constructor : callable or get_it_from
724 The constructor to be called.
726 Returns
727 -------
728 temp_dict : dict
729 Keyword arguments to pass to the constructor.
730 any_missing : bool
731 True when at least one required argument could not be resolved.
732 missing_args : list of str
733 Names of unresolved required arguments.
734 missing_key_data : list of tuple
735 ``(key, arg)`` pairs for every unresolved required argument.
736 """
737 missing_key_data = []
739 # SPECIAL: if the constructor is get_it_from, handle it separately
740 if isinstance(constructor, get_it_from):
741 try:
742 parent = getattr(self, constructor.name)
743 query = key
744 any_missing = False
745 missing_args = []
746 except AttributeError:
747 parent = None
748 query = None
749 any_missing = True
750 missing_args = [constructor.name]
751 temp_dict = {"parent": parent, "query": query}
752 return temp_dict, any_missing, missing_args, missing_key_data
754 # Normal constructor: inspect signature and gather arguments
755 args_needed = get_arg_names(constructor)
756 has_no_default = {
757 k: v.default is inspect.Parameter.empty
758 for k, v in inspect.signature(constructor).parameters.items()
759 }
760 temp_dict = {}
761 any_missing = False
762 missing_args = []
763 for this_arg in args_needed:
764 if hasattr(self, this_arg):
765 temp_dict[this_arg] = getattr(self, this_arg)
766 else:
767 try:
768 temp_dict[this_arg] = self.parameters[this_arg]
769 except KeyError:
770 if has_no_default[this_arg]:
771 # Record missing key-data pair
772 any_missing = True
773 missing_key_data.append((key, this_arg))
774 missing_args.append(this_arg)
776 return temp_dict, any_missing, missing_args, missing_key_data
778 def _attempt_construct(self, key, i, keys_complete, backup, errors, force):
779 """
780 Attempt to construct the object for a single key.
782 The method looks up the constructor, gathers its arguments, runs it,
783 and records any errors. It also handles the ``None``-constructor case
784 (restore from backup) and the missing-args case (record and defer).
786 Parameters
787 ----------
788 key : str
789 The name of the object to construct.
790 i : int
791 Index of *key* inside the ``keys`` array (used to update
792 ``keys_complete``).
793 keys_complete : np.ndarray of bool
794 Boolean array indicating which keys have been completed; mutated
795 in place when this key succeeds.
796 backup : dict
797 Dictionary of pre-construction attribute values.
798 errors : dict
799 ``self._constructor_errors``; mutated in place.
800 force : bool
801 When True, swallow exceptions and continue; when False, re-raise.
803 Returns
804 -------
805 accomplished : bool
806 True when the key was completed (constructor ran or was None).
807 missing_key_data : list of tuple
808 ``(key, arg)`` pairs recorded for missing required arguments.
809 """
810 missing_key_data = []
812 # Look up the constructor for this key
813 try:
814 constructor = self.constructors[key]
815 except Exception as not_found:
816 errors[key] = "No constructor found for " + str(not_found)
817 if force:
818 return False, missing_key_data
819 else:
820 raise KeyError("No constructor found for " + key) from None
822 # If the constructor is None, restore from backup and mark complete
823 if constructor is None:
824 if key in backup.keys():
825 setattr(self, key, backup[key])
826 self.parameters[key] = backup[key]
827 keys_complete[i] = True
828 return True, missing_key_data
830 # Gather arguments for the constructor
831 temp_dict, any_missing, missing_args, missing_key_data = (
832 self._gather_constructor_args(key, constructor)
833 )
835 # If all required data was found, run the constructor and store the result
836 if not any_missing:
837 try:
838 temp = constructor(**temp_dict)
839 except Exception as problem:
840 errors[key] = str(type(problem)) + ": " + str(problem)
841 self.del_param(key)
842 if force:
843 return False, missing_key_data
844 else:
845 raise
846 setattr(self, key, temp)
847 self.parameters[key] = temp
848 if key in errors:
849 del errors[key]
850 keys_complete[i] = True
851 return True, missing_key_data
853 # Some required arguments were missing; record and defer
854 msg = "Missing required arguments: " + ", ".join(missing_args)
855 errors[key] = msg
856 self.del_param(key)
857 # Never raise exceptions here, as the arguments might be filled in later
858 return False, missing_key_data
860 def _construct_pass(self, keys, keys_complete, backup, errors, force):
861 """
862 Perform one full sweep over all incomplete keys.
864 Calls ``_attempt_construct`` for every key that has not yet been
865 completed and accumulates results.
867 Parameters
868 ----------
869 keys : sequence of str
870 All keys requested for construction.
871 keys_complete : np.ndarray of bool
872 Boolean array indicating which keys have been completed; mutated
873 in place by ``_attempt_construct``.
874 backup : dict
875 Dictionary of pre-construction attribute values.
876 errors : dict
877 ``self._constructor_errors``; mutated in place.
878 force : bool
879 Passed through to ``_attempt_construct``.
881 Returns
882 -------
883 anything_accomplished : bool
884 True when at least one key was completed during this pass.
885 missing_key_data : list of tuple
886 Accumulated ``(key, arg)`` pairs for all unresolved required
887 arguments across every incomplete key.
888 """
889 anything_accomplished = False
890 missing_key_data = []
892 for i, key in enumerate(keys):
893 if keys_complete[i]:
894 continue # This key has already been built
896 accomplished, key_missing = self._attempt_construct(
897 key, i, keys_complete, backup, errors, force
898 )
899 if accomplished:
900 anything_accomplished = True
901 missing_key_data.extend(key_missing)
903 return anything_accomplished, missing_key_data
905 def construct(self, *args, force=False):
906 """
907 Top-level method for building constructed inputs. If called without any
908 inputs, construct builds each of the objects named in the keys of the
909 constructors dictionary; it draws inputs for the constructors from the
910 parameters dictionary and adds its results to the same. If passed one or
911 more strings as arguments, the method builds only the named keys. The
912 method will do multiple "passes" over the requested keys, as some cons-
913 tructors require inputs built by other constructors. If any requested
914 constructors failed to build due to missing data, those keys (and the
915 missing data) will be named in self._missing_key_data. Other errors are
916 recorded in the dictionary attribute _constructor_errors.
918 This method tries to "start from scratch" by removing prior constructed
919 objects, holding them in a backup dictionary during construction. This
920 is done so that dependencies among constructors are resolved properly,
921 without mistakenly relying on "old information". A backup value is used
922 if a constructor function is set to None (i.e. "don't do anything"), or
923 if the construct method fails to produce a new object.
925 Parameters
926 ----------
927 *args : str, optional
928 Keys of self.constructors that are requested to be constructed.
929 If no arguments are passed, *all* elements of the dictionary are implied.
930 force : bool, optional
931 When True, the method will force its way past any errors, including
932 missing constructors, missing arguments for constructors, and errors
933 raised during execution of constructors. Information about all such
934 errors is stored in the dictionary attributes described above. When
935 False (default), any errors or exception will be raised.
937 Returns
938 -------
939 None
940 """
941 # Set up the requested work
942 if len(args) > 0:
943 keys = args
944 else:
945 keys = list(self.constructors.keys())
946 N_keys = len(keys)
947 keys_complete = np.zeros(N_keys, dtype=bool)
948 if N_keys == 0:
949 return # Do nothing if there are no constructed objects
951 # Remove pre-existing constructed objects, preventing "incomplete" updates,
952 # but store the current values in a backup dictionary in case something fails
953 backup = {}
954 for key in keys:
955 if hasattr(self, key):
956 backup[key] = getattr(self, key)
957 self.del_param(key)
959 # Get the dictionary of constructor errors
960 if not hasattr(self, "_constructor_errors"):
961 self._constructor_errors = {}
962 errors = self._constructor_errors
964 # As long as the work isn't complete and we made some progress on the last
965 # pass, repeatedly perform passes of trying to construct objects
966 any_keys_incomplete = np.any(np.logical_not(keys_complete))
967 go = any_keys_incomplete
968 while go:
969 anything_accomplished, missing_key_data = self._construct_pass(
970 keys, keys_complete, backup, errors, force
971 )
972 any_keys_incomplete = np.any(np.logical_not(keys_complete))
973 go = any_keys_incomplete and anything_accomplished
975 # Store missing key-data pairs and exit
976 self._missing_key_data = missing_key_data
977 self._constructor_errors = errors
978 if any_keys_incomplete:
979 msg = "Did not construct these objects:"
980 for i in range(N_keys):
981 if keys_complete[i]:
982 continue
983 msg += " " + keys[i] + ","
984 key = keys[i]
985 if key in backup.keys():
986 setattr(self, key, backup[key])
987 self.parameters[key] = backup[key]
988 msg = msg[:-1]
989 if not force:
990 raise ValueError(msg)
991 return
993 def describe_constructors(self, *args):
994 """
995 Prints to screen a string describing this instance's constructed objects,
996 including their names, the function that constructs them, the names of
997 those functions inputs, and whether those inputs are present.
999 Parameters
1000 ----------
1001 *args : str, optional
1002 Optional list of strings naming constructed inputs to be described.
1003 If none are passed, all constructors are described.
1005 Returns
1006 -------
1007 None
1008 """
1009 if len(args) > 0:
1010 keys = args
1011 else:
1012 keys = list(self.constructors.keys())
1013 yes = "\u2713"
1014 no = "X"
1015 maybe = "*"
1016 noyes = [no, yes]
1018 out = ""
1019 for key in keys:
1020 has_val = hasattr(self, key) or (key in self.parameters)
1022 try:
1023 constructor = self.constructors[key]
1024 except KeyError:
1025 out += noyes[int(has_val)] + " " + key + " : NO CONSTRUCTOR FOUND\n"
1026 continue
1028 # Get the constructor function if possible
1029 if isinstance(constructor, get_it_from):
1030 parent_name = self.constructors[key].name
1031 out += (
1032 noyes[int(has_val)]
1033 + " "
1034 + key
1035 + " : get it from "
1036 + parent_name
1037 + "\n"
1038 )
1039 continue
1040 else:
1041 out += (
1042 noyes[int(has_val)]
1043 + " "
1044 + key
1045 + " : "
1046 + constructor.__name__
1047 + "\n"
1048 )
1050 # Get constructor argument names
1051 arg_names = get_arg_names(constructor)
1052 has_no_default = {
1053 k: v.default is inspect.Parameter.empty
1054 for k, v in inspect.signature(constructor).parameters.items()
1055 }
1057 # Check whether each argument exists
1058 for j in range(len(arg_names)):
1059 this_arg = arg_names[j]
1060 if hasattr(self, this_arg) or this_arg in self.parameters:
1061 symb = yes
1062 elif not has_no_default[this_arg]:
1063 symb = maybe
1064 else:
1065 symb = no
1066 out += " " + symb + " " + this_arg + "\n"
1068 # Print the string to screen
1069 print(out)
1070 return
1072 # This is a "synonym" method so that old calls to update() still work
1073 def update(self, *args, **kwargs):
1074 self.construct(*args, **kwargs)
1077class AgentType(Model):
1078 """
1079 A superclass for economic agents in the HARK framework. Each model should
1080 specify its own subclass of AgentType, inheriting its methods and overwriting
1081 as necessary. Critically, every subclass of AgentType should define class-
1082 specific static values of the attributes time_vary and time_inv as lists of
1083 strings. Each element of time_vary is the name of a field in AgentSubType
1084 that varies over time in the model. Each element of time_inv is the name of
1085 a field in AgentSubType that is constant over time in the model.
1087 Parameters
1088 ----------
1089 solution_terminal : Solution
1090 A representation of the solution to the terminal period problem of
1091 this AgentType instance, or an initial guess of the solution if this
1092 is an infinite horizon problem.
1093 cycles : int
1094 The number of times the sequence of periods is experienced by this
1095 AgentType in their "lifetime". cycles=1 corresponds to a lifecycle
1096 model, with a certain sequence of one period problems experienced
1097 once before terminating. cycles=0 corresponds to an infinite horizon
1098 model, with a sequence of one period problems repeating indefinitely.
1099 pseudo_terminal : bool
1100 Indicates whether solution_terminal isn't actually part of the
1101 solution to the problem (as a known solution to the terminal period
1102 problem), but instead represents a "scrap value"-style termination.
1103 When True, solution_terminal is not included in the solution; when
1104 False, solution_terminal is the last element of the solution.
1105 tolerance : float
1106 Maximum acceptable "distance" between successive solutions to the
1107 one period problem in an infinite horizon (cycles=0) model in order
1108 for the solution to be considered as having "converged". Inoperative
1109 when cycles>0.
1110 verbose : int
1111 Level of output to be displayed by this instance, default is 1.
1112 quiet : bool
1113 Indicator for whether this instance should operate "quietly", default False.
1114 seed : int
1115 A seed for this instance's random number generator.
1116 construct : bool
1117 Indicator for whether this instance's construct() method should be run
1118 when initialized (default True). When False, an instance of the class
1119 can be created even if not all of its attributes can be constructed.
1120 use_defaults : bool
1121 Indicator for whether this instance should use the values in the class'
1122 default dictionary to fill in parameters and constructors for those not
1123 provided by the user (default True). Setting this to False is useful for
1124 situations where the user wants to be absolutely sure that they know what
1125 is being passed to the class initializer, without resorting to defaults.
1127 Attributes
1128 ----------
1129 AgentCount : int
1130 The number of agents of this type to use in simulation.
1132 state_vars : list of string
1133 The string labels for this AgentType's model state variables.
1134 """
1136 time_vary_ = []
1137 time_inv_ = []
1138 shock_vars_ = []
1139 state_vars = []
1140 poststate_vars = []
1141 market_vars = []
1142 distributions = []
1143 default_ = {"params": {}, "solver": NullFunc()}
1145 def __init__(
1146 self,
1147 solution_terminal=None,
1148 pseudo_terminal=True,
1149 tolerance=0.000001,
1150 verbose=1,
1151 quiet=False,
1152 seed=0,
1153 construct=True,
1154 use_defaults=True,
1155 **kwds,
1156 ):
1157 super().__init__()
1158 params = deepcopy(self.default_["params"]) if use_defaults else {}
1159 params.update(kwds)
1161 # Correctly handle constructors that have been passed in kwds
1162 if "constructors" in self.default_["params"].keys() and use_defaults:
1163 constructors = deepcopy(self.default_["params"]["constructors"])
1164 else:
1165 constructors = {}
1166 if "constructors" in kwds.keys():
1167 constructors.update(kwds["constructors"])
1168 params["constructors"] = constructors
1170 # Set default track_vars
1171 if "track_vars" in self.default_.keys() and use_defaults:
1172 self.track_vars = copy(self.default_["track_vars"])
1173 else:
1174 self.track_vars = []
1176 # Set model file name if possible
1177 try:
1178 self.model_file = copy(self.default_["model"])
1179 except (KeyError, TypeError):
1180 # Fallback to None if "model" key is missing or invalid for copying
1181 self.model_file = None
1183 if solution_terminal is None:
1184 solution_terminal = NullFunc()
1186 self.solve_one_period = self.default_["solver"] # NOQA
1187 self.solution_terminal = solution_terminal # NOQA
1188 self.pseudo_terminal = pseudo_terminal # NOQA
1189 self.tolerance = tolerance # NOQA
1190 self.verbose = verbose
1191 self.quiet = quiet
1192 self.seed = seed # NOQA
1193 self.state_now = {sv: None for sv in self.state_vars}
1194 self.state_prev = self.state_now.copy()
1195 self.controls = {}
1196 self.shocks = {}
1197 self.read_shocks = False # NOQA
1198 self.shock_history = {}
1199 self.newborn_init_history = {}
1200 self.history = {}
1201 self.assign_parameters(**params) # NOQA
1202 self.reset_rng() # NOQA
1203 self.bilt = {}
1204 if construct:
1205 self.construct()
1207 # Add instance-level lists and objects
1208 self.time_vary = deepcopy(self.time_vary_)
1209 self.time_inv = deepcopy(self.time_inv_)
1210 self.shock_vars = deepcopy(self.shock_vars_)
1212 def add_to_time_vary(self, *params):
1213 """
1214 Adds any number of parameters to time_vary for this instance.
1216 Parameters
1217 ----------
1218 params : string
1219 Any number of strings naming attributes to be added to time_vary
1221 Returns
1222 -------
1223 None
1224 """
1225 for param in params:
1226 if param not in self.time_vary:
1227 self.time_vary.append(param)
1229 def add_to_time_inv(self, *params):
1230 """
1231 Adds any number of parameters to time_inv for this instance.
1233 Parameters
1234 ----------
1235 params : string
1236 Any number of strings naming attributes to be added to time_inv
1238 Returns
1239 -------
1240 None
1241 """
1242 for param in params:
1243 if param not in self.time_inv:
1244 self.time_inv.append(param)
1246 def del_from_time_vary(self, *params):
1247 """
1248 Removes any number of parameters from time_vary for this instance.
1250 Parameters
1251 ----------
1252 params : string
1253 Any number of strings naming attributes to be removed from time_vary
1255 Returns
1256 -------
1257 None
1258 """
1259 for param in params:
1260 if param in self.time_vary:
1261 self.time_vary.remove(param)
1263 def del_from_time_inv(self, *params):
1264 """
1265 Removes any number of parameters from time_inv for this instance.
1267 Parameters
1268 ----------
1269 params : string
1270 Any number of strings naming attributes to be removed from time_inv
1272 Returns
1273 -------
1274 None
1275 """
1276 for param in params:
1277 if param in self.time_inv:
1278 self.time_inv.remove(param)
1280 def unpack(self, name):
1281 """
1282 Unpacks an attribute from a solution object for easier access.
1283 After the model has been solved, its components (like consumption function)
1284 reside in the attributes of each element of `ThisType.solution` (e.g. `cFunc`).
1285 This method creates a (time varying) attribute of the given attribute name
1286 that contains a list of elements accessible by `ThisType.parameter`.
1288 Parameters
1289 ----------
1290 name: str
1291 Name of the attribute to unpack from the solution
1293 Returns
1294 -------
1295 none
1296 """
1297 # Use list comprehension for better performance instead of loop with append
1298 if type(self.solution[0]) is dict:
1299 setattr(self, name, [soln_t[name] for soln_t in self.solution])
1300 else:
1301 setattr(self, name, [soln_t.__dict__[name] for soln_t in self.solution])
1302 self.add_to_time_vary(name)
1304 def solve(
1305 self,
1306 verbose=False,
1307 presolve=True,
1308 postsolve=True,
1309 from_solution=None,
1310 from_t=None,
1311 ):
1312 """
1313 Solve the model for this instance of an agent type by backward induction.
1314 Loops through the sequence of one period problems, passing the solution
1315 from period t+1 to the problem for period t.
1317 Parameters
1318 ----------
1319 verbose : bool, optional
1320 If True, solution progress is printed to screen. Default False.
1321 presolve : bool, optional
1322 If True (default), the pre_solve method is run before solving.
1323 postsolve : bool, optional
1324 If True (default), the post_solve method is run after solving.
1325 from_solution: Solution
1326 If different from None, will be used as the starting point of backward
1327 induction, instead of self.solution_terminal.
1328 from_t : int or None
1329 If not None, indicates which period of the model the solver should start
1330 from. It should usually only be used in combination with from_solution.
1331 Stands for the time index that from_solution represents, and thus is
1332 only compatible with cycles=1 and will be reset to None otherwise.
1334 Returns
1335 -------
1336 none
1337 """
1339 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep-
1340 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is
1341 # -np.inf, np.inf/np.inf is np.nan and so on.
1342 with np.errstate(
1343 divide="ignore", over="ignore", under="ignore", invalid="ignore"
1344 ):
1345 if presolve:
1346 self.pre_solve() # Do pre-solution stuff
1347 self.solution = solve_agent(
1348 self,
1349 verbose,
1350 from_solution,
1351 from_t,
1352 ) # Solve the model by backward induction
1353 if postsolve:
1354 self.post_solve() # Do post-solution stuff
1356 def reset_rng(self):
1357 """
1358 Reset the random number generator and all distributions for this type.
1359 Type-checking for lists is to handle the following three cases:
1361 1) The target is a single distribution object
1362 2) The target is a list of distribution objects (probably time-varying)
1363 3) The target is a nested list of distributions, as in ConsMarkovModel.
1364 """
1365 self.RNG = np.random.default_rng(self.seed)
1366 for name in self.distributions:
1367 if not hasattr(self, name):
1368 continue
1370 dstn = getattr(self, name)
1371 if isinstance(dstn, list):
1372 for D in dstn:
1373 if isinstance(D, list):
1374 for d in D:
1375 d.reset()
1376 else:
1377 D.reset()
1378 else:
1379 dstn.reset()
1381 def check_elements_of_time_vary_are_lists(self):
1382 """
1383 A method to check that elements of time_vary are lists.
1384 """
1385 for param in self.time_vary:
1386 if not hasattr(self, param):
1387 continue
1388 if not isinstance(
1389 getattr(self, param),
1390 (IndexDistribution,),
1391 ):
1392 assert type(getattr(self, param)) == list, (
1393 param
1394 + " is not a list or time varying distribution,"
1395 + " but should be because it is in time_vary"
1396 )
1398 def check_restrictions(self):
1399 """
1400 A method to check that various restrictions are met for the model class.
1401 """
1402 return
1404 def pre_solve(self):
1405 """
1406 A method that is run immediately before the model is solved, to check inputs or to prepare
1407 the terminal solution, perhaps.
1409 Parameters
1410 ----------
1411 none
1413 Returns
1414 -------
1415 none
1416 """
1417 self.check_restrictions()
1418 self.check_elements_of_time_vary_are_lists()
1419 return None
1421 def post_solve(self):
1422 """
1423 A method that is run immediately after the model is solved, to finalize
1424 the solution in some way. Does nothing here.
1426 Parameters
1427 ----------
1428 none
1430 Returns
1431 -------
1432 none
1433 """
1434 return None
1436 def get_market_params(self, mkt, construct=True):
1437 """
1438 Fetch data named in class attribute market_vars and assign it as attributes
1439 (and parameters) of self. By default, the construct method is run within
1440 this method, because the market parameters often have information needed to
1441 "complete" the microeconomic problem.
1443 This method is called automatically by the Market.give_agent_params()
1444 method for all agents.
1446 Parameters
1447 ----------
1448 mkt : Market
1449 Market to which this AgentType belongs.
1450 construct : bool
1451 Indicator for whether constructed attributes should be updated after
1452 fetching data / parameters from mkt (default True)
1454 Returns
1455 -------
1456 None
1457 """
1458 temp_dict = {}
1459 for name in self.market_vars:
1460 temp_dict[name] = copy(getattr(mkt, name))
1461 self.assign_parameters(**temp_dict)
1462 if construct:
1463 self.construct()
1465 def initialize_sym(self, **kwargs):
1466 """
1467 Use the new simulator structure to build a simulator from the agents'
1468 attributes, storing it in a private attribute.
1469 """
1470 self.reset_rng() # ensure seeds are set identically each time
1471 self._simulator = make_simulator_from_agent(self, **kwargs)
1472 self._simulator.reset()
1474 def find_target(self, target_var, force_list=False, **kwargs):
1475 r"""
1476 Find the "target" level of a named variable such that $E[\Delta x] = 0$,
1477 with $E[\Delta x-\epsilon] > 0$ and $E[\Delta x+\epsilon] < 0$ (locally stable).
1478 Returns a single real value if there is only one target, and a list if multiple;
1479 returns np.nan if no target is found. Pass force_list=True to always get a list.
1480 See documentation for HARK.simulator.find_target_state for more options.
1481 """
1482 if not hasattr(self, "solution"):
1483 raise AttributeError("Model must be solved before using find_target!")
1484 temp_simulator = make_simulator_from_agent(self)
1485 target_vals = temp_simulator.find_target_state(target_var, **kwargs)
1486 if force_list:
1487 return target_vals
1488 if len(target_vals) == 0:
1489 return np.nan
1490 elif len(target_vals) == 1:
1491 return target_vals[0]
1492 else:
1493 return target_vals
1495 def initialize_sim(self):
1496 """
1497 Prepares this AgentType for a new simulation. Resets the internal random number generator,
1498 makes initial states for all agents (using sim_birth), clears histories of tracked variables.
1500 Parameters
1501 ----------
1502 None
1504 Returns
1505 -------
1506 None
1507 """
1508 if not hasattr(self, "T_sim"):
1509 raise Exception(
1510 "To initialize simulation variables it is necessary to first "
1511 + "set the attribute T_sim to the largest number of observations "
1512 + "you plan to simulate for each agent including re-births."
1513 )
1514 elif self.T_sim <= 0:
1515 raise Exception(
1516 "T_sim represents the largest number of observations "
1517 + "that can be simulated for an agent, and must be a positive number."
1518 )
1520 self.reset_rng()
1521 self.t_sim = 0
1522 all_agents = np.ones(self.AgentCount, dtype=bool)
1523 blank_array = np.empty(self.AgentCount)
1524 blank_array[:] = np.nan
1525 for var in self.state_vars:
1526 self.state_now[var] = copy(blank_array)
1528 # Number of periods since agent entry
1529 self.t_age = np.zeros(self.AgentCount, dtype=int)
1530 # Which cycle period each agent is on
1531 self.t_cycle = np.zeros(self.AgentCount, dtype=int)
1532 self.sim_birth(all_agents)
1534 # If we are asked to use existing shocks and a set of initial conditions
1535 # exist, use them
1536 if self.read_shocks and bool(self.newborn_init_history):
1537 for var_name in self.state_now:
1538 # Check that we are actually given a value for the variable
1539 if var_name in self.newborn_init_history.keys():
1540 # Copy only array-like idiosyncratic states. Aggregates should
1541 # not be set by newborns
1542 idio = (
1543 isinstance(self.state_now[var_name], np.ndarray)
1544 and len(self.state_now[var_name]) == self.AgentCount
1545 )
1546 if idio:
1547 self.state_now[var_name] = self.newborn_init_history[var_name][
1548 0
1549 ]
1551 else:
1552 warn(
1553 "The option for reading shocks was activated but "
1554 + "the model requires state "
1555 + var_name
1556 + ", not contained in "
1557 + "newborn_init_history."
1558 )
1560 self.clear_history()
1562 def _export_single_var_by_time(self, history, var, t, dtype):
1563 """
1564 Mode 1a: single variable, by_age=False.
1566 Returns a DataFrame whose columns are simulation periods (t_sim) and
1567 whose rows are agent indices. When t is None all T_sim periods are
1568 included; when t is an array only those periods are included.
1569 """
1570 try:
1571 data = history[var]
1572 except KeyError:
1573 raise KeyError("Variable named " + var + " not found in simulated data!")
1575 if t is None:
1576 cols = [str(i) for i in range(self.T_sim)]
1577 df = DataFrame(data=data.T, columns=cols, dtype=dtype)
1578 else:
1579 cols = [str(t_val) for t_val in t]
1580 df = DataFrame(data=data[t, :].T, columns=cols, dtype=dtype)
1581 return df
1583 def _export_single_var_by_age(self, history, var, age, t, dtype, sym):
1584 """
1585 Mode 1b: single variable, by_age=True.
1587 Returns a DataFrame whose columns are within-agent model ages (t_age)
1588 and whose rows are individual agent lifetimes. Observations after
1589 death (or before birth) are NaN. When t is None all ages up to
1590 max(t_age) are included; when t is an array only those ages are used.
1591 The sym flag controls newborn detection: age == 0 when sym is True,
1592 age == 1 when sym is False.
1593 """
1594 try:
1595 data = history[var]
1596 except KeyError:
1597 raise KeyError("Variable named " + var + " not found in simulated data!")
1599 # Determine which ages to include and mark qualifying observations
1600 if t is None:
1601 age_set = np.arange(np.max(age) + 1)
1602 in_age_set = np.ones_like(data, dtype=bool)
1603 else:
1604 age_set = t
1605 in_age_set = np.zeros_like(data, dtype=bool)
1606 for j in age_set:
1607 these = age == j
1608 in_age_set[these] = True
1610 # Locate newborns to determine the number of individual lifetimes (rows)
1611 newborns = age == 0 if sym else age == 1
1612 T = age_set.size # number of age columns
1613 N = np.sum(newborns) # number of agent lifetimes (rows)
1614 out = np.full((N, T), np.nan)
1616 # Extract each individual's sequence and place it into the output array
1617 n = 0
1618 for i in range(self.AgentCount):
1619 data_i = data[:, i]
1620 births = np.where(newborns[:, i])[0]
1621 K = births.size
1622 for k in range(K):
1623 start = births[k]
1624 stop = births[k + 1] if (k < K - 1) else self.T_sim
1625 use = in_age_set[start:stop, i]
1626 temp = data_i[start:stop][use]
1627 out[n, : temp.size] = temp
1628 n += 1
1630 cols = [str(a) for a in age_set]
1631 df = DataFrame(data=out, columns=cols, dtype=dtype)
1632 return df
1634 def _export_single_t_by_time(self, history, var_list, t, dtype):
1635 """
1636 Mode 2a: single time period, by_age=False.
1638 Returns a DataFrame with one row per agent and one column per variable,
1639 drawn from the absolute simulation period t (i.e. history[name][t, :]).
1640 """
1641 K = len(var_list)
1642 N = self.AgentCount
1643 out = np.full((N, K), np.nan)
1644 for k, name in enumerate(var_list):
1645 out[:, k] = history[name][t, :]
1646 df = DataFrame(data=out, columns=var_list, dtype=dtype)
1647 return df
1649 def _export_single_t_by_age(self, history, var_list, age, t, dtype):
1650 """
1651 Mode 2b: single time period, by_age=True.
1653 Returns a DataFrame with one row per agent-period at which t_age == t
1654 and one column per variable.
1655 """
1656 right_age = age == t
1657 N = np.sum(right_age)
1658 K = len(var_list)
1659 out = np.full((N, K), np.nan)
1660 for k, name in enumerate(var_list):
1661 out[:, k] = history[name][right_age]
1662 df = DataFrame(data=out, columns=var_list, dtype=dtype)
1663 return df
1665 def export_to_df(self, var=None, t=None, by_age=False, dtype=None, sym=False):
1666 """
1667 Export an AgentType instance's simulated data to a pandas dataframe object.
1668 There are four construction modes depending on the arguments passed:
1670 1a) If exactly one simulated variable is named as var and by_age is False,
1671 then the dataframe will contain T_sim columns, each representing one
1672 simulated period in absolute simulation time t_sim. Each row of the
1673 dataframe will represent one *agent index* of the population, with death
1674 and replacement occurring within a row. Optionally, argument t can be
1675 provided as an array to specify which periods to include (default all).
1677 1b) If exactly one simulated variable is named as var and by_age is True,
1678 then the dataframe's columns will correspond to within-agent model age
1679 t_age. Each row of the dataframe will represent one specific agent from
1680 model entry (t_age=0) to model death. All observations after death will
1681 be NaN. Optionally, argument t can be provided as an array to specify
1682 which ages to include (default all). Number of columns in dataframe will
1683 depend on max(t_age) and/or argument t.
1685 2a) If an integer is provided as t and by_age is False, then each column of
1686 the dataframe will represent a different simulated variable, using the
1687 value for the specified absolute simulated period t=t_sim. Optionally,
1688 the var argument can be provided as a list of strings naming which var-
1689 iables should be included in the dataframe (default all).
1691 2b) If an integer is provided as t and by_age is True, then each column of
1692 the dataframe will represent a different simulated variable, taken from
1693 all agent-periods at which t == t_age, within-agent model age. Optionally,
1694 the var argument can be provided as a list of strings naming which var-
1695 iables should be included in the dataframe (default all).
1697 In summary, *either* var should be a single string *or* t should be an integer.
1698 Any other combination of var and t will raise an exception.
1700 Parameters
1701 ----------
1702 var : str or [str] or None
1703 If a single string is provided, it represents the name of the one simulated
1704 variable to export. If a list of strings, then the argument t must also be
1705 provided to indicate which time period the dataframe will represent. Name(s)
1706 must correspond to a key for history or hystory dictionary (i.e. named in track_vars).
1707 If not provided, then all keys in history or hystory are included.
1708 t : int or np.array or None
1709 If an integer, indicates which one period will be included in the dataframe.
1710 When by_age is False (default), t refers to absolute simulated time t_sim:
1711 literally the t-th row of history[key]. When by_age is True, t refers to
1712 within-agent model age t_age; the dataframe will include all agent-periods
1713 where the agent has exactly t_age==t. If var is a single string, then t is
1714 an optional input as an array of periods (or ages) to include (default all).
1715 by_age : bool
1716 Indicator for whether observation selection should be on the basis of absolute
1717 simulated time t_sim or within-agent model age t_age. If True, then t_age
1718 must be in track_vars so that it appears in the simulated data. Additionally,
1719 argument dtype should *not* be provided when by_age is True, as this will
1720 result in NaNs being cast to a datatype that doesn't necessarily support them.
1721 dtype : type or None
1722 Optional data type to cast the dataframe. By default, uses the datatype from
1723 the entry in history or hystory.
1724 sym : bool
1725 Indicator for whether the dataframe should look for simulated data in the
1726 history (False, default) or hystory (True) dictionary attribute. This option
1727 will be deprecated in the future when legacy simulation methods are removed.
1729 Returns
1730 -------
1731 df : pandas.DataFrame
1732 The requested dataframe, constructed from this instance's simulated data.
1733 """
1734 # Validate arguments
1735 single_var = type(var) is str
1736 single_t = isinstance(t, (int, np.integer))
1737 if not (single_var ^ single_t):
1738 raise ValueError(
1739 "Either var must be a single string, or t must be a single integer!"
1740 )
1741 if dtype is not None and by_age:
1742 raise ValueError(
1743 "Can't specify dtype when using by_age is True because of potential incompatibility with representing NaN"
1744 )
1746 # Get the relevant history dictionary (deprecate in future)
1747 history = self.hystory if sym else self.history
1749 # Retrieve age array once if needed (raises a clear error when missing)
1750 if by_age:
1751 try:
1752 age = history["t_age"]
1753 except KeyError:
1754 raise KeyError(
1755 "t_age must be in track_vars if by_age=True will be used!"
1756 )
1758 # Route to the appropriate private method
1759 if single_var and not by_age:
1760 return self._export_single_var_by_time(history, var, t, dtype)
1761 elif single_var and by_age:
1762 return self._export_single_var_by_age(history, var, age, t, dtype, sym)
1763 else: # single_t
1764 # Build and validate the variable list
1765 if var is None:
1766 var_list = list(history.keys())
1767 else:
1768 var_list = copy(var)
1769 sim_keys = list(history.keys())
1770 for name in var_list:
1771 if name not in sim_keys:
1772 raise KeyError(
1773 "Variable called " + name + " not found in simulation data!"
1774 )
1775 if by_age:
1776 return self._export_single_t_by_age(history, var_list, age, t, dtype)
1777 else:
1778 return self._export_single_t_by_time(history, var_list, t, dtype)
1780 def sim_one_period(self):
1781 """
1782 Simulates one period for this type. Calls the methods get_mortality(), get_shocks() or
1783 read_shocks, get_states(), get_controls(), and get_poststates(). These should be defined for
1784 AgentType subclasses, except get_mortality (define its components sim_death and sim_birth
1785 instead) and read_shocks.
1787 Parameters
1788 ----------
1789 None
1791 Returns
1792 -------
1793 None
1794 """
1795 if not hasattr(self, "solution"):
1796 raise Exception(
1797 "Model instance does not have a solution stored. To simulate, it is necessary"
1798 " to run the `solve()` method first."
1799 )
1801 # Mortality adjusts the agent population
1802 self.get_mortality() # Replace some agents with "newborns"
1804 # state_{t-1}
1805 for var in self.state_now:
1806 self.state_prev[var] = self.state_now[var]
1808 if isinstance(self.state_now[var], np.ndarray):
1809 self.state_now[var] = np.empty(self.AgentCount)
1810 else:
1811 # Probably an aggregate variable. It may be getting set by the Market.
1812 pass
1814 if self.read_shocks: # If shock histories have been pre-specified, use those
1815 self.read_shocks_from_history()
1816 else: # Otherwise, draw shocks as usual according to subclass-specific method
1817 self.get_shocks()
1818 self.get_states() # Determine each agent's state at decision time
1819 self.get_controls() # Determine each agent's choice or control variables based on states
1820 self.get_poststates() # Calculate variables that come *after* decision-time
1822 # Advance time for all agents
1823 self.t_age = self.t_age + 1 # Age all consumers by one period
1824 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle
1825 self.t_cycle[self.t_cycle == self.T_cycle] = (
1826 0 # Resetting to zero for those who have reached the end
1827 )
1829 def make_shock_history(self):
1830 """
1831 Makes a pre-specified history of shocks for the simulation. Shock variables should be named
1832 in self.shock_vars, a list of strings that is subclass-specific. This method runs a subset
1833 of the standard simulation loop by simulating only mortality and shocks; each variable named
1834 in shock_vars is stored in a T_sim x AgentCount array in history dictionary self.history[X].
1835 Automatically sets self.read_shocks to True so that these pre-specified shocks are used for
1836 all subsequent calls to simulate().
1838 Parameters
1839 ----------
1840 None
1842 Returns
1843 -------
1844 None
1845 """
1846 # Re-initialize the simulation
1847 self.initialize_sim()
1849 # Make blank history arrays for each shock variable (and mortality)
1850 for var_name in self.shock_vars:
1851 self.shock_history[var_name] = (
1852 np.zeros((self.T_sim, self.AgentCount)) + np.nan
1853 )
1854 self.shock_history["who_dies"] = np.zeros(
1855 (self.T_sim, self.AgentCount), dtype=bool
1856 )
1858 # Also make blank arrays for the draws of newborns' initial conditions
1859 for var_name in self.state_vars:
1860 self.newborn_init_history[var_name] = (
1861 np.zeros((self.T_sim, self.AgentCount)) + np.nan
1862 )
1864 # Record the initial condition of the newborns created by
1865 # initialize_sim -> sim_births
1866 for var_name in self.state_vars:
1867 # Check whether the state is idiosyncratic or an aggregate
1868 idio = (
1869 isinstance(self.state_now[var_name], np.ndarray)
1870 and len(self.state_now[var_name]) == self.AgentCount
1871 )
1872 if idio:
1873 self.newborn_init_history[var_name][self.t_sim] = self.state_now[
1874 var_name
1875 ]
1876 else:
1877 # Aggregate state is a scalar. Assign it to every agent.
1878 self.newborn_init_history[var_name][self.t_sim, :] = self.state_now[
1879 var_name
1880 ]
1882 # Make and store the history of shocks for each period
1883 for t in range(self.T_sim):
1884 # Deaths
1885 self.get_mortality()
1886 self.shock_history["who_dies"][t, :] = self.who_dies
1888 # Initial conditions of newborns
1889 if self.who_dies.any():
1890 for var_name in self.state_vars:
1891 # Check whether the state is idiosyncratic or an aggregate
1892 idio = (
1893 isinstance(self.state_now[var_name], np.ndarray)
1894 and len(self.state_now[var_name]) == self.AgentCount
1895 )
1896 if idio:
1897 self.newborn_init_history[var_name][t, self.who_dies] = (
1898 self.state_now[var_name][self.who_dies]
1899 )
1900 else:
1901 self.newborn_init_history[var_name][t, self.who_dies] = (
1902 self.state_now[var_name]
1903 )
1905 # Other Shocks
1906 self.get_shocks()
1907 for var_name in self.shock_vars:
1908 self.shock_history[var_name][t, :] = self.shocks[var_name]
1910 self.t_sim += 1
1911 self.t_age = self.t_age + 1 # Age all consumers by one period
1912 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle
1913 self.t_cycle[self.t_cycle == self.T_cycle] = (
1914 0 # Resetting to zero for those who have reached the end
1915 )
1917 # Flag that shocks can be read rather than simulated
1918 self.read_shocks = True
1920 def get_mortality(self):
1921 """
1922 Simulates mortality or agent turnover according to some model-specific rules named sim_death
1923 and sim_birth (methods of an AgentType subclass). sim_death takes no arguments and returns
1924 a Boolean array of size AgentCount, indicating which agents of this type have "died" and
1925 must be replaced. sim_birth takes such a Boolean array as an argument and generates initial
1926 post-decision states for those agent indices.
1928 Parameters
1929 ----------
1930 None
1932 Returns
1933 -------
1934 None
1935 """
1936 if self.read_shocks:
1937 who_dies = self.shock_history["who_dies"][self.t_sim, :]
1938 # Instead of simulating births, assign the saved newborn initial conditions
1939 if who_dies.any():
1940 for var_name in self.state_now:
1941 if var_name in self.newborn_init_history.keys():
1942 # Copy only array-like idiosyncratic states. Aggregates should
1943 # not be set by newborns
1944 idio = (
1945 isinstance(self.state_now[var_name], np.ndarray)
1946 and len(self.state_now[var_name]) == self.AgentCount
1947 )
1948 if idio:
1949 self.state_now[var_name][who_dies] = (
1950 self.newborn_init_history[var_name][
1951 self.t_sim, who_dies
1952 ]
1953 )
1955 else:
1956 warn(
1957 "The option for reading shocks was activated but "
1958 + "the model requires state "
1959 + var_name
1960 + ", not contained in "
1961 + "newborn_init_history."
1962 )
1964 # Reset ages of newborns
1965 self.t_age[who_dies] = 0
1966 self.t_cycle[who_dies] = 0
1967 else:
1968 who_dies = self.sim_death()
1969 self.sim_birth(who_dies)
1970 self.who_dies = who_dies
1971 return None
1973 def sim_death(self):
1974 """
1975 Determines which agents in the current population "die" or should be replaced. Takes no
1976 inputs, returns a Boolean array of size self.AgentCount, which has True for agents who die
1977 and False for those that survive. Returns all False by default, must be overwritten by a
1978 subclass to have replacement events.
1980 Parameters
1981 ----------
1982 None
1984 Returns
1985 -------
1986 who_dies : np.array
1987 Boolean array of size self.AgentCount indicating which agents die and are replaced.
1988 """
1989 who_dies = np.zeros(self.AgentCount, dtype=bool)
1990 return who_dies
1992 def sim_birth(self, which_agents): # pragma: nocover
1993 """
1994 Makes new agents for the simulation. Takes a boolean array as an input, indicating which
1995 agent indices are to be "born". Does nothing by default, must be overwritten by a subclass.
1997 Parameters
1998 ----------
1999 which_agents : np.array(Bool)
2000 Boolean array of size self.AgentCount indicating which agents should be "born".
2002 Returns
2003 -------
2004 None
2005 """
2006 raise Exception("AgentType subclass must define method sim_birth!")
2008 def get_shocks(self): # pragma: nocover
2009 """
2010 Gets values of shock variables for the current period. Does nothing by default, but can
2011 be overwritten by subclasses of AgentType.
2013 Parameters
2014 ----------
2015 None
2017 Returns
2018 -------
2019 None
2020 """
2021 return None
2023 def read_shocks_from_history(self):
2024 """
2025 Reads values of shock variables for the current period from history arrays.
2026 For each variable X named in self.shock_vars, this attribute of self is
2027 set to self.history[X][self.t_sim,:].
2029 This method is only ever called if self.read_shocks is True. This can
2030 be achieved by using the method make_shock_history() (or manually after
2031 storing a "handcrafted" shock history).
2033 Parameters
2034 ----------
2035 None
2037 Returns
2038 -------
2039 None
2040 """
2041 for var_name in self.shock_vars:
2042 self.shocks[var_name] = self.shock_history[var_name][self.t_sim, :]
2044 def get_states(self):
2045 """
2046 Gets values of state variables for the current period.
2047 By default, calls transition function and assigns values
2048 to the state_now dictionary.
2050 Parameters
2051 ----------
2052 None
2054 Returns
2055 -------
2056 None
2057 """
2058 new_states = self.transition()
2060 for i, var in enumerate(self.state_now):
2061 # a hack for now to deal with 'post-states'
2062 if i < len(new_states):
2063 self.state_now[var] = new_states[i]
2065 def transition(self): # pragma: nocover
2066 """
2068 Parameters
2069 ----------
2070 None
2072 [Eventually, to match dolo spec:
2073 exogenous_prev, endogenous_prev, controls, exogenous, parameters]
2075 Returns
2076 -------
2078 endogenous_state: ()
2079 Tuple with new values of the endogenous states
2080 """
2081 return ()
2083 def get_controls(self): # pragma: nocover
2084 """
2085 Gets values of control variables for the current period, probably by using current states.
2086 Does nothing by default, but can be overwritten by subclasses of AgentType.
2088 Parameters
2089 ----------
2090 None
2092 Returns
2093 -------
2094 None
2095 """
2096 return None
2098 def get_poststates(self):
2099 """
2100 Gets values of post-decision state variables for the current period,
2101 probably by current
2102 states and controls and maybe market-level events or shock variables.
2103 Does nothing by
2104 default, but can be overwritten by subclasses of AgentType.
2106 Parameters
2107 ----------
2108 None
2110 Returns
2111 -------
2112 None
2113 """
2114 return None
2116 def symulate(self, T=None):
2117 """
2118 Run the new simulation structure, with history results written to the
2119 hystory attribute of self.
2120 """
2121 self._simulator.simulate(T)
2122 self.hystory = self._simulator.history
2124 def describe_model(self, display=True):
2125 """
2126 Print to screen information about this agent's model, based on its model
2127 file. This is useful for learning about outcome variable names for tracking
2128 during simulation, or for use with sequence space Jacobians.
2129 """
2130 if not hasattr(self, "_simulator"):
2131 self.initialize_sym()
2132 self._simulator.describe(display=display)
2134 def simulate(self, sim_periods=None):
2135 """
2136 Simulates this agent type for a given number of periods. Defaults to self.T_sim,
2137 or all remaining periods to simulate (T_sim - t_sim). Records histories of
2138 attributes named in self.track_vars in self.history[varname].
2140 Parameters
2141 ----------
2142 sim_periods : int or None
2143 Number of periods to simulate. Default is all remaining periods (usually T_sim).
2145 Returns
2146 -------
2147 history : dict
2148 The history tracked during the simulation.
2149 """
2150 if not hasattr(self, "t_sim"):
2151 raise Exception(
2152 "It seems that the simulation variables were not initialize before calling "
2153 + "simulate(). Call initialize_sim() to initialize the variables before calling simulate() again."
2154 )
2156 if not hasattr(self, "T_sim"):
2157 raise Exception(
2158 "This agent type instance must have the attribute T_sim set to a positive integer."
2159 + "Set T_sim to match the largest dataset you might simulate, and run this agent's"
2160 + "initialize_sim() method before running simulate() again."
2161 )
2163 if sim_periods is not None and self.T_sim < sim_periods:
2164 raise Exception(
2165 "To simulate, sim_periods has to be larger than the maximum data set size "
2166 + "T_sim. Either increase the attribute T_sim of this agent type instance "
2167 + "and call the initialize_sim() method again, or set sim_periods <= T_sim."
2168 )
2170 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep-
2171 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is
2172 # -np.inf, np.inf/np.inf is np.nan and so on.
2173 with np.errstate(
2174 divide="ignore", over="ignore", under="ignore", invalid="ignore"
2175 ):
2176 if sim_periods is None:
2177 sim_periods = self.T_sim - self.t_sim
2179 for t in range(sim_periods):
2180 self.sim_one_period()
2182 for var_name in self.track_vars:
2183 if var_name in self.state_now:
2184 self.history[var_name][self.t_sim, :] = self.state_now[var_name]
2185 elif var_name in self.shocks:
2186 self.history[var_name][self.t_sim, :] = self.shocks[var_name]
2187 elif var_name in self.controls:
2188 self.history[var_name][self.t_sim, :] = self.controls[var_name]
2189 else:
2190 if var_name == "who_dies" and self.t_sim > 1:
2191 self.history[var_name][self.t_sim - 1, :] = getattr(
2192 self, var_name
2193 )
2194 else:
2195 self.history[var_name][self.t_sim, :] = getattr(
2196 self, var_name
2197 )
2198 self.t_sim += 1
2200 def clear_history(self):
2201 """
2202 Clears the histories of the attributes named in self.track_vars.
2204 Parameters
2205 ----------
2206 None
2208 Returns
2209 -------
2210 None
2211 """
2212 for var_name in self.track_vars:
2213 self.history[var_name] = np.empty((self.T_sim, self.AgentCount))
2214 self.history[var_name].fill(np.nan)
2216 def make_basic_SSJ(self, shock, outcomes, grids, **kwargs):
2217 """
2218 Construct and return sequence space Jacobian matrices for specified outcomes
2219 with respect to specified "shock" variable. This "basic" method only works
2220 for "one period infinite horizon" models (cycles=0, T_cycle=1) and for life-
2221 cycle models (cycles=1). See documentation for simulator.make_basic_SSJ_matrices
2222 and simulator.make_flat_LC_SSJ_matrices for more information.
2223 """
2224 if (self.cycles == 0) and (self.T_cycle == 1):
2225 return make_basic_SSJ_matrices(self, shock, outcomes, grids, **kwargs)
2226 elif self.cycles == 1:
2227 return make_flat_LC_SSJ_matrices(self, shock, outcomes, grids, **kwargs)
2228 else:
2229 raise ValueError(
2230 "Can only make HA-SSJ matrices for infinite horizon or life-cycle models!"
2231 )
2233 def calc_impulse_response_manually(self, shock, outcomes, grids, **kwargs):
2234 """
2235 Calculate and return the impulse response(s) of a perturbation to the shock
2236 parameter in period t=s, essentially computing one column of the sequence
2237 space Jacobian matrix manually. This "basic" method only works for "one
2238 period infinite horizon" models (cycles=0, T_cycle=1). See documentation
2239 for simulator.calc_shock_response_manually for more information.
2240 """
2241 return calc_shock_response_manually(self, shock, outcomes, grids, **kwargs)
2244def solve_agent(agent, verbose, from_solution=None, from_t=None):
2245 """
2246 Solve the dynamic model for one agent type using backwards induction. This
2247 function iterates on "cycles" of an agent's model either a given number of
2248 times or until solution convergence if an infinite horizon model is used
2249 (with agent.cycles = 0).
2251 Parameters
2252 ----------
2253 agent : AgentType
2254 The microeconomic AgentType whose dynamic problem
2255 is to be solved.
2256 verbose : boolean
2257 If True, solution progress is printed to screen (when cycles != 1).
2258 from_solution: Solution
2259 If different from None, will be used as the starting point of backward
2260 induction, instead of self.solution_terminal
2261 from_t : int or None
2262 If not None, indicates which period of the model the solver should start
2263 from. It should usually only be used in combination with from_solution.
2264 Stands for the time index that from_solution represents, and thus is
2265 only compatible with cycles=1 and will be reset to None otherwise.
2267 Returns
2268 -------
2269 solution : [Solution]
2270 A list of solutions to the one period problems that the agent will
2271 encounter in his "lifetime".
2272 """
2273 # Check to see whether this is an (in)finite horizon problem
2274 cycles_left = agent.cycles # NOQA
2275 infinite_horizon = cycles_left == 0 # NOQA
2277 if from_solution is None:
2278 solution_last = agent.solution_terminal # NOQA
2279 else:
2280 solution_last = from_solution
2281 if agent.cycles != 1:
2282 from_t = None
2284 # Initialize the solution, which includes the terminal solution if it's not a pseudo-terminal period
2285 solution = []
2286 if not agent.pseudo_terminal:
2287 solution.insert(0, deepcopy(solution_last))
2289 # Initialize the process, then loop over cycles
2290 go = True # NOQA
2291 completed_cycles = 0 # NOQA
2292 max_cycles = 5000 # NOQA - escape clause
2293 if verbose:
2294 t_last = time()
2295 while go:
2296 # Solve a cycle of the model, recording it if horizon is finite
2297 solution_cycle = solve_one_cycle(agent, solution_last, from_t)
2298 if not infinite_horizon:
2299 solution = solution_cycle + solution
2301 # Check for termination: identical solutions across
2302 # cycle iterations or run out of cycles
2303 solution_now = solution_cycle[0]
2304 if infinite_horizon:
2305 if completed_cycles > 0:
2306 solution_distance = distance_metric(solution_now, solution_last)
2307 agent.solution_distance = (
2308 solution_distance # Add these attributes so users can
2309 )
2310 agent.completed_cycles = (
2311 completed_cycles # query them to see if solution is ready
2312 )
2313 go = (
2314 solution_distance > agent.tolerance
2315 and completed_cycles < max_cycles
2316 )
2317 else: # Assume solution does not converge after only one cycle
2318 solution_distance = 100.0
2319 go = True
2320 else:
2321 cycles_left += -1
2322 go = cycles_left > 0
2324 # Update the "last period solution"
2325 solution_last = solution_now
2326 completed_cycles += 1
2328 # Display progress if requested
2329 if verbose:
2330 t_now = time()
2331 if infinite_horizon:
2332 print(
2333 "Finished cycle #"
2334 + str(completed_cycles)
2335 + " in "
2336 + "{:.6f}".format(t_now - t_last)
2337 + " seconds, solution distance = "
2338 + str(solution_distance)
2339 )
2340 else:
2341 print(
2342 "Finished cycle #"
2343 + str(completed_cycles)
2344 + " of "
2345 + str(agent.cycles)
2346 + " in "
2347 + str(t_now - t_last)
2348 + " seconds."
2349 )
2350 t_last = t_now
2352 # Record the last cycle if horizon is infinite (solution is still empty!)
2353 if infinite_horizon:
2354 solution = (
2355 solution_cycle # PseudoTerminal=False impossible for infinite horizon
2356 )
2358 return solution
2361def solve_one_cycle(agent, solution_last, from_t):
2362 """
2363 Solve one "cycle" of the dynamic model for one agent type. This function
2364 iterates over the periods within an agent's cycle, updating the time-varying
2365 parameters and passing them to the single period solver(s).
2367 Parameters
2368 ----------
2369 agent : AgentType
2370 The microeconomic AgentType whose dynamic problem is to be solved.
2371 solution_last : Solution
2372 A representation of the solution of the period that comes after the
2373 end of the sequence of one period problems. This might be the term-
2374 inal period solution, a "pseudo terminal" solution, or simply the
2375 solution to the earliest period from the succeeding cycle.
2376 from_t : int or None
2377 If not None, indicates which period of the model the solver should start
2378 from. When used, represents the time index that solution_last is from.
2380 Returns
2381 -------
2382 solution_cycle : [Solution]
2383 A list of one period solutions for one "cycle" of the AgentType's
2384 microeconomic model.
2385 """
2387 # Check if the agent has a 'Parameters' attribute of the 'Parameters' class
2388 # if so, take advantage of it. Else, use the old method
2389 if hasattr(agent, "parameters") and isinstance(agent.parameters, Parameters):
2390 T = agent.parameters._length if from_t is None else from_t
2392 # Initialize the solution for this cycle, then iterate on periods
2393 solution_cycle = []
2394 solution_next = solution_last
2396 cycles_range = [0] + list(range(T - 1, 0, -1))
2397 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range:
2398 # Update which single period solver to use (if it depends on time)
2399 if hasattr(agent.solve_one_period, "__getitem__"):
2400 solve_one_period = agent.solve_one_period[k]
2401 else:
2402 solve_one_period = agent.solve_one_period
2404 if hasattr(solve_one_period, "solver_args"):
2405 these_args = solve_one_period.solver_args
2406 else:
2407 these_args = get_arg_names(solve_one_period)
2409 # Make a temporary dictionary for this period
2410 temp_pars = agent.parameters[k]
2411 temp_dict = {
2412 name: solution_next if name == "solution_next" else temp_pars[name]
2413 for name in these_args
2414 }
2416 # Solve one period, add it to the solution, and move to the next period
2417 solution_t = solve_one_period(**temp_dict)
2418 solution_cycle.insert(0, solution_t)
2419 solution_next = solution_t
2421 else:
2422 # Calculate number of periods per cycle, defaults to 1 if all variables are time invariant
2423 if len(agent.time_vary) > 0:
2424 T = agent.T_cycle if from_t is None else from_t
2425 else:
2426 T = 1
2428 solve_dict = {
2429 parameter: agent.__dict__[parameter] for parameter in agent.time_inv
2430 }
2431 solve_dict.update({parameter: None for parameter in agent.time_vary})
2433 # Initialize the solution for this cycle, then iterate on periods
2434 solution_cycle = []
2435 solution_next = solution_last
2437 cycles_range = [0] + list(range(T - 1, 0, -1))
2438 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range:
2439 # Update which single period solver to use (if it depends on time)
2440 if hasattr(agent.solve_one_period, "__getitem__"):
2441 solve_one_period = agent.solve_one_period[k]
2442 else:
2443 solve_one_period = agent.solve_one_period
2445 if hasattr(solve_one_period, "solver_args"):
2446 these_args = solve_one_period.solver_args
2447 else:
2448 these_args = get_arg_names(solve_one_period)
2450 # Update time-varying single period inputs
2451 for name in agent.time_vary:
2452 if name in these_args:
2453 solve_dict[name] = agent.__dict__[name][k]
2454 solve_dict["solution_next"] = solution_next
2456 # Make a temporary dictionary for this period
2457 temp_dict = {name: solve_dict[name] for name in these_args}
2459 # Solve one period, add it to the solution, and move to the next period
2460 solution_t = solve_one_period(**temp_dict)
2461 solution_cycle.insert(0, solution_t)
2462 solution_next = solution_t
2464 # Return the list of per-period solutions
2465 return solution_cycle
2468def make_one_period_oo_solver(solver_class):
2469 """
2470 Returns a function that solves a single period consumption-saving
2471 problem.
2472 Parameters
2473 ----------
2474 solver_class : Solver
2475 A class of Solver to be used.
2476 -------
2477 solver_function : function
2478 A function for solving one period of a problem.
2479 """
2481 def one_period_solver(**kwds):
2482 solver = solver_class(**kwds)
2484 # not ideal; better if this is defined in all Solver classes
2485 if hasattr(solver, "prepare_to_solve"):
2486 solver.prepare_to_solve()
2488 solution_now = solver.solve()
2489 return solution_now
2491 one_period_solver.solver_class = solver_class
2492 # This can be revisited once it is possible to export parameters
2493 one_period_solver.solver_args = get_arg_names(solver_class.__init__)[1:]
2495 return one_period_solver
2498# ========================================================================
2499# ========================================================================
2502class Market(Model):
2503 """
2504 A superclass to represent a central clearinghouse of information. Used for
2505 dynamic general equilibrium models to solve the "macroeconomic" model as a
2506 layer on top of the "microeconomic" models of one or more AgentTypes.
2508 Parameters
2509 ----------
2510 agents : [AgentType]
2511 A list of all the AgentTypes in this market.
2512 sow_vars : [string]
2513 Names of variables generated by the "aggregate market process" that should
2514 be "sown" to the agents in the market. Aggregate state, etc.
2515 reap_vars : [string]
2516 Names of variables to be collected ("reaped") from agents in the market
2517 to be used in the "aggregate market process".
2518 const_vars : [string]
2519 Names of attributes of the Market instance that are used in the "aggregate
2520 market process" but do not come from agents-- they are constant or simply
2521 parameters inherent to the process.
2522 track_vars : [string]
2523 Names of variables generated by the "aggregate market process" that should
2524 be tracked as a "history" so that a new dynamic rule can be calculated.
2525 This is often a subset of sow_vars.
2526 dyn_vars : [string]
2527 Names of variables that constitute a "dynamic rule".
2528 mill_rule : function
2529 A function that takes inputs named in reap_vars and returns a tuple the
2530 same size and order as sow_vars. The "aggregate market process" that
2531 transforms individual agent actions/states/data into aggregate data to
2532 be sent back to agents.
2533 calc_dynamics : function
2534 A function that takes inputs named in track_vars and returns an object
2535 with attributes named in dyn_vars. Looks at histories of aggregate
2536 variables and generates a new "dynamic rule" for agents to believe and
2537 act on.
2538 act_T : int
2539 The number of times that the "aggregate market process" should be run
2540 in order to generate a history of aggregate variables.
2541 tolerance: float
2542 Minimum acceptable distance between "dynamic rules" to consider the
2543 Market solution process converged. Distance is a user-defined metric.
2544 """
2546 def __init__(
2547 self,
2548 agents=None,
2549 sow_vars=None,
2550 reap_vars=None,
2551 const_vars=None,
2552 track_vars=None,
2553 dyn_vars=None,
2554 mill_rule=None,
2555 calc_dynamics=None,
2556 distributions=None,
2557 act_T=1000,
2558 tolerance=0.000001,
2559 seed=0,
2560 **kwds,
2561 ):
2562 super().__init__()
2563 self.agents = agents if agents is not None else list() # NOQA
2565 self.reap_vars = reap_vars if reap_vars is not None else list() # NOQA
2566 self.reap_state = {var: [] for var in self.reap_vars}
2568 self.sow_vars = sow_vars if sow_vars is not None else list() # NOQA
2569 # dictionaries for tracking initial and current values
2570 # of the sow variables.
2571 self.sow_init = {var: None for var in self.sow_vars}
2572 self.sow_state = {var: None for var in self.sow_vars}
2574 const_vars = const_vars if const_vars is not None else list() # NOQA
2575 self.const_vars = {var: None for var in const_vars}
2577 self.track_vars = track_vars if track_vars is not None else list() # NOQA
2578 self.dyn_vars = dyn_vars if dyn_vars is not None else list() # NOQA
2579 self.distributions = distributions if distributions is not None else list() # NOQA
2581 if mill_rule is not None: # To prevent overwriting of method-based mill_rules
2582 self.mill_rule = mill_rule
2583 if calc_dynamics is not None: # Ditto for calc_dynamics
2584 self.calc_dynamics = calc_dynamics
2585 self.act_T = act_T # NOQA
2586 self.tolerance = tolerance # NOQA
2587 self.seed = seed
2588 self.max_loops = 1000 # NOQA
2589 self.history = {}
2590 self.assign_parameters(**kwds)
2591 self.RNG = np.random.default_rng(self.seed)
2593 self.print_parallel_error_once = True
2594 # Print the error associated with calling the parallel method
2595 # "solve_agents" one time. If set to false, the error will never
2596 # print. See "solve_agents" for why this prints once or never.
2598 def give_agent_params(self, construct=True):
2599 """
2600 Distribute relevant market-level parameters to each AgentType in self.agents
2601 by having them call their get_market_params method.
2603 Parameters
2604 ----------
2605 construct : bool, optional
2606 Whether agents should run their construct method after fetching market
2607 data (default True).
2609 Returns
2610 -------
2611 None
2612 """
2613 for agent in self.agents:
2614 agent.get_market_params(self, construct)
2616 def solve_agents(self):
2617 """
2618 Solves the microeconomic problem for all AgentTypes in this market.
2620 Parameters
2621 ----------
2622 None
2624 Returns
2625 -------
2626 None
2627 """
2628 try:
2629 multi_thread_commands(self.agents, ["solve()"])
2630 except Exception as err:
2631 if self.print_parallel_error_once:
2632 # Set flag to False so this is only printed once.
2633 self.print_parallel_error_once = False
2634 print(
2635 "**** WARNING: could not execute multi_thread_commands in HARK.core.Market.solve_agents() ",
2636 "so using the serial version instead. This will likely be slower. "
2637 "The multi_thread_commands() functions failed with the following error:",
2638 "\n",
2639 sys.exc_info()[0],
2640 ":",
2641 err,
2642 ) # sys.exc_info()[0])
2643 multi_thread_commands_fake(self.agents, ["solve()"])
2645 def solve(self):
2646 """
2647 "Solves" the market by finding a "dynamic rule" that governs the aggregate
2648 market state such that when agents believe in these dynamics, their actions
2649 collectively generate the same dynamic rule.
2651 Parameters
2652 ----------
2653 None
2655 Returns
2656 -------
2657 None
2658 """
2659 go = True
2660 max_loops = self.max_loops # Failsafe against infinite solution loop
2661 completed_loops = 0
2662 old_dynamics = None
2664 while go: # Loop until the dynamic process converges or we hit the loop cap
2665 self.solve_agents() # Solve each AgentType's micro problem
2666 self.make_history() # "Run" the model while tracking aggregate variables
2667 new_dynamics = self.update_dynamics() # Find a new aggregate dynamic rule
2669 # Check to see if the dynamic rule has converged (if this is not the first loop)
2670 if completed_loops > 0:
2671 distance = new_dynamics.distance(old_dynamics)
2672 else:
2673 distance = 1000000.0
2675 # Move to the next loop if the terminal conditions are not met
2676 old_dynamics = new_dynamics
2677 completed_loops += 1
2678 go = distance >= self.tolerance and completed_loops < max_loops
2680 self.dynamics = new_dynamics # Store the final dynamic rule in self
2682 def reap(self):
2683 """
2684 Collects attributes named in reap_vars from each AgentType in the market,
2685 storing them in respectively named attributes of self.
2687 Parameters
2688 ----------
2689 none
2691 Returns
2692 -------
2693 none
2694 """
2695 for var in self.reap_state:
2696 harvest = []
2698 for agent in self.agents:
2699 # TODO: generalized variable lookup across namespaces
2700 if var in agent.state_now:
2701 # or state_now ??
2702 harvest.append(agent.state_now[var])
2704 self.reap_state[var] = harvest
2706 def sow(self):
2707 """
2708 Distributes attrributes named in sow_vars from self to each AgentType
2709 in the market, storing them in respectively named attributes.
2711 Parameters
2712 ----------
2713 none
2715 Returns
2716 -------
2717 none
2718 """
2719 for sow_var in self.sow_state:
2720 for this_type in self.agents:
2721 if sow_var in this_type.state_now:
2722 this_type.state_now[sow_var] = self.sow_state[sow_var]
2723 if sow_var in this_type.shocks:
2724 this_type.shocks[sow_var] = self.sow_state[sow_var]
2725 else:
2726 setattr(this_type, sow_var, self.sow_state[sow_var])
2728 def mill(self):
2729 """
2730 Processes the variables collected from agents using the function mill_rule,
2731 storing the results in attributes named in aggr_sow.
2733 Parameters
2734 ----------
2735 none
2737 Returns
2738 -------
2739 none
2740 """
2741 # Make a dictionary of inputs for the mill_rule
2742 mill_dict = copy(self.reap_state)
2743 mill_dict.update(self.const_vars)
2745 # Run the mill_rule and store its output in self
2746 product = self.mill_rule(**mill_dict)
2748 for i, sow_var in enumerate(self.sow_state):
2749 self.sow_state[sow_var] = product[i]
2751 def cultivate(self):
2752 """
2753 Has each AgentType in agents perform their market_action method, using
2754 variables sown from the market (and maybe also "private" variables).
2755 The market_action method should store new results in attributes named in
2756 reap_vars to be reaped later.
2758 Parameters
2759 ----------
2760 none
2762 Returns
2763 -------
2764 none
2765 """
2766 for this_type in self.agents:
2767 this_type.market_action()
2769 def reset(self):
2770 """
2771 Reset the state of the market (attributes in sow_vars, etc) to some
2772 user-defined initial state, and erase the histories of tracked variables.
2773 Also resets the internal RNG so that draws can be reproduced.
2775 Parameters
2776 ----------
2777 none
2779 Returns
2780 -------
2781 none
2782 """
2783 # Reset internal RNG and distributions
2784 for name in self.distributions:
2785 if not hasattr(self, name):
2786 continue
2787 dstn = getattr(self, name)
2788 if isinstance(dstn, list):
2789 for D in dstn:
2790 D.reset()
2791 else:
2792 dstn.reset()
2794 # Reset the history of tracked variables
2795 self.history = {var_name: [] for var_name in self.track_vars}
2797 # Set the sow variables to their initial levels
2798 for var_name in self.sow_state:
2799 self.sow_state[var_name] = self.sow_init[var_name]
2801 # Reset each AgentType in the market
2802 for this_type in self.agents:
2803 this_type.reset()
2805 def store(self):
2806 """
2807 Record the current value of each variable X named in track_vars in an
2808 dictionary field named history[X].
2810 Parameters
2811 ----------
2812 none
2814 Returns
2815 -------
2816 none
2817 """
2818 for var_name in self.track_vars:
2819 if var_name in self.sow_state:
2820 value_now = self.sow_state[var_name]
2821 elif var_name in self.reap_state:
2822 value_now = self.reap_state[var_name]
2823 elif var_name in self.const_vars:
2824 value_now = self.const_vars[var_name]
2825 else:
2826 value_now = getattr(self, var_name)
2828 self.history[var_name].append(value_now)
2830 def make_history(self):
2831 """
2832 Runs a loop of sow-->cultivate-->reap-->mill act_T times, tracking the
2833 evolution of variables X named in track_vars in dictionary fields
2834 history[X].
2836 Parameters
2837 ----------
2838 none
2840 Returns
2841 -------
2842 none
2843 """
2844 self.reset() # Initialize the state of the market
2845 for t in range(self.act_T):
2846 self.sow() # Distribute aggregated information/state to agents
2847 self.cultivate() # Agents take action
2848 self.reap() # Collect individual data from agents
2849 self.mill() # Process individual data into aggregate data
2850 self.store() # Record variables of interest
2852 def update_dynamics(self):
2853 """
2854 Calculates a new "aggregate dynamic rule" using the history of variables
2855 named in track_vars, and distributes this rule to AgentTypes in agents.
2857 Parameters
2858 ----------
2859 none
2861 Returns
2862 -------
2863 dynamics : instance
2864 The new "aggregate dynamic rule" that agents believe in and act on.
2865 Should have attributes named in dyn_vars.
2866 """
2867 # Make a dictionary of inputs for the dynamics calculator
2868 arg_names = list(get_arg_names(self.calc_dynamics))
2869 if "self" in arg_names:
2870 arg_names.remove("self")
2871 update_dict = {}
2872 for name in arg_names:
2873 update_dict[name] = (
2874 self.history[name] if name in self.track_vars else getattr(self, name)
2875 )
2876 # Calculate a new dynamic rule and distribute it to the agents in agent_list
2877 dynamics = self.calc_dynamics(**update_dict) # User-defined dynamics calculator
2878 for var_name in self.dyn_vars:
2879 this_obj = getattr(dynamics, var_name)
2880 for this_type in self.agents:
2881 setattr(this_type, var_name, this_obj)
2882 return dynamics
2885def distribute_params(agent, param_name, param_count, distribution):
2886 """
2887 Distributes heterogeneous values of one parameter to the AgentTypes in self.agents.
2888 Parameters
2889 ----------
2890 agent: AgentType
2891 An agent to clone.
2892 param_name : string
2893 Name of the parameter to be assigned.
2894 param_count : int
2895 Number of different values the parameter will take on.
2896 distribution : Distribution
2897 A 1-D distribution.
2899 Returns
2900 -------
2901 agent_set : [AgentType]
2902 A list of param_count agents, ex ante heterogeneous with
2903 respect to param_name. The AgentCount of the original
2904 will be split between the agents of the returned
2905 list in proportion to the given distribution.
2906 """
2907 param_dist = distribution.discretize(N=param_count)
2909 agent_set = [deepcopy(agent) for i in range(param_count)]
2911 for j in range(param_count):
2912 agent_set[j].assign_parameters(
2913 **{"AgentCount": int(agent.AgentCount * param_dist.pmv[j])}
2914 )
2915 agent_set[j].assign_parameters(**{param_name: param_dist.atoms[0, j]})
2917 return agent_set
2920@dataclass
2921class AgentPopulation:
2922 """
2923 A class for representing a population of ex-ante heterogeneous agents.
2924 """
2926 agent_type: AgentType # type of agent in the population
2927 parameters: dict # dictionary of parameters
2928 seed: int = 0 # random seed
2929 time_var: List[str] = field(init=False)
2930 time_inv: List[str] = field(init=False)
2931 distributed_params: List[str] = field(init=False)
2932 agent_type_count: Optional[int] = field(init=False)
2933 term_age: Optional[int] = field(init=False)
2934 continuous_distributions: Dict[str, Distribution] = field(init=False)
2935 discrete_distributions: Dict[str, Distribution] = field(init=False)
2936 population_parameters: List[Dict[str, Union[List[float], float]]] = field(
2937 init=False
2938 )
2939 agents: List[AgentType] = field(init=False)
2940 agent_database: pd.DataFrame = field(init=False)
2941 solution: List[Any] = field(init=False)
2943 def __post_init__(self):
2944 """
2945 Initialize the population of agents, determine distributed parameters,
2946 and infer `agent_type_count` and `term_age`.
2947 """
2948 # create a dummy agent and obtain its time-varying
2949 # and time-invariant attributes
2950 dummy_agent = self.agent_type()
2951 self.time_var = dummy_agent.time_vary
2952 self.time_inv = dummy_agent.time_inv
2954 # create list of distributed parameters
2955 # these are parameters that differ across agents
2956 self.distributed_params = [
2957 key
2958 for key, param in self.parameters.items()
2959 if (isinstance(param, list) and isinstance(param[0], list))
2960 or isinstance(param, Distribution)
2961 or (isinstance(param, DataArray) and param.dims[0] == "agent")
2962 ]
2964 self.__infer_counts__()
2966 self.print_parallel_error_once = True
2967 # Print warning once if parallel simulation fails
2969 def __infer_counts__(self):
2970 """
2971 Infer `agent_type_count` and `term_age` from the parameters.
2972 If parameters include a `Distribution` type, a list of lists,
2973 or a `DataArray` with `agent` as the first dimension, then
2974 the AgentPopulation contains ex-ante heterogenous agents.
2975 """
2977 # infer agent_type_count from distributed parameters
2978 agent_type_count = 1
2979 for key in self.distributed_params:
2980 param = self.parameters[key]
2981 if isinstance(param, Distribution):
2982 agent_type_count = None
2983 warn(
2984 "Cannot infer agent_type_count from a Distribution. "
2985 "Please provide approximation parameters."
2986 )
2987 break
2988 elif isinstance(param, list):
2989 agent_type_count = max(agent_type_count, len(param))
2990 elif isinstance(param, DataArray) and param.dims[0] == "agent":
2991 agent_type_count = max(agent_type_count, param.shape[0])
2993 self.agent_type_count = agent_type_count
2995 # infer term_age from all parameters
2996 term_age = 1
2997 for param in self.parameters.values():
2998 if isinstance(param, Distribution):
2999 term_age = None
3000 warn(
3001 "Cannot infer term_age from a Distribution. "
3002 "Please provide approximation parameters."
3003 )
3004 break
3005 elif isinstance(param, list) and isinstance(param[0], list):
3006 term_age = max(term_age, len(param[0]))
3007 elif isinstance(param, DataArray) and param.dims[-1] == "age":
3008 term_age = max(term_age, param.shape[-1])
3010 self.term_age = term_age
3012 def approx_distributions(self, approx_params: dict):
3013 """
3014 Approximate continuous distributions with discrete ones. If the initial
3015 parameters include a `Distribution` type, then the AgentPopulation is
3016 not ready to solve, and stands for an abstract population. To solve the
3017 AgentPopulation, we need discretization parameters for each continuous
3018 distribution. This method approximates the continuous distributions with
3019 discrete ones, and updates the parameters dictionary.
3020 """
3021 self.continuous_distributions = {}
3022 self.discrete_distributions = {}
3024 for key, args in approx_params.items():
3025 param = self.parameters[key]
3026 if key in self.distributed_params and isinstance(param, Distribution):
3027 self.continuous_distributions[key] = param
3028 self.discrete_distributions[key] = param.discretize(**args)
3029 else:
3030 raise ValueError(
3031 f"Warning: parameter {key} is not a Distribution found "
3032 f"in agent type {self.agent_type}"
3033 )
3035 if len(self.discrete_distributions) > 1:
3036 joint_dist = combine_indep_dstns(*self.discrete_distributions.values())
3037 else:
3038 joint_dist = list(self.discrete_distributions.values())[0]
3040 for i, key in enumerate(self.discrete_distributions):
3041 self.parameters[key] = DataArray(joint_dist.atoms[i], dims=("agent"))
3043 self.__infer_counts__()
3045 def __parse_parameters__(self) -> None:
3046 """
3047 Creates distributed dictionaries of parameters for each ex-ante
3048 heterogeneous agent in the parameterized population. The parameters
3049 are stored in a list of dictionaries, where each dictionary contains
3050 the parameters for one agent. Expands parameters that vary over time
3051 to a list of length `term_age`.
3052 """
3054 population_parameters = [] # container for dictionaries of each agent subgroup
3055 for agent in range(self.agent_type_count):
3056 agent_parameters = {}
3057 for key, param in self.parameters.items():
3058 if key in self.time_var:
3059 # parameters that vary over time have to be repeated
3060 if isinstance(param, (int, float)):
3061 parameter_per_t = [param] * self.term_age
3062 elif isinstance(param, list):
3063 if isinstance(param[0], list):
3064 parameter_per_t = param[agent]
3065 else:
3066 parameter_per_t = param
3067 elif isinstance(param, DataArray):
3068 if param.dims[0] == "agent":
3069 if len(param.dims) > 1 and param.dims[-1] == "age":
3070 parameter_per_t = param[agent].values.tolist()
3071 else:
3072 parameter_per_t = param[agent].item()
3073 elif param.dims[0] == "age":
3074 parameter_per_t = param.values.tolist()
3076 agent_parameters[key] = parameter_per_t
3078 elif key in self.time_inv:
3079 if isinstance(param, (int, float)):
3080 agent_parameters[key] = param
3081 elif isinstance(param, list):
3082 if isinstance(param[0], list):
3083 agent_parameters[key] = param[agent]
3084 else:
3085 agent_parameters[key] = param
3086 elif isinstance(param, DataArray) and param.dims[0] == "agent":
3087 agent_parameters[key] = param[agent].item()
3089 else:
3090 if isinstance(param, (int, float)):
3091 agent_parameters[key] = param # assume time inv
3092 elif isinstance(param, list):
3093 if isinstance(param[0], list):
3094 agent_parameters[key] = param[agent] # assume agent vary
3095 else:
3096 agent_parameters[key] = param # assume time vary
3097 elif isinstance(param, DataArray):
3098 if param.dims[0] == "agent":
3099 if len(param.dims) > 1 and param.dims[-1] == "age":
3100 agent_parameters[key] = param[agent].values.tolist()
3101 else:
3102 agent_parameters[key] = param[agent].item()
3103 elif param.dims[0] == "age":
3104 agent_parameters[key] = param.values.tolist()
3106 population_parameters.append(agent_parameters)
3108 self.population_parameters = population_parameters
3110 def create_distributed_agents(self):
3111 """
3112 Parses the parameters dictionary and creates a list of agents with the
3113 appropriate parameters. Also sets the seed for each agent.
3114 """
3116 self.__parse_parameters__()
3118 rng = np.random.default_rng(self.seed)
3120 self.agents = [
3121 self.agent_type(seed=rng.integers(0, 2**31 - 1), **agent_dict)
3122 for agent_dict in self.population_parameters
3123 ]
3125 def create_database(self):
3126 """
3127 Optionally creates a pandas DataFrame with the parameters for each agent.
3128 """
3129 database = pd.DataFrame(self.population_parameters)
3130 database["agents"] = self.agents
3132 self.agent_database = database
3134 def solve(self):
3135 """
3136 Solves each agent of the population serially.
3137 """
3139 # see Market class for an example of how to solve distributed agents in parallel
3141 for agent in self.agents:
3142 agent.solve()
3144 def unpack_solutions(self):
3145 """
3146 Unpacks the solutions of each agent into an attribute of the population.
3147 """
3148 self.solution = [agent.solution for agent in self.agents]
3150 def initialize_sim(self):
3151 """
3152 Initializes the simulation for each agent.
3153 """
3154 for agent in self.agents:
3155 agent.initialize_sim()
3157 def simulate(self, num_jobs=None):
3158 """
3159 Simulates each agent of the population.
3161 Parameters
3162 ----------
3163 num_jobs : int, optional
3164 Number of parallel jobs to use. Defaults to using all available
3165 cores when ``None``. Falls back to serial execution if parallel
3166 processing fails.
3167 """
3168 try:
3169 multi_thread_commands(self.agents, ["simulate()"], num_jobs)
3170 except Exception as err:
3171 if getattr(self, "print_parallel_error_once", False):
3172 self.print_parallel_error_once = False
3173 print(
3174 "**** WARNING: could not execute multi_thread_commands in HARK.core.AgentPopulation.simulate() ",
3175 "so using the serial version instead. This will likely be slower. ",
3176 "The multi_thread_commands() function failed with the following error:\n",
3177 sys.exc_info()[0],
3178 ":",
3179 err,
3180 )
3181 multi_thread_commands_fake(self.agents, ["simulate()"], num_jobs)
3183 def __iter__(self):
3184 """
3185 Allows for iteration over the agents in the population.
3186 """
3187 return iter(self.agents)
3189 def __getitem__(self, idx):
3190 """
3191 Allows for indexing into the population.
3192 """
3193 return self.agents[idx]
3196###############################################################################
3199def multi_thread_commands_fake(
3200 agent_list: List, command_list: List, num_jobs=None
3201) -> None:
3202 """
3203 Executes the list of commands in command_list for each AgentType in agent_list
3204 in an ordinary, single-threaded loop. Each command should be a method of
3205 that AgentType subclass. This function exists so as to easily disable
3206 multithreading, as it uses the same syntax as multi_thread_commands.
3208 Parameters
3209 ----------
3210 agent_list : [AgentType]
3211 A list of instances of AgentType on which the commands will be run.
3212 command_list : [string]
3213 A list of commands to run for each AgentType.
3214 num_jobs : None
3215 Dummy input to match syntax of multi_thread_commands. Does nothing.
3217 Returns
3218 -------
3219 none
3220 """
3221 for agent in agent_list:
3222 for command in command_list:
3223 # Can pass method names with or without parentheses
3224 if command[-2:] == "()":
3225 getattr(agent, command[:-2])()
3226 else:
3227 getattr(agent, command)()
3230def multi_thread_commands(agent_list: List, command_list: List, num_jobs=None) -> None:
3231 """
3232 Executes the list of commands in command_list for each AgentType in agent_list
3233 using a multithreaded system. Each command should be a method of that AgentType subclass.
3235 Parameters
3236 ----------
3237 agent_list : [AgentType]
3238 A list of instances of AgentType on which the commands will be run.
3239 command_list : [string]
3240 A list of commands to run for each AgentType in agent_list.
3242 Returns
3243 -------
3244 None
3245 """
3246 if len(agent_list) == 1:
3247 multi_thread_commands_fake(agent_list, command_list)
3248 return None
3250 # Default number of parallel jobs is the smaller of number of AgentTypes in
3251 # the input and the number of available cores.
3252 if num_jobs is None:
3253 num_jobs = min(len(agent_list), multiprocessing.cpu_count())
3255 # Send each command in command_list to each of the types in agent_list to be run
3256 agent_list_out = Parallel(n_jobs=num_jobs)(
3257 delayed(run_commands)(*args)
3258 for args in zip(agent_list, len(agent_list) * [command_list])
3259 )
3261 # Replace the original types with the output from the parallel call
3262 for j in range(len(agent_list)):
3263 agent_list[j] = agent_list_out[j]
3266def run_commands(agent: Any, command_list: List) -> Any:
3267 """
3268 Executes each command in command_list on a given AgentType. The commands
3269 should be methods of that AgentType's subclass.
3271 Parameters
3272 ----------
3273 agent : AgentType
3274 An instance of AgentType on which the commands will be run.
3275 command_list : [string]
3276 A list of commands that the agent should run, as methods.
3278 Returns
3279 -------
3280 agent : AgentType
3281 The same AgentType instance passed as input, after running the commands.
3282 """
3283 for command in command_list:
3284 # Can pass method names with or without parentheses
3285 if command[-2:] == "()":
3286 getattr(agent, command[:-2])()
3287 else:
3288 getattr(agent, command)()
3289 return agent