Coverage for HARK / core.py: 96%
1088 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 05:31 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 05:31 +0000
1"""
2High-level functions and classes for solving a wide variety of economic models.
3The "core" of HARK is a framework for "microeconomic" and "macroeconomic"
4models. A micro model concerns the dynamic optimization problem for some type
5of agents, where agents take the inputs to their problem as exogenous. A macro
6model adds an additional layer, endogenizing some of the inputs to the micro
7problem by finding a general equilibrium dynamic rule.
8"""
10# Import basic modules
11import inspect
12import sys
13from collections import namedtuple
14from copy import copy, deepcopy
15from dataclasses import dataclass, field
16from time import time
17from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
18from warnings import warn
19import multiprocessing
20from joblib import Parallel, delayed
21from pandas import DataFrame
23import numpy as np
24import pandas as pd
25from xarray import DataArray
27from HARK.distributions import (
28 Distribution,
29 IndexDistribution,
30 combine_indep_dstns,
31)
32from HARK.utilities import NullFunc, get_arg_names, get_it_from
33from HARK.simulator import make_simulator_from_agent
34from HARK.SSJutils import (
35 make_basic_SSJ_matrices,
36 calc_shock_response_manually,
37)
38from HARK.metric import MetricObject, distance_metric
40__all__ = [
41 "AgentType",
42 "Market",
43 "Parameters",
44 "Model",
45 "AgentPopulation",
46 "multi_thread_commands",
47 "multi_thread_commands_fake",
48 "NullFunc",
49 "make_one_period_oo_solver",
50 "distribute_params",
51]
54class Parameters:
55 """
56 A smart container for model parameters that handles age-varying dynamics.
58 This class stores parameters as an internal dictionary and manages their
59 age-varying properties, providing both attribute-style and dictionary-style
60 access. It is designed to handle the time-varying dynamics of parameters
61 in economic models.
63 Attributes
64 ----------
65 _length : int
66 The terminal age of the agents in the model.
67 _invariant_params : Set[str]
68 A set of parameter names that are invariant over time.
69 _varying_params : Set[str]
70 A set of parameter names that vary over time.
71 _parameters : Dict[str, Any]
72 The internal dictionary storing all parameters.
73 """
75 __slots__ = (
76 "_length",
77 "_invariant_params",
78 "_varying_params",
79 "_parameters",
80 "_frozen",
81 "_namedtuple_cache",
82 )
84 def __init__(self, **parameters: Any) -> None:
85 """
86 Initialize a Parameters object and parse the age-varying dynamics of parameters.
88 Parameters
89 ----------
90 T_cycle : int, optional
91 The number of time periods in the model cycle (default: 1).
92 Must be >= 1.
93 frozen : bool, optional
94 If True, the Parameters object will be immutable after initialization
95 (default: False).
96 _time_inv : List[str], optional
97 List of parameter names to explicitly mark as time-invariant,
98 overriding automatic inference.
99 _time_vary : List[str], optional
100 List of parameter names to explicitly mark as time-varying,
101 overriding automatic inference.
102 **parameters : Any
103 Any number of parameters in the form key=value.
105 Raises
106 ------
107 ValueError
108 If T_cycle is less than 1.
110 Notes
111 -----
112 Automatic time-variance inference rules:
113 - Scalars (int, float, bool, None) are time-invariant
114 - NumPy arrays are time-invariant (use lists/tuples for time-varying)
115 - Single-element lists/tuples [x] are unwrapped to x and time-invariant
116 - Multi-element lists/tuples are time-varying if length matches T_cycle
117 - 2D arrays with first dimension matching T_cycle are time-varying
118 - Distributions and Callables are time-invariant
120 Use _time_inv or _time_vary to override automatic inference when needed.
121 """
122 # Extract special parameters
123 self._length: int = parameters.pop("T_cycle", 1)
124 frozen: bool = parameters.pop("frozen", False)
125 time_inv_override: List[str] = parameters.pop("_time_inv", [])
126 time_vary_override: List[str] = parameters.pop("_time_vary", [])
128 # Validate T_cycle
129 if self._length < 1:
130 raise ValueError(f"T_cycle must be >= 1, got {self._length}")
132 # Initialize internal state
133 self._invariant_params: Set[str] = set()
134 self._varying_params: Set[str] = set()
135 self._parameters: Dict[str, Any] = {"T_cycle": self._length}
136 self._frozen: bool = False # Set to False initially to allow setup
137 self._namedtuple_cache: Optional[type] = None
139 # Set parameters using automatic inference
140 for key, value in parameters.items():
141 self[key] = value
143 # Apply explicit overrides
144 for param in time_inv_override:
145 if param in self._parameters:
146 self._invariant_params.add(param)
147 self._varying_params.discard(param)
149 for param in time_vary_override:
150 if param in self._parameters:
151 self._varying_params.add(param)
152 self._invariant_params.discard(param)
154 # Freeze if requested
155 self._frozen = frozen
157 def __getitem__(self, item_or_key: Union[int, str]) -> Union["Parameters", Any]:
158 """
159 Access parameters by age index or parameter name.
161 If item_or_key is an integer, returns a Parameters object with the parameters
162 that apply to that age. This includes all invariant parameters and the
163 `item_or_key`th element of all age-varying parameters. If item_or_key is a
164 string, it returns the value of the parameter with that name.
166 Parameters
167 ----------
168 item_or_key : Union[int, str]
169 Age index or parameter name.
171 Returns
172 -------
173 Union[Parameters, Any]
174 A new Parameters object for the specified age, or the value of the
175 specified parameter.
177 Raises
178 ------
179 ValueError:
180 If the age index is out of bounds.
181 KeyError:
182 If the parameter name is not found.
183 TypeError:
184 If the key is neither an integer nor a string.
185 """
186 if isinstance(item_or_key, int):
187 if item_or_key < 0 or item_or_key >= self._length:
188 raise ValueError(
189 f"Age {item_or_key} is out of bounds (valid: 0-{self._length - 1})."
190 )
192 params = {key: self._parameters[key] for key in self._invariant_params}
193 params.update(
194 {
195 key: (
196 self._parameters[key][item_or_key]
197 if isinstance(self._parameters[key], (list, tuple, np.ndarray))
198 else self._parameters[key]
199 )
200 for key in self._varying_params
201 }
202 )
203 return Parameters(**params)
204 elif isinstance(item_or_key, str):
205 return self._parameters[item_or_key]
206 else:
207 raise TypeError("Key must be an integer (age) or string (parameter name).")
209 def __setitem__(self, key: str, value: Any) -> None:
210 """
211 Set parameter values, automatically inferring time variance.
213 If the parameter is a scalar, numpy array, boolean, distribution, callable
214 or None, it is assumed to be invariant over time. If the parameter is a
215 list or tuple, it is assumed to be varying over time. If the parameter
216 is a list or tuple of length greater than 1, the length of the list or
217 tuple must match the `_length` attribute of the Parameters object.
219 2D numpy arrays with first dimension matching T_cycle are treated as
220 time-varying parameters.
222 Parameters
223 ----------
224 key : str
225 Name of the parameter.
226 value : Any
227 Value of the parameter.
229 Raises
230 ------
231 ValueError:
232 If the parameter name is not a string or if the value type is unsupported.
233 If the parameter value is inconsistent with the current model length.
234 RuntimeError:
235 If the Parameters object is frozen.
236 """
237 if self._frozen:
238 raise RuntimeError("Cannot modify frozen Parameters object")
240 if not isinstance(key, str):
241 raise ValueError(f"Parameter name must be a string, got {type(key)}")
243 # Check for 2D numpy arrays with time-varying first dimension
244 if isinstance(value, np.ndarray) and value.ndim >= 2:
245 if value.shape[0] == self._length:
246 self._varying_params.add(key)
247 self._invariant_params.discard(key)
248 else:
249 self._invariant_params.add(key)
250 self._varying_params.discard(key)
251 elif isinstance(
252 value,
253 (
254 int,
255 float,
256 np.ndarray,
257 type(None),
258 Distribution,
259 bool,
260 Callable,
261 MetricObject,
262 ),
263 ):
264 self._invariant_params.add(key)
265 self._varying_params.discard(key)
266 elif isinstance(value, (list, tuple)):
267 if len(value) == 1:
268 value = value[0]
269 self._invariant_params.add(key)
270 self._varying_params.discard(key)
271 elif self._length is None or self._length == 1:
272 self._length = len(value)
273 self._varying_params.add(key)
274 self._invariant_params.discard(key)
275 elif len(value) == self._length:
276 self._varying_params.add(key)
277 self._invariant_params.discard(key)
278 else:
279 raise ValueError(
280 f"Parameter {key} must have length 1 or {self._length}, not {len(value)}"
281 )
282 else:
283 raise ValueError(f"Unsupported type for parameter {key}: {type(value)}")
285 self._parameters[key] = value
287 def __iter__(self) -> Iterator[str]:
288 """Allow iteration over parameter names."""
289 return iter(self._parameters)
291 def __len__(self) -> int:
292 """Return the number of parameters."""
293 return len(self._parameters)
295 def keys(self) -> Iterator[str]:
296 """Return a view of parameter names."""
297 return self._parameters.keys()
299 def values(self) -> Iterator[Any]:
300 """Return a view of parameter values."""
301 return self._parameters.values()
303 def items(self) -> Iterator[Tuple[str, Any]]:
304 """Return a view of parameter (name, value) pairs."""
305 return self._parameters.items()
307 def to_dict(self) -> Dict[str, Any]:
308 """
309 Convert parameters to a plain dictionary.
311 Returns
312 -------
313 Dict[str, Any]
314 A dictionary containing all parameters.
315 """
316 return dict(self._parameters)
318 def to_namedtuple(self) -> namedtuple:
319 """
320 Convert parameters to a namedtuple.
322 The namedtuple class is cached for efficiency on repeated calls.
324 Returns
325 -------
326 namedtuple
327 A namedtuple containing all parameters.
328 """
329 if self._namedtuple_cache is None:
330 self._namedtuple_cache = namedtuple("Parameters", self.keys())
331 return self._namedtuple_cache(**self.to_dict())
333 def update(self, other: Union["Parameters", Dict[str, Any]]) -> None:
334 """
335 Update parameters from another Parameters object or dictionary.
337 Parameters
338 ----------
339 other : Union[Parameters, Dict[str, Any]]
340 The source of parameters to update from.
342 Raises
343 ------
344 TypeError
345 If the input is neither a Parameters object nor a dictionary.
346 """
347 if isinstance(other, Parameters):
348 for key, value in other._parameters.items():
349 self[key] = value
350 elif isinstance(other, dict):
351 for key, value in other.items():
352 self[key] = value
353 else:
354 raise TypeError(
355 "Update source must be a Parameters object or a dictionary."
356 )
358 def __repr__(self) -> str:
359 """Return a detailed string representation of the Parameters object."""
360 return (
361 f"Parameters(_length={self._length}, "
362 f"_invariant_params={self._invariant_params}, "
363 f"_varying_params={self._varying_params}, "
364 f"_parameters={self._parameters})"
365 )
367 def __str__(self) -> str:
368 """Return a simple string representation of the Parameters object."""
369 return f"Parameters({str(self._parameters)})"
371 def __getattr__(self, name: str) -> Any:
372 """
373 Allow attribute-style access to parameters.
375 Parameters
376 ----------
377 name : str
378 Name of the parameter to access.
380 Returns
381 -------
382 Any
383 The value of the specified parameter.
385 Raises
386 ------
387 AttributeError:
388 If the parameter name is not found.
389 """
390 if name.startswith("_"):
391 return super().__getattribute__(name)
392 try:
393 return self._parameters[name]
394 except KeyError:
395 raise AttributeError(f"'Parameters' object has no attribute '{name}'")
397 def __setattr__(self, name: str, value: Any) -> None:
398 """
399 Allow attribute-style setting of parameters.
401 Parameters
402 ----------
403 name : str
404 Name of the parameter to set.
405 value : Any
406 Value to set for the parameter.
407 """
408 if name.startswith("_"):
409 super().__setattr__(name, value)
410 else:
411 self[name] = value
413 def __contains__(self, item: str) -> bool:
414 """Check if a parameter exists in the Parameters object."""
415 return item in self._parameters
417 def copy(self) -> "Parameters":
418 """
419 Create a deep copy of the Parameters object.
421 Returns
422 -------
423 Parameters
424 A new Parameters object with the same contents.
425 """
426 return deepcopy(self)
428 def add_to_time_vary(self, *params: str) -> None:
429 """
430 Adds any number of parameters to the time-varying set.
432 Parameters
433 ----------
434 *params : str
435 Any number of strings naming parameters to be added to time_vary.
436 """
437 for param in params:
438 if param in self._parameters:
439 self._varying_params.add(param)
440 self._invariant_params.discard(param)
441 else:
442 warn(
443 f"Parameter '{param}' does not exist and cannot be added to time_vary."
444 )
446 def add_to_time_inv(self, *params: str) -> None:
447 """
448 Adds any number of parameters to the time-invariant set.
450 Parameters
451 ----------
452 *params : str
453 Any number of strings naming parameters to be added to time_inv.
454 """
455 for param in params:
456 if param in self._parameters:
457 self._invariant_params.add(param)
458 self._varying_params.discard(param)
459 else:
460 warn(
461 f"Parameter '{param}' does not exist and cannot be added to time_inv."
462 )
464 def del_from_time_vary(self, *params: str) -> None:
465 """
466 Removes any number of parameters from the time-varying set.
468 Parameters
469 ----------
470 *params : str
471 Any number of strings naming parameters to be removed from time_vary.
472 """
473 for param in params:
474 self._varying_params.discard(param)
476 def del_from_time_inv(self, *params: str) -> None:
477 """
478 Removes any number of parameters from the time-invariant set.
480 Parameters
481 ----------
482 *params : str
483 Any number of strings naming parameters to be removed from time_inv.
484 """
485 for param in params:
486 self._invariant_params.discard(param)
488 def get(self, key: str, default: Any = None) -> Any:
489 """
490 Get a parameter value, returning a default if not found.
492 Parameters
493 ----------
494 key : str
495 The parameter name.
496 default : Any, optional
497 The default value to return if the key is not found.
499 Returns
500 -------
501 Any
502 The parameter value or the default.
503 """
504 return self._parameters.get(key, default)
506 def set_many(self, **kwargs: Any) -> None:
507 """
508 Set multiple parameters at once.
510 Parameters
511 ----------
512 **kwargs : Keyword arguments representing parameter names and values.
513 """
514 for key, value in kwargs.items():
515 self[key] = value
517 def is_time_varying(self, key: str) -> bool:
518 """
519 Check if a parameter is time-varying.
521 Parameters
522 ----------
523 key : str
524 The parameter name.
526 Returns
527 -------
528 bool
529 True if the parameter is time-varying, False otherwise.
530 """
531 return key in self._varying_params
533 def at_age(self, age: int) -> "Parameters":
534 """
535 Get parameters for a specific age.
537 This is an alternative to integer indexing (params[age]) that is more
538 explicit and avoids potential confusion with dictionary-style access.
540 Parameters
541 ----------
542 age : int
543 The age index to retrieve parameters for.
545 Returns
546 -------
547 Parameters
548 A new Parameters object with parameters for the specified age.
550 Raises
551 ------
552 ValueError
553 If the age index is out of bounds.
555 Examples
556 --------
557 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97], sigma=2.0)
558 >>> age_1_params = params.at_age(1)
559 >>> age_1_params.beta
560 0.96
561 """
562 return self[age]
564 def validate(self) -> None:
565 """
566 Validate parameter consistency.
568 Checks that all time-varying parameters have length matching T_cycle.
569 This is useful after manual modifications or when parameters are set
570 programmatically.
572 Raises
573 ------
574 ValueError
575 If any time-varying parameter has incorrect length.
577 Examples
578 --------
579 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97])
580 >>> params.validate() # Passes
581 >>> params.add_to_time_vary("beta")
582 >>> params.validate() # Still passes
583 """
584 errors = []
585 for param in self._varying_params:
586 value = self._parameters[param]
587 if isinstance(value, (list, tuple)):
588 if len(value) != self._length:
589 errors.append(
590 f"Parameter '{param}' has length {len(value)}, expected {self._length}"
591 )
592 elif isinstance(value, np.ndarray):
593 if value.ndim == 0:
594 errors.append(
595 f"Parameter '{param}' is a 0-dimensional array (scalar), "
596 "which should not be time-varying"
597 )
598 elif value.ndim >= 2:
599 if value.shape[0] != self._length:
600 errors.append(
601 f"Parameter '{param}' has first dimension {value.shape[0]}, expected {self._length}"
602 )
603 elif value.ndim == 1:
604 if len(value) != self._length:
605 errors.append(
606 f"Parameter '{param}' has length {len(value)}, expected {self._length}"
607 )
608 elif value.ndim == 0:
609 errors.append(
610 f"Parameter '{param}' is a 0-dimensional numpy array, expected length {self._length}"
611 )
613 if errors:
614 raise ValueError(
615 "Parameter validation failed:\n" + "\n".join(f" - {e}" for e in errors)
616 )
619class Model:
620 """
621 A class with special handling of parameters assignment.
622 """
624 def __init__(self):
625 if not hasattr(self, "parameters"):
626 self.parameters = {}
627 if not hasattr(self, "constructors"):
628 self.constructors = {}
630 def assign_parameters(self, **kwds):
631 """
632 Assign an arbitrary number of attributes to this agent.
634 Parameters
635 ----------
636 **kwds : keyword arguments
637 Any number of keyword arguments of the form key=value.
638 Each value will be assigned to the attribute named in self.
640 Returns
641 -------
642 None
643 """
644 self.parameters.update(kwds)
645 for key in kwds:
646 setattr(self, key, kwds[key])
648 def get_parameter(self, name):
649 """
650 Returns a parameter of this model
652 Parameters
653 ----------
654 name : str
655 The name of the parameter to get
657 Returns
658 -------
659 value : The value of the parameter
660 """
661 return self.parameters[name]
663 def __eq__(self, other):
664 if isinstance(other, type(self)):
665 return self.parameters == other.parameters
667 return NotImplemented
669 def __str__(self):
670 type_ = type(self)
671 module = type_.__module__
672 qualname = type_.__qualname__
674 s = f"<{module}.{qualname} object at {hex(id(self))}.\n"
675 s += "Parameters:"
677 for p in self.parameters:
678 s += f"\n{p}: {self.parameters[p]}"
680 s += ">"
681 return s
683 def describe(self):
684 return self.__str__()
686 def del_param(self, param_name):
687 """
688 Deletes a parameter from this instance, removing it both from the object's
689 namespace (if it's there) and the parameters dictionary (likewise).
691 Parameters
692 ----------
693 param_name : str
694 A string naming a parameter or data to be deleted from this instance.
695 Removes information from self.parameters dictionary and own namespace.
697 Returns
698 -------
699 None
700 """
701 if param_name in self.parameters:
702 del self.parameters[param_name]
703 if hasattr(self, param_name):
704 delattr(self, param_name)
706 def _gather_constructor_args(self, key, constructor):
707 """
708 Gather all arguments needed to call the given constructor.
710 Handles both the special ``get_it_from`` case and normal callables.
711 For a normal callable the method inspects the function signature to
712 find every required argument, then resolves each argument from the
713 instance namespace (``self.<arg>``) or from ``self.parameters``.
714 Arguments that have a default value are silently skipped when they
715 cannot be found; required arguments that cannot be resolved are
716 recorded as missing.
718 Parameters
719 ----------
720 key : str
721 The name of the constructed object (used to record missing pairs).
722 constructor : callable or get_it_from
723 The constructor to be called.
725 Returns
726 -------
727 temp_dict : dict
728 Keyword arguments to pass to the constructor.
729 any_missing : bool
730 True when at least one required argument could not be resolved.
731 missing_args : list of str
732 Names of unresolved required arguments.
733 missing_key_data : list of tuple
734 ``(key, arg)`` pairs for every unresolved required argument.
735 """
736 missing_key_data = []
738 # SPECIAL: if the constructor is get_it_from, handle it separately
739 if isinstance(constructor, get_it_from):
740 try:
741 parent = getattr(self, constructor.name)
742 query = key
743 any_missing = False
744 missing_args = []
745 except AttributeError:
746 parent = None
747 query = None
748 any_missing = True
749 missing_args = [constructor.name]
750 temp_dict = {"parent": parent, "query": query}
751 return temp_dict, any_missing, missing_args, missing_key_data
753 # Normal constructor: inspect signature and gather arguments
754 args_needed = get_arg_names(constructor)
755 has_no_default = {
756 k: v.default is inspect.Parameter.empty
757 for k, v in inspect.signature(constructor).parameters.items()
758 }
759 temp_dict = {}
760 any_missing = False
761 missing_args = []
762 for this_arg in args_needed:
763 if hasattr(self, this_arg):
764 temp_dict[this_arg] = getattr(self, this_arg)
765 else:
766 try:
767 temp_dict[this_arg] = self.parameters[this_arg]
768 except KeyError:
769 if has_no_default[this_arg]:
770 # Record missing key-data pair
771 any_missing = True
772 missing_key_data.append((key, this_arg))
773 missing_args.append(this_arg)
775 return temp_dict, any_missing, missing_args, missing_key_data
777 def _attempt_construct(self, key, i, keys_complete, backup, errors, force):
778 """
779 Attempt to construct the object for a single key.
781 The method looks up the constructor, gathers its arguments, runs it,
782 and records any errors. It also handles the ``None``-constructor case
783 (restore from backup) and the missing-args case (record and defer).
785 Parameters
786 ----------
787 key : str
788 The name of the object to construct.
789 i : int
790 Index of *key* inside the ``keys`` array (used to update
791 ``keys_complete``).
792 keys_complete : np.ndarray of bool
793 Boolean array indicating which keys have been completed; mutated
794 in place when this key succeeds.
795 backup : dict
796 Dictionary of pre-construction attribute values.
797 errors : dict
798 ``self._constructor_errors``; mutated in place.
799 force : bool
800 When True, swallow exceptions and continue; when False, re-raise.
802 Returns
803 -------
804 accomplished : bool
805 True when the key was completed (constructor ran or was None).
806 missing_key_data : list of tuple
807 ``(key, arg)`` pairs recorded for missing required arguments.
808 """
809 missing_key_data = []
811 # Look up the constructor for this key
812 try:
813 constructor = self.constructors[key]
814 except Exception as not_found:
815 errors[key] = "No constructor found for " + str(not_found)
816 if force:
817 return False, missing_key_data
818 else:
819 raise KeyError("No constructor found for " + key) from None
821 # If the constructor is None, restore from backup and mark complete
822 if constructor is None:
823 if key in backup.keys():
824 setattr(self, key, backup[key])
825 self.parameters[key] = backup[key]
826 keys_complete[i] = True
827 return True, missing_key_data
829 # Gather arguments for the constructor
830 temp_dict, any_missing, missing_args, missing_key_data = (
831 self._gather_constructor_args(key, constructor)
832 )
834 # If all required data was found, run the constructor and store the result
835 if not any_missing:
836 try:
837 temp = constructor(**temp_dict)
838 except Exception as problem:
839 errors[key] = str(type(problem)) + ": " + str(problem)
840 self.del_param(key)
841 if force:
842 return False, missing_key_data
843 else:
844 raise
845 setattr(self, key, temp)
846 self.parameters[key] = temp
847 if key in errors:
848 del errors[key]
849 keys_complete[i] = True
850 return True, missing_key_data
852 # Some required arguments were missing; record and defer
853 msg = "Missing required arguments: " + ", ".join(missing_args)
854 errors[key] = msg
855 self.del_param(key)
856 # Never raise exceptions here, as the arguments might be filled in later
857 return False, missing_key_data
859 def _construct_pass(self, keys, keys_complete, backup, errors, force):
860 """
861 Perform one full sweep over all incomplete keys.
863 Calls ``_attempt_construct`` for every key that has not yet been
864 completed and accumulates results.
866 Parameters
867 ----------
868 keys : sequence of str
869 All keys requested for construction.
870 keys_complete : np.ndarray of bool
871 Boolean array indicating which keys have been completed; mutated
872 in place by ``_attempt_construct``.
873 backup : dict
874 Dictionary of pre-construction attribute values.
875 errors : dict
876 ``self._constructor_errors``; mutated in place.
877 force : bool
878 Passed through to ``_attempt_construct``.
880 Returns
881 -------
882 anything_accomplished : bool
883 True when at least one key was completed during this pass.
884 missing_key_data : list of tuple
885 Accumulated ``(key, arg)`` pairs for all unresolved required
886 arguments across every incomplete key.
887 """
888 anything_accomplished = False
889 missing_key_data = []
891 for i, key in enumerate(keys):
892 if keys_complete[i]:
893 continue # This key has already been built
895 accomplished, key_missing = self._attempt_construct(
896 key, i, keys_complete, backup, errors, force
897 )
898 if accomplished:
899 anything_accomplished = True
900 missing_key_data.extend(key_missing)
902 return anything_accomplished, missing_key_data
904 def construct(self, *args, force=False):
905 """
906 Top-level method for building constructed inputs. If called without any
907 inputs, construct builds each of the objects named in the keys of the
908 constructors dictionary; it draws inputs for the constructors from the
909 parameters dictionary and adds its results to the same. If passed one or
910 more strings as arguments, the method builds only the named keys. The
911 method will do multiple "passes" over the requested keys, as some cons-
912 tructors require inputs built by other constructors. If any requested
913 constructors failed to build due to missing data, those keys (and the
914 missing data) will be named in self._missing_key_data. Other errors are
915 recorded in the dictionary attribute _constructor_errors.
917 This method tries to "start from scratch" by removing prior constructed
918 objects, holding them in a backup dictionary during construction. This
919 is done so that dependencies among constructors are resolved properly,
920 without mistakenly relying on "old information". A backup value is used
921 if a constructor function is set to None (i.e. "don't do anything"), or
922 if the construct method fails to produce a new object.
924 Parameters
925 ----------
926 *args : str, optional
927 Keys of self.constructors that are requested to be constructed.
928 If no arguments are passed, *all* elements of the dictionary are implied.
929 force : bool, optional
930 When True, the method will force its way past any errors, including
931 missing constructors, missing arguments for constructors, and errors
932 raised during execution of constructors. Information about all such
933 errors is stored in the dictionary attributes described above. When
934 False (default), any errors or exception will be raised.
936 Returns
937 -------
938 None
939 """
940 # Set up the requested work
941 if len(args) > 0:
942 keys = args
943 else:
944 keys = list(self.constructors.keys())
945 N_keys = len(keys)
946 keys_complete = np.zeros(N_keys, dtype=bool)
947 if N_keys == 0:
948 return # Do nothing if there are no constructed objects
950 # Remove pre-existing constructed objects, preventing "incomplete" updates,
951 # but store the current values in a backup dictionary in case something fails
952 backup = {}
953 for key in keys:
954 if hasattr(self, key):
955 backup[key] = getattr(self, key)
956 self.del_param(key)
958 # Get the dictionary of constructor errors
959 if not hasattr(self, "_constructor_errors"):
960 self._constructor_errors = {}
961 errors = self._constructor_errors
963 # As long as the work isn't complete and we made some progress on the last
964 # pass, repeatedly perform passes of trying to construct objects
965 any_keys_incomplete = np.any(np.logical_not(keys_complete))
966 go = any_keys_incomplete
967 while go:
968 anything_accomplished, missing_key_data = self._construct_pass(
969 keys, keys_complete, backup, errors, force
970 )
971 any_keys_incomplete = np.any(np.logical_not(keys_complete))
972 go = any_keys_incomplete and anything_accomplished
974 # Store missing key-data pairs and exit
975 self._missing_key_data = missing_key_data
976 self._constructor_errors = errors
977 if any_keys_incomplete:
978 msg = "Did not construct these objects:"
979 for i in range(N_keys):
980 if keys_complete[i]:
981 continue
982 msg += " " + keys[i] + ","
983 key = keys[i]
984 if key in backup.keys():
985 setattr(self, key, backup[key])
986 self.parameters[key] = backup[key]
987 msg = msg[:-1]
988 if not force:
989 raise ValueError(msg)
990 return
992 def describe_constructors(self, *args):
993 """
994 Prints to screen a string describing this instance's constructed objects,
995 including their names, the function that constructs them, the names of
996 those functions inputs, and whether those inputs are present.
998 Parameters
999 ----------
1000 *args : str, optional
1001 Optional list of strings naming constructed inputs to be described.
1002 If none are passed, all constructors are described.
1004 Returns
1005 -------
1006 None
1007 """
1008 if len(args) > 0:
1009 keys = args
1010 else:
1011 keys = list(self.constructors.keys())
1012 yes = "\u2713"
1013 no = "X"
1014 maybe = "*"
1015 noyes = [no, yes]
1017 out = ""
1018 for key in keys:
1019 has_val = hasattr(self, key) or (key in self.parameters)
1021 try:
1022 constructor = self.constructors[key]
1023 except KeyError:
1024 out += noyes[int(has_val)] + " " + key + " : NO CONSTRUCTOR FOUND\n"
1025 continue
1027 # Get the constructor function if possible
1028 if isinstance(constructor, get_it_from):
1029 parent_name = self.constructors[key].name
1030 out += (
1031 noyes[int(has_val)]
1032 + " "
1033 + key
1034 + " : get it from "
1035 + parent_name
1036 + "\n"
1037 )
1038 continue
1039 else:
1040 out += (
1041 noyes[int(has_val)]
1042 + " "
1043 + key
1044 + " : "
1045 + constructor.__name__
1046 + "\n"
1047 )
1049 # Get constructor argument names
1050 arg_names = get_arg_names(constructor)
1051 has_no_default = {
1052 k: v.default is inspect.Parameter.empty
1053 for k, v in inspect.signature(constructor).parameters.items()
1054 }
1056 # Check whether each argument exists
1057 for j in range(len(arg_names)):
1058 this_arg = arg_names[j]
1059 if hasattr(self, this_arg) or this_arg in self.parameters:
1060 symb = yes
1061 elif not has_no_default[this_arg]:
1062 symb = maybe
1063 else:
1064 symb = no
1065 out += " " + symb + " " + this_arg + "\n"
1067 # Print the string to screen
1068 print(out)
1069 return
1071 # This is a "synonym" method so that old calls to update() still work
1072 def update(self, *args, **kwargs):
1073 self.construct(*args, **kwargs)
1076class AgentType(Model):
1077 """
1078 A superclass for economic agents in the HARK framework. Each model should
1079 specify its own subclass of AgentType, inheriting its methods and overwriting
1080 as necessary. Critically, every subclass of AgentType should define class-
1081 specific static values of the attributes time_vary and time_inv as lists of
1082 strings. Each element of time_vary is the name of a field in AgentSubType
1083 that varies over time in the model. Each element of time_inv is the name of
1084 a field in AgentSubType that is constant over time in the model.
1086 Parameters
1087 ----------
1088 solution_terminal : Solution
1089 A representation of the solution to the terminal period problem of
1090 this AgentType instance, or an initial guess of the solution if this
1091 is an infinite horizon problem.
1092 cycles : int
1093 The number of times the sequence of periods is experienced by this
1094 AgentType in their "lifetime". cycles=1 corresponds to a lifecycle
1095 model, with a certain sequence of one period problems experienced
1096 once before terminating. cycles=0 corresponds to an infinite horizon
1097 model, with a sequence of one period problems repeating indefinitely.
1098 pseudo_terminal : bool
1099 Indicates whether solution_terminal isn't actually part of the
1100 solution to the problem (as a known solution to the terminal period
1101 problem), but instead represents a "scrap value"-style termination.
1102 When True, solution_terminal is not included in the solution; when
1103 False, solution_terminal is the last element of the solution.
1104 tolerance : float
1105 Maximum acceptable "distance" between successive solutions to the
1106 one period problem in an infinite horizon (cycles=0) model in order
1107 for the solution to be considered as having "converged". Inoperative
1108 when cycles>0.
1109 verbose : int
1110 Level of output to be displayed by this instance, default is 1.
1111 quiet : bool
1112 Indicator for whether this instance should operate "quietly", default False.
1113 seed : int
1114 A seed for this instance's random number generator.
1115 construct : bool
1116 Indicator for whether this instance's construct() method should be run
1117 when initialized (default True). When False, an instance of the class
1118 can be created even if not all of its attributes can be constructed.
1119 use_defaults : bool
1120 Indicator for whether this instance should use the values in the class'
1121 default dictionary to fill in parameters and constructors for those not
1122 provided by the user (default True). Setting this to False is useful for
1123 situations where the user wants to be absolutely sure that they know what
1124 is being passed to the class initializer, without resorting to defaults.
1126 Attributes
1127 ----------
1128 AgentCount : int
1129 The number of agents of this type to use in simulation.
1131 state_vars : list of string
1132 The string labels for this AgentType's model state variables.
1133 """
1135 time_vary_ = []
1136 time_inv_ = []
1137 shock_vars_ = []
1138 state_vars = []
1139 poststate_vars = []
1140 market_vars = []
1141 distributions = []
1142 default_ = {"params": {}, "solver": NullFunc()}
1144 def __init__(
1145 self,
1146 solution_terminal=None,
1147 pseudo_terminal=True,
1148 tolerance=0.000001,
1149 verbose=1,
1150 quiet=False,
1151 seed=0,
1152 construct=True,
1153 use_defaults=True,
1154 **kwds,
1155 ):
1156 super().__init__()
1157 params = deepcopy(self.default_["params"]) if use_defaults else {}
1158 params.update(kwds)
1160 # Correctly handle constructors that have been passed in kwds
1161 if "constructors" in self.default_["params"].keys() and use_defaults:
1162 constructors = deepcopy(self.default_["params"]["constructors"])
1163 else:
1164 constructors = {}
1165 if "constructors" in kwds.keys():
1166 constructors.update(kwds["constructors"])
1167 params["constructors"] = constructors
1169 # Set default track_vars
1170 if "track_vars" in self.default_.keys() and use_defaults:
1171 self.track_vars = copy(self.default_["track_vars"])
1172 else:
1173 self.track_vars = []
1175 # Set model file name if possible
1176 try:
1177 self.model_file = copy(self.default_["model"])
1178 except (KeyError, TypeError):
1179 # Fallback to None if "model" key is missing or invalid for copying
1180 self.model_file = None
1182 if solution_terminal is None:
1183 solution_terminal = NullFunc()
1185 self.solve_one_period = self.default_["solver"] # NOQA
1186 self.solution_terminal = solution_terminal # NOQA
1187 self.pseudo_terminal = pseudo_terminal # NOQA
1188 self.tolerance = tolerance # NOQA
1189 self.verbose = verbose
1190 self.quiet = quiet
1191 self.seed = seed # NOQA
1192 self.state_now = {sv: None for sv in self.state_vars}
1193 self.state_prev = self.state_now.copy()
1194 self.controls = {}
1195 self.shocks = {}
1196 self.read_shocks = False # NOQA
1197 self.shock_history = {}
1198 self.newborn_init_history = {}
1199 self.history = {}
1200 self.assign_parameters(**params) # NOQA
1201 self.reset_rng() # NOQA
1202 self.bilt = {}
1203 if construct:
1204 self.construct()
1206 # Add instance-level lists and objects
1207 self.time_vary = deepcopy(self.time_vary_)
1208 self.time_inv = deepcopy(self.time_inv_)
1209 self.shock_vars = deepcopy(self.shock_vars_)
1211 def add_to_time_vary(self, *params):
1212 """
1213 Adds any number of parameters to time_vary for this instance.
1215 Parameters
1216 ----------
1217 params : string
1218 Any number of strings naming attributes to be added to time_vary
1220 Returns
1221 -------
1222 None
1223 """
1224 for param in params:
1225 if param not in self.time_vary:
1226 self.time_vary.append(param)
1228 def add_to_time_inv(self, *params):
1229 """
1230 Adds any number of parameters to time_inv for this instance.
1232 Parameters
1233 ----------
1234 params : string
1235 Any number of strings naming attributes to be added to time_inv
1237 Returns
1238 -------
1239 None
1240 """
1241 for param in params:
1242 if param not in self.time_inv:
1243 self.time_inv.append(param)
1245 def del_from_time_vary(self, *params):
1246 """
1247 Removes any number of parameters from time_vary for this instance.
1249 Parameters
1250 ----------
1251 params : string
1252 Any number of strings naming attributes to be removed from time_vary
1254 Returns
1255 -------
1256 None
1257 """
1258 for param in params:
1259 if param in self.time_vary:
1260 self.time_vary.remove(param)
1262 def del_from_time_inv(self, *params):
1263 """
1264 Removes any number of parameters from time_inv for this instance.
1266 Parameters
1267 ----------
1268 params : string
1269 Any number of strings naming attributes to be removed from time_inv
1271 Returns
1272 -------
1273 None
1274 """
1275 for param in params:
1276 if param in self.time_inv:
1277 self.time_inv.remove(param)
1279 def unpack(self, name):
1280 """
1281 Unpacks an attribute from a solution object for easier access.
1282 After the model has been solved, its components (like consumption function)
1283 reside in the attributes of each element of `ThisType.solution` (e.g. `cFunc`).
1284 This method creates a (time varying) attribute of the given attribute name
1285 that contains a list of elements accessible by `ThisType.parameter`.
1287 Parameters
1288 ----------
1289 name: str
1290 Name of the attribute to unpack from the solution
1292 Returns
1293 -------
1294 none
1295 """
1296 # Use list comprehension for better performance instead of loop with append
1297 if type(self.solution[0]) is dict:
1298 setattr(self, name, [soln_t[name] for soln_t in self.solution])
1299 else:
1300 setattr(self, name, [soln_t.__dict__[name] for soln_t in self.solution])
1301 self.add_to_time_vary(name)
1303 def solve(
1304 self,
1305 verbose=False,
1306 presolve=True,
1307 postsolve=True,
1308 from_solution=None,
1309 from_t=None,
1310 ):
1311 """
1312 Solve the model for this instance of an agent type by backward induction.
1313 Loops through the sequence of one period problems, passing the solution
1314 from period t+1 to the problem for period t.
1316 Parameters
1317 ----------
1318 verbose : bool, optional
1319 If True, solution progress is printed to screen. Default False.
1320 presolve : bool, optional
1321 If True (default), the pre_solve method is run before solving.
1322 postsolve : bool, optional
1323 If True (default), the post_solve method is run after solving.
1324 from_solution: Solution
1325 If different from None, will be used as the starting point of backward
1326 induction, instead of self.solution_terminal.
1327 from_t : int or None
1328 If not None, indicates which period of the model the solver should start
1329 from. It should usually only be used in combination with from_solution.
1330 Stands for the time index that from_solution represents, and thus is
1331 only compatible with cycles=1 and will be reset to None otherwise.
1333 Returns
1334 -------
1335 none
1336 """
1338 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep-
1339 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is
1340 # -np.inf, np.inf/np.inf is np.nan and so on.
1341 with np.errstate(
1342 divide="ignore", over="ignore", under="ignore", invalid="ignore"
1343 ):
1344 if presolve:
1345 self.pre_solve() # Do pre-solution stuff
1346 self.solution = solve_agent(
1347 self,
1348 verbose,
1349 from_solution,
1350 from_t,
1351 ) # Solve the model by backward induction
1352 if postsolve:
1353 self.post_solve() # Do post-solution stuff
1355 def reset_rng(self):
1356 """
1357 Reset the random number generator and all distributions for this type.
1358 Type-checking for lists is to handle the following three cases:
1360 1) The target is a single distribution object
1361 2) The target is a list of distribution objects (probably time-varying)
1362 3) The target is a nested list of distributions, as in ConsMarkovModel.
1363 """
1364 self.RNG = np.random.default_rng(self.seed)
1365 for name in self.distributions:
1366 if not hasattr(self, name):
1367 continue
1369 dstn = getattr(self, name)
1370 if isinstance(dstn, list):
1371 for D in dstn:
1372 if isinstance(D, list):
1373 for d in D:
1374 d.reset()
1375 else:
1376 D.reset()
1377 else:
1378 dstn.reset()
1380 def check_elements_of_time_vary_are_lists(self):
1381 """
1382 A method to check that elements of time_vary are lists.
1383 """
1384 for param in self.time_vary:
1385 if not hasattr(self, param):
1386 continue
1387 if not isinstance(
1388 getattr(self, param),
1389 (IndexDistribution,),
1390 ):
1391 assert type(getattr(self, param)) == list, (
1392 param
1393 + " is not a list or time varying distribution,"
1394 + " but should be because it is in time_vary"
1395 )
1397 def check_restrictions(self):
1398 """
1399 A method to check that various restrictions are met for the model class.
1400 """
1401 return
1403 def pre_solve(self):
1404 """
1405 A method that is run immediately before the model is solved, to check inputs or to prepare
1406 the terminal solution, perhaps.
1408 Parameters
1409 ----------
1410 none
1412 Returns
1413 -------
1414 none
1415 """
1416 self.check_restrictions()
1417 self.check_elements_of_time_vary_are_lists()
1418 return None
1420 def post_solve(self):
1421 """
1422 A method that is run immediately after the model is solved, to finalize
1423 the solution in some way. Does nothing here.
1425 Parameters
1426 ----------
1427 none
1429 Returns
1430 -------
1431 none
1432 """
1433 return None
1435 def get_market_params(self, mkt, construct=True):
1436 """
1437 Fetch data named in class attribute market_vars and assign it as attributes
1438 (and parameters) of self. By default, the construct method is run within
1439 this method, because the market parameters often have information needed to
1440 "complete" the microeconomic problem.
1442 This method is called automatically by the Market.give_agent_params()
1443 method for all agents.
1445 Parameters
1446 ----------
1447 mkt : Market
1448 Market to which this AgentType belongs.
1449 construct : bool
1450 Indicator for whether constructed attributes should be updated after
1451 fetching data / parameters from mkt (default True)
1453 Returns
1454 -------
1455 None
1456 """
1457 temp_dict = {}
1458 for name in self.market_vars:
1459 temp_dict[name] = copy(getattr(mkt, name))
1460 self.assign_parameters(**temp_dict)
1461 if construct:
1462 self.construct()
1464 def initialize_sym(self, **kwargs):
1465 """
1466 Use the new simulator structure to build a simulator from the agents'
1467 attributes, storing it in a private attribute.
1468 """
1469 self.reset_rng() # ensure seeds are set identically each time
1470 self._simulator = make_simulator_from_agent(self, **kwargs)
1471 self._simulator.reset()
1473 def find_target(self, target_var, force_list=False, **kwargs):
1474 """
1475 Find the "target" level of a named variable such that E[\Delta x] = 0,
1476 with E[\Delta x-eps] > 0 and E[\Delta x+eps] < 0 (locally stable).
1477 Returns a single real value if there is only one target, and a list if multiple;
1478 returns np.nan if no target is found. Pass force_list=True to always get a list.
1479 See documentation for HARK.simulator.find_target_state for more options.
1480 """
1481 if not hasattr(self, "solution"):
1482 raise AttributeError("Model must be solved before using find_target!")
1483 temp_simulator = make_simulator_from_agent(self)
1484 target_vals = temp_simulator.find_target_state(target_var, **kwargs)
1485 if force_list:
1486 return target_vals
1487 if len(target_vals) == 0:
1488 return np.nan
1489 elif len(target_vals) == 1:
1490 return target_vals[0]
1491 else:
1492 return target_vals
1494 def initialize_sim(self):
1495 """
1496 Prepares this AgentType for a new simulation. Resets the internal random number generator,
1497 makes initial states for all agents (using sim_birth), clears histories of tracked variables.
1499 Parameters
1500 ----------
1501 None
1503 Returns
1504 -------
1505 None
1506 """
1507 if not hasattr(self, "T_sim"):
1508 raise Exception(
1509 "To initialize simulation variables it is necessary to first "
1510 + "set the attribute T_sim to the largest number of observations "
1511 + "you plan to simulate for each agent including re-births."
1512 )
1513 elif self.T_sim <= 0:
1514 raise Exception(
1515 "T_sim represents the largest number of observations "
1516 + "that can be simulated for an agent, and must be a positive number."
1517 )
1519 self.reset_rng()
1520 self.t_sim = 0
1521 all_agents = np.ones(self.AgentCount, dtype=bool)
1522 blank_array = np.empty(self.AgentCount)
1523 blank_array[:] = np.nan
1524 for var in self.state_vars:
1525 self.state_now[var] = copy(blank_array)
1527 # Number of periods since agent entry
1528 self.t_age = np.zeros(self.AgentCount, dtype=int)
1529 # Which cycle period each agent is on
1530 self.t_cycle = np.zeros(self.AgentCount, dtype=int)
1531 self.sim_birth(all_agents)
1533 # If we are asked to use existing shocks and a set of initial conditions
1534 # exist, use them
1535 if self.read_shocks and bool(self.newborn_init_history):
1536 for var_name in self.state_now:
1537 # Check that we are actually given a value for the variable
1538 if var_name in self.newborn_init_history.keys():
1539 # Copy only array-like idiosyncratic states. Aggregates should
1540 # not be set by newborns
1541 idio = (
1542 isinstance(self.state_now[var_name], np.ndarray)
1543 and len(self.state_now[var_name]) == self.AgentCount
1544 )
1545 if idio:
1546 self.state_now[var_name] = self.newborn_init_history[var_name][
1547 0
1548 ]
1550 else:
1551 warn(
1552 "The option for reading shocks was activated but "
1553 + "the model requires state "
1554 + var_name
1555 + ", not contained in "
1556 + "newborn_init_history."
1557 )
1559 self.clear_history()
1561 def _export_single_var_by_time(self, history, var, t, dtype):
1562 """
1563 Mode 1a: single variable, by_age=False.
1565 Returns a DataFrame whose columns are simulation periods (t_sim) and
1566 whose rows are agent indices. When t is None all T_sim periods are
1567 included; when t is an array only those periods are included.
1568 """
1569 try:
1570 data = history[var]
1571 except KeyError:
1572 raise KeyError("Variable named " + var + " not found in simulated data!")
1574 if t is None:
1575 cols = [str(i) for i in range(self.T_sim)]
1576 df = DataFrame(data=data.T, columns=cols, dtype=dtype)
1577 else:
1578 cols = [str(t_val) for t_val in t]
1579 df = DataFrame(data=data[t, :].T, columns=cols, dtype=dtype)
1580 return df
1582 def _export_single_var_by_age(self, history, var, age, t, dtype, sym):
1583 """
1584 Mode 1b: single variable, by_age=True.
1586 Returns a DataFrame whose columns are within-agent model ages (t_age)
1587 and whose rows are individual agent lifetimes. Observations after
1588 death (or before birth) are NaN. When t is None all ages up to
1589 max(t_age) are included; when t is an array only those ages are used.
1590 The sym flag controls newborn detection: age == 0 when sym is True,
1591 age == 1 when sym is False.
1592 """
1593 try:
1594 data = history[var]
1595 except KeyError:
1596 raise KeyError("Variable named " + var + " not found in simulated data!")
1598 # Determine which ages to include and mark qualifying observations
1599 if t is None:
1600 age_set = np.arange(np.max(age) + 1)
1601 in_age_set = np.ones_like(data, dtype=bool)
1602 else:
1603 age_set = t
1604 in_age_set = np.zeros_like(data, dtype=bool)
1605 for j in age_set:
1606 these = age == j
1607 in_age_set[these] = True
1609 # Locate newborns to determine the number of individual lifetimes (rows)
1610 newborns = age == 0 if sym else age == 1
1611 T = age_set.size # number of age columns
1612 N = np.sum(newborns) # number of agent lifetimes (rows)
1613 out = np.full((N, T), np.nan)
1615 # Extract each individual's sequence and place it into the output array
1616 n = 0
1617 for i in range(self.AgentCount):
1618 data_i = data[:, i]
1619 births = np.where(newborns[:, i])[0]
1620 K = births.size
1621 for k in range(K):
1622 start = births[k]
1623 stop = births[k + 1] if (k < K - 1) else self.T_sim
1624 use = in_age_set[start:stop, i]
1625 temp = data_i[start:stop][use]
1626 out[n, : temp.size] = temp
1627 n += 1
1629 cols = [str(a) for a in age_set]
1630 df = DataFrame(data=out, columns=cols, dtype=dtype)
1631 return df
1633 def _export_single_t_by_time(self, history, var_list, t, dtype):
1634 """
1635 Mode 2a: single time period, by_age=False.
1637 Returns a DataFrame with one row per agent and one column per variable,
1638 drawn from the absolute simulation period t (i.e. history[name][t, :]).
1639 """
1640 K = len(var_list)
1641 N = self.AgentCount
1642 out = np.full((N, K), np.nan)
1643 for k, name in enumerate(var_list):
1644 out[:, k] = history[name][t, :]
1645 df = DataFrame(data=out, columns=var_list, dtype=dtype)
1646 return df
1648 def _export_single_t_by_age(self, history, var_list, age, t, dtype):
1649 """
1650 Mode 2b: single time period, by_age=True.
1652 Returns a DataFrame with one row per agent-period at which t_age == t
1653 and one column per variable.
1654 """
1655 right_age = age == t
1656 N = np.sum(right_age)
1657 K = len(var_list)
1658 out = np.full((N, K), np.nan)
1659 for k, name in enumerate(var_list):
1660 out[:, k] = history[name][right_age]
1661 df = DataFrame(data=out, columns=var_list, dtype=dtype)
1662 return df
1664 def export_to_df(self, var=None, t=None, by_age=False, dtype=None, sym=False):
1665 """
1666 Export an AgentType instance's simulated data to a pandas dataframe object.
1667 There are four construction modes depending on the arguments passed:
1669 1a) If exactly one simulated variable is named as var and by_age is False,
1670 then the dataframe will contain T_sim columns, each representing one
1671 simulated period in absolute simulation time t_sim. Each row of the
1672 dataframe will represent one *agent index* of the population, with death
1673 and replacement occurring within a row. Optionally, argument t can be
1674 provided as an array to specify which periods to include (default all).
1676 1b) If exactly one simulated variable is named as var and by_age is True,
1677 then the dataframe's columns will correspond to within-agent model age
1678 t_age. Each row of the dataframe will represent one specific agent from
1679 model entry (t_age=0) to model death. All observations after death will
1680 be NaN. Optionally, argument t can be provided as an array to specify
1681 which ages to include (default all). Number of columns in dataframe will
1682 depend on max(t_age) and/or argument t.
1684 2a) If an integer is provided as t and by_age is False, then each column of
1685 the dataframe will represent a different simulated variable, using the
1686 value for the specified absolute simulated period t=t_sim. Optionally,
1687 the var argument can be provided as a list of strings naming which var-
1688 iables should be included in the dataframe (default all).
1690 2b) If an integer is provided as t and by_age is True, then each column of
1691 the dataframe will represent a different simulated variable, taken from
1692 all agent-periods at which t == t_age, within-agent model age. Optionally,
1693 the var argument can be provided as a list of strings naming which var-
1694 iables should be included in the dataframe (default all).
1696 In summary, *either* var should be a single string *or* t should be an integer.
1697 Any other combination of var and t will raise an exception.
1699 Parameters
1700 ----------
1701 var : str or [str] or None
1702 If a single string is provided, it represents the name of the one simulated
1703 variable to export. If a list of strings, then the argument t must also be
1704 provided to indicate which time period the dataframe will represent. Name(s)
1705 must correspond to a key for history or hystory dictionary (i.e. named in track_vars).
1706 If not provided, then all keys in history or hystory are included.
1707 t : int or np.array or None
1708 If an integer, indicates which one period will be included in the dataframe.
1709 When by_age is False (default), t refers to absolute simulated time t_sim:
1710 literally the t-th row of history[key]. When by_age is True, t refers to
1711 within-agent model age t_age; the dataframe will include all agent-periods
1712 where the agent has exactly t_age==t. If var is a single string, then t is
1713 an optional input as an array of periods (or ages) to include (default all).
1714 by_age : bool
1715 Indicator for whether observation selection should be on the basis of absolute
1716 simulated time t_sim or within-agent model age t_age. If True, then t_age
1717 must be in track_vars so that it appears in the simulated data. Additionally,
1718 argument dtype should *not* be provided when by_age is True, as this will
1719 result in NaNs being cast to a datatype that doesn't necessarily support them.
1720 dtype : type or None
1721 Optional data type to cast the dataframe. By default, uses the datatype from
1722 the entry in history or hystory.
1723 sym : bool
1724 Indicator for whether the dataframe should look for simulated data in the
1725 history (False, default) or hystory (True) dictionary attribute. This option
1726 will be deprecated in the future when legacy simulation methods are removed.
1728 Returns
1729 -------
1730 df : pandas.DataFrame
1731 The requested dataframe, constructed from this instance's simulated data.
1732 """
1733 # Validate arguments
1734 single_var = type(var) is str
1735 single_t = isinstance(t, (int, np.integer))
1736 if not (single_var ^ single_t):
1737 raise ValueError(
1738 "Either var must be a single string, or t must be a single integer!"
1739 )
1740 if dtype is not None and by_age:
1741 raise ValueError(
1742 "Can't specify dtype when using by_age is True because of potential incompatibility with representing NaN"
1743 )
1745 # Get the relevant history dictionary (deprecate in future)
1746 history = self.hystory if sym else self.history
1748 # Retrieve age array once if needed (raises a clear error when missing)
1749 if by_age:
1750 try:
1751 age = history["t_age"]
1752 except KeyError:
1753 raise KeyError(
1754 "t_age must be in track_vars if by_age=True will be used!"
1755 )
1757 # Route to the appropriate private method
1758 if single_var and not by_age:
1759 return self._export_single_var_by_time(history, var, t, dtype)
1760 elif single_var and by_age:
1761 return self._export_single_var_by_age(history, var, age, t, dtype, sym)
1762 else: # single_t
1763 # Build and validate the variable list
1764 if var is None:
1765 var_list = list(history.keys())
1766 else:
1767 var_list = copy(var)
1768 sim_keys = list(history.keys())
1769 for name in var_list:
1770 if name not in sim_keys:
1771 raise KeyError(
1772 "Variable called " + name + " not found in simulation data!"
1773 )
1774 if by_age:
1775 return self._export_single_t_by_age(history, var_list, age, t, dtype)
1776 else:
1777 return self._export_single_t_by_time(history, var_list, t, dtype)
1779 def sim_one_period(self):
1780 """
1781 Simulates one period for this type. Calls the methods get_mortality(), get_shocks() or
1782 read_shocks, get_states(), get_controls(), and get_poststates(). These should be defined for
1783 AgentType subclasses, except get_mortality (define its components sim_death and sim_birth
1784 instead) and read_shocks.
1786 Parameters
1787 ----------
1788 None
1790 Returns
1791 -------
1792 None
1793 """
1794 if not hasattr(self, "solution"):
1795 raise Exception(
1796 "Model instance does not have a solution stored. To simulate, it is necessary"
1797 " to run the `solve()` method first."
1798 )
1800 # Mortality adjusts the agent population
1801 self.get_mortality() # Replace some agents with "newborns"
1803 # state_{t-1}
1804 for var in self.state_now:
1805 self.state_prev[var] = self.state_now[var]
1807 if isinstance(self.state_now[var], np.ndarray):
1808 self.state_now[var] = np.empty(self.AgentCount)
1809 else:
1810 # Probably an aggregate variable. It may be getting set by the Market.
1811 pass
1813 if self.read_shocks: # If shock histories have been pre-specified, use those
1814 self.read_shocks_from_history()
1815 else: # Otherwise, draw shocks as usual according to subclass-specific method
1816 self.get_shocks()
1817 self.get_states() # Determine each agent's state at decision time
1818 self.get_controls() # Determine each agent's choice or control variables based on states
1819 self.get_poststates() # Calculate variables that come *after* decision-time
1821 # Advance time for all agents
1822 self.t_age = self.t_age + 1 # Age all consumers by one period
1823 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle
1824 self.t_cycle[self.t_cycle == self.T_cycle] = (
1825 0 # Resetting to zero for those who have reached the end
1826 )
1828 def make_shock_history(self):
1829 """
1830 Makes a pre-specified history of shocks for the simulation. Shock variables should be named
1831 in self.shock_vars, a list of strings that is subclass-specific. This method runs a subset
1832 of the standard simulation loop by simulating only mortality and shocks; each variable named
1833 in shock_vars is stored in a T_sim x AgentCount array in history dictionary self.history[X].
1834 Automatically sets self.read_shocks to True so that these pre-specified shocks are used for
1835 all subsequent calls to simulate().
1837 Parameters
1838 ----------
1839 None
1841 Returns
1842 -------
1843 None
1844 """
1845 # Re-initialize the simulation
1846 self.initialize_sim()
1848 # Make blank history arrays for each shock variable (and mortality)
1849 for var_name in self.shock_vars:
1850 self.shock_history[var_name] = (
1851 np.zeros((self.T_sim, self.AgentCount)) + np.nan
1852 )
1853 self.shock_history["who_dies"] = np.zeros(
1854 (self.T_sim, self.AgentCount), dtype=bool
1855 )
1857 # Also make blank arrays for the draws of newborns' initial conditions
1858 for var_name in self.state_vars:
1859 self.newborn_init_history[var_name] = (
1860 np.zeros((self.T_sim, self.AgentCount)) + np.nan
1861 )
1863 # Record the initial condition of the newborns created by
1864 # initialize_sim -> sim_births
1865 for var_name in self.state_vars:
1866 # Check whether the state is idiosyncratic or an aggregate
1867 idio = (
1868 isinstance(self.state_now[var_name], np.ndarray)
1869 and len(self.state_now[var_name]) == self.AgentCount
1870 )
1871 if idio:
1872 self.newborn_init_history[var_name][self.t_sim] = self.state_now[
1873 var_name
1874 ]
1875 else:
1876 # Aggregate state is a scalar. Assign it to every agent.
1877 self.newborn_init_history[var_name][self.t_sim, :] = self.state_now[
1878 var_name
1879 ]
1881 # Make and store the history of shocks for each period
1882 for t in range(self.T_sim):
1883 # Deaths
1884 self.get_mortality()
1885 self.shock_history["who_dies"][t, :] = self.who_dies
1887 # Initial conditions of newborns
1888 if self.who_dies.any():
1889 for var_name in self.state_vars:
1890 # Check whether the state is idiosyncratic or an aggregate
1891 idio = (
1892 isinstance(self.state_now[var_name], np.ndarray)
1893 and len(self.state_now[var_name]) == self.AgentCount
1894 )
1895 if idio:
1896 self.newborn_init_history[var_name][t, self.who_dies] = (
1897 self.state_now[var_name][self.who_dies]
1898 )
1899 else:
1900 self.newborn_init_history[var_name][t, self.who_dies] = (
1901 self.state_now[var_name]
1902 )
1904 # Other Shocks
1905 self.get_shocks()
1906 for var_name in self.shock_vars:
1907 self.shock_history[var_name][t, :] = self.shocks[var_name]
1909 self.t_sim += 1
1910 self.t_age = self.t_age + 1 # Age all consumers by one period
1911 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle
1912 self.t_cycle[self.t_cycle == self.T_cycle] = (
1913 0 # Resetting to zero for those who have reached the end
1914 )
1916 # Flag that shocks can be read rather than simulated
1917 self.read_shocks = True
1919 def get_mortality(self):
1920 """
1921 Simulates mortality or agent turnover according to some model-specific rules named sim_death
1922 and sim_birth (methods of an AgentType subclass). sim_death takes no arguments and returns
1923 a Boolean array of size AgentCount, indicating which agents of this type have "died" and
1924 must be replaced. sim_birth takes such a Boolean array as an argument and generates initial
1925 post-decision states for those agent indices.
1927 Parameters
1928 ----------
1929 None
1931 Returns
1932 -------
1933 None
1934 """
1935 if self.read_shocks:
1936 who_dies = self.shock_history["who_dies"][self.t_sim, :]
1937 # Instead of simulating births, assign the saved newborn initial conditions
1938 if who_dies.any():
1939 for var_name in self.state_now:
1940 if var_name in self.newborn_init_history.keys():
1941 # Copy only array-like idiosyncratic states. Aggregates should
1942 # not be set by newborns
1943 idio = (
1944 isinstance(self.state_now[var_name], np.ndarray)
1945 and len(self.state_now[var_name]) == self.AgentCount
1946 )
1947 if idio:
1948 self.state_now[var_name][who_dies] = (
1949 self.newborn_init_history[var_name][
1950 self.t_sim, who_dies
1951 ]
1952 )
1954 else:
1955 warn(
1956 "The option for reading shocks was activated but "
1957 + "the model requires state "
1958 + var_name
1959 + ", not contained in "
1960 + "newborn_init_history."
1961 )
1963 # Reset ages of newborns
1964 self.t_age[who_dies] = 0
1965 self.t_cycle[who_dies] = 0
1966 else:
1967 who_dies = self.sim_death()
1968 self.sim_birth(who_dies)
1969 self.who_dies = who_dies
1970 return None
1972 def sim_death(self):
1973 """
1974 Determines which agents in the current population "die" or should be replaced. Takes no
1975 inputs, returns a Boolean array of size self.AgentCount, which has True for agents who die
1976 and False for those that survive. Returns all False by default, must be overwritten by a
1977 subclass to have replacement events.
1979 Parameters
1980 ----------
1981 None
1983 Returns
1984 -------
1985 who_dies : np.array
1986 Boolean array of size self.AgentCount indicating which agents die and are replaced.
1987 """
1988 who_dies = np.zeros(self.AgentCount, dtype=bool)
1989 return who_dies
1991 def sim_birth(self, which_agents): # pragma: nocover
1992 """
1993 Makes new agents for the simulation. Takes a boolean array as an input, indicating which
1994 agent indices are to be "born". Does nothing by default, must be overwritten by a subclass.
1996 Parameters
1997 ----------
1998 which_agents : np.array(Bool)
1999 Boolean array of size self.AgentCount indicating which agents should be "born".
2001 Returns
2002 -------
2003 None
2004 """
2005 raise Exception("AgentType subclass must define method sim_birth!")
2007 def get_shocks(self): # pragma: nocover
2008 """
2009 Gets values of shock variables for the current period. Does nothing by default, but can
2010 be overwritten by subclasses of AgentType.
2012 Parameters
2013 ----------
2014 None
2016 Returns
2017 -------
2018 None
2019 """
2020 return None
2022 def read_shocks_from_history(self):
2023 """
2024 Reads values of shock variables for the current period from history arrays.
2025 For each variable X named in self.shock_vars, this attribute of self is
2026 set to self.history[X][self.t_sim,:].
2028 This method is only ever called if self.read_shocks is True. This can
2029 be achieved by using the method make_shock_history() (or manually after
2030 storing a "handcrafted" shock history).
2032 Parameters
2033 ----------
2034 None
2036 Returns
2037 -------
2038 None
2039 """
2040 for var_name in self.shock_vars:
2041 self.shocks[var_name] = self.shock_history[var_name][self.t_sim, :]
2043 def get_states(self):
2044 """
2045 Gets values of state variables for the current period.
2046 By default, calls transition function and assigns values
2047 to the state_now dictionary.
2049 Parameters
2050 ----------
2051 None
2053 Returns
2054 -------
2055 None
2056 """
2057 new_states = self.transition()
2059 for i, var in enumerate(self.state_now):
2060 # a hack for now to deal with 'post-states'
2061 if i < len(new_states):
2062 self.state_now[var] = new_states[i]
2064 def transition(self): # pragma: nocover
2065 """
2067 Parameters
2068 ----------
2069 None
2071 [Eventually, to match dolo spec:
2072 exogenous_prev, endogenous_prev, controls, exogenous, parameters]
2074 Returns
2075 -------
2077 endogenous_state: ()
2078 Tuple with new values of the endogenous states
2079 """
2080 return ()
2082 def get_controls(self): # pragma: nocover
2083 """
2084 Gets values of control variables for the current period, probably by using current states.
2085 Does nothing by default, but can be overwritten by subclasses of AgentType.
2087 Parameters
2088 ----------
2089 None
2091 Returns
2092 -------
2093 None
2094 """
2095 return None
2097 def get_poststates(self):
2098 """
2099 Gets values of post-decision state variables for the current period,
2100 probably by current
2101 states and controls and maybe market-level events or shock variables.
2102 Does nothing by
2103 default, but can be overwritten by subclasses of AgentType.
2105 Parameters
2106 ----------
2107 None
2109 Returns
2110 -------
2111 None
2112 """
2113 return None
2115 def symulate(self, T=None):
2116 """
2117 Run the new simulation structure, with history results written to the
2118 hystory attribute of self.
2119 """
2120 self._simulator.simulate(T)
2121 self.hystory = self._simulator.history
2123 def describe_model(self, display=True):
2124 """
2125 Print to screen information about this agent's model, based on its model
2126 file. This is useful for learning about outcome variable names for tracking
2127 during simulation, or for use with sequence space Jacobians.
2128 """
2129 if not hasattr(self, "_simulator"):
2130 self.initialize_sym()
2131 self._simulator.describe(display=display)
2133 def simulate(self, sim_periods=None):
2134 """
2135 Simulates this agent type for a given number of periods. Defaults to self.T_sim,
2136 or all remaining periods to simulate (T_sim - t_sim). Records histories of
2137 attributes named in self.track_vars in self.history[varname].
2139 Parameters
2140 ----------
2141 sim_periods : int or None
2142 Number of periods to simulate. Default is all remaining periods (usually T_sim).
2144 Returns
2145 -------
2146 history : dict
2147 The history tracked during the simulation.
2148 """
2149 if not hasattr(self, "t_sim"):
2150 raise Exception(
2151 "It seems that the simulation variables were not initialize before calling "
2152 + "simulate(). Call initialize_sim() to initialize the variables before calling simulate() again."
2153 )
2155 if not hasattr(self, "T_sim"):
2156 raise Exception(
2157 "This agent type instance must have the attribute T_sim set to a positive integer."
2158 + "Set T_sim to match the largest dataset you might simulate, and run this agent's"
2159 + "initialize_sim() method before running simulate() again."
2160 )
2162 if sim_periods is not None and self.T_sim < sim_periods:
2163 raise Exception(
2164 "To simulate, sim_periods has to be larger than the maximum data set size "
2165 + "T_sim. Either increase the attribute T_sim of this agent type instance "
2166 + "and call the initialize_sim() method again, or set sim_periods <= T_sim."
2167 )
2169 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep-
2170 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is
2171 # -np.inf, np.inf/np.inf is np.nan and so on.
2172 with np.errstate(
2173 divide="ignore", over="ignore", under="ignore", invalid="ignore"
2174 ):
2175 if sim_periods is None:
2176 sim_periods = self.T_sim - self.t_sim
2178 for t in range(sim_periods):
2179 self.sim_one_period()
2181 for var_name in self.track_vars:
2182 if var_name in self.state_now:
2183 self.history[var_name][self.t_sim, :] = self.state_now[var_name]
2184 elif var_name in self.shocks:
2185 self.history[var_name][self.t_sim, :] = self.shocks[var_name]
2186 elif var_name in self.controls:
2187 self.history[var_name][self.t_sim, :] = self.controls[var_name]
2188 else:
2189 if var_name == "who_dies" and self.t_sim > 1:
2190 self.history[var_name][self.t_sim - 1, :] = getattr(
2191 self, var_name
2192 )
2193 else:
2194 self.history[var_name][self.t_sim, :] = getattr(
2195 self, var_name
2196 )
2197 self.t_sim += 1
2199 def clear_history(self):
2200 """
2201 Clears the histories of the attributes named in self.track_vars.
2203 Parameters
2204 ----------
2205 None
2207 Returns
2208 -------
2209 None
2210 """
2211 for var_name in self.track_vars:
2212 self.history[var_name] = np.empty((self.T_sim, self.AgentCount))
2213 self.history[var_name].fill(np.nan)
2215 def make_basic_SSJ(self, shock, outcomes, grids, **kwargs):
2216 """
2217 Construct and return sequence space Jacobian matrices for specified outcomes
2218 with respect to specified "shock" variable. This "basic" method only works
2219 for "one period infinite horizon" models (cycles=0, T_cycle=1). See documen-
2220 tation for simulator.make_basic_SSJ_matrices for more information.
2221 """
2222 return make_basic_SSJ_matrices(self, shock, outcomes, grids, **kwargs)
2224 def calc_impulse_response_manually(self, shock, outcomes, grids, **kwargs):
2225 """
2226 Calculate and return the impulse response(s) of a perturbation to the shock
2227 parameter in period t=s, essentially computing one column of the sequence
2228 space Jacobian matrix manually. This "basic" method only works for "one
2229 period infinite horizon" models (cycles=0, T_cycle=1). See documentation
2230 for simulator.calc_shock_response_manually for more information.
2231 """
2232 return calc_shock_response_manually(self, shock, outcomes, grids, **kwargs)
2235def solve_agent(agent, verbose, from_solution=None, from_t=None):
2236 """
2237 Solve the dynamic model for one agent type using backwards induction. This
2238 function iterates on "cycles" of an agent's model either a given number of
2239 times or until solution convergence if an infinite horizon model is used
2240 (with agent.cycles = 0).
2242 Parameters
2243 ----------
2244 agent : AgentType
2245 The microeconomic AgentType whose dynamic problem
2246 is to be solved.
2247 verbose : boolean
2248 If True, solution progress is printed to screen (when cycles != 1).
2249 from_solution: Solution
2250 If different from None, will be used as the starting point of backward
2251 induction, instead of self.solution_terminal
2252 from_t : int or None
2253 If not None, indicates which period of the model the solver should start
2254 from. It should usually only be used in combination with from_solution.
2255 Stands for the time index that from_solution represents, and thus is
2256 only compatible with cycles=1 and will be reset to None otherwise.
2258 Returns
2259 -------
2260 solution : [Solution]
2261 A list of solutions to the one period problems that the agent will
2262 encounter in his "lifetime".
2263 """
2264 # Check to see whether this is an (in)finite horizon problem
2265 cycles_left = agent.cycles # NOQA
2266 infinite_horizon = cycles_left == 0 # NOQA
2268 if from_solution is None:
2269 solution_last = agent.solution_terminal # NOQA
2270 else:
2271 solution_last = from_solution
2272 if agent.cycles != 1:
2273 from_t = None
2275 # Initialize the solution, which includes the terminal solution if it's not a pseudo-terminal period
2276 solution = []
2277 if not agent.pseudo_terminal:
2278 solution.insert(0, deepcopy(solution_last))
2280 # Initialize the process, then loop over cycles
2281 go = True # NOQA
2282 completed_cycles = 0 # NOQA
2283 max_cycles = 5000 # NOQA - escape clause
2284 if verbose:
2285 t_last = time()
2286 while go:
2287 # Solve a cycle of the model, recording it if horizon is finite
2288 solution_cycle = solve_one_cycle(agent, solution_last, from_t)
2289 if not infinite_horizon:
2290 solution = solution_cycle + solution
2292 # Check for termination: identical solutions across
2293 # cycle iterations or run out of cycles
2294 solution_now = solution_cycle[0]
2295 if infinite_horizon:
2296 if completed_cycles > 0:
2297 solution_distance = distance_metric(solution_now, solution_last)
2298 agent.solution_distance = (
2299 solution_distance # Add these attributes so users can
2300 )
2301 agent.completed_cycles = (
2302 completed_cycles # query them to see if solution is ready
2303 )
2304 go = (
2305 solution_distance > agent.tolerance
2306 and completed_cycles < max_cycles
2307 )
2308 else: # Assume solution does not converge after only one cycle
2309 solution_distance = 100.0
2310 go = True
2311 else:
2312 cycles_left += -1
2313 go = cycles_left > 0
2315 # Update the "last period solution"
2316 solution_last = solution_now
2317 completed_cycles += 1
2319 # Display progress if requested
2320 if verbose:
2321 t_now = time()
2322 if infinite_horizon:
2323 print(
2324 "Finished cycle #"
2325 + str(completed_cycles)
2326 + " in "
2327 + str(t_now - t_last)
2328 + " seconds, solution distance = "
2329 + str(solution_distance)
2330 )
2331 else:
2332 print(
2333 "Finished cycle #"
2334 + str(completed_cycles)
2335 + " of "
2336 + str(agent.cycles)
2337 + " in "
2338 + str(t_now - t_last)
2339 + " seconds."
2340 )
2341 t_last = t_now
2343 # Record the last cycle if horizon is infinite (solution is still empty!)
2344 if infinite_horizon:
2345 solution = (
2346 solution_cycle # PseudoTerminal=False impossible for infinite horizon
2347 )
2349 return solution
2352def solve_one_cycle(agent, solution_last, from_t):
2353 """
2354 Solve one "cycle" of the dynamic model for one agent type. This function
2355 iterates over the periods within an agent's cycle, updating the time-varying
2356 parameters and passing them to the single period solver(s).
2358 Parameters
2359 ----------
2360 agent : AgentType
2361 The microeconomic AgentType whose dynamic problem is to be solved.
2362 solution_last : Solution
2363 A representation of the solution of the period that comes after the
2364 end of the sequence of one period problems. This might be the term-
2365 inal period solution, a "pseudo terminal" solution, or simply the
2366 solution to the earliest period from the succeeding cycle.
2367 from_t : int or None
2368 If not None, indicates which period of the model the solver should start
2369 from. When used, represents the time index that solution_last is from.
2371 Returns
2372 -------
2373 solution_cycle : [Solution]
2374 A list of one period solutions for one "cycle" of the AgentType's
2375 microeconomic model.
2376 """
2378 # Check if the agent has a 'Parameters' attribute of the 'Parameters' class
2379 # if so, take advantage of it. Else, use the old method
2380 if hasattr(agent, "parameters") and isinstance(agent.parameters, Parameters):
2381 T = agent.parameters._length if from_t is None else from_t
2383 # Initialize the solution for this cycle, then iterate on periods
2384 solution_cycle = []
2385 solution_next = solution_last
2387 cycles_range = [0] + list(range(T - 1, 0, -1))
2388 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range:
2389 # Update which single period solver to use (if it depends on time)
2390 if hasattr(agent.solve_one_period, "__getitem__"):
2391 solve_one_period = agent.solve_one_period[k]
2392 else:
2393 solve_one_period = agent.solve_one_period
2395 if hasattr(solve_one_period, "solver_args"):
2396 these_args = solve_one_period.solver_args
2397 else:
2398 these_args = get_arg_names(solve_one_period)
2400 # Make a temporary dictionary for this period
2401 temp_pars = agent.parameters[k]
2402 temp_dict = {
2403 name: solution_next if name == "solution_next" else temp_pars[name]
2404 for name in these_args
2405 }
2407 # Solve one period, add it to the solution, and move to the next period
2408 solution_t = solve_one_period(**temp_dict)
2409 solution_cycle.insert(0, solution_t)
2410 solution_next = solution_t
2412 else:
2413 # Calculate number of periods per cycle, defaults to 1 if all variables are time invariant
2414 if len(agent.time_vary) > 0:
2415 T = agent.T_cycle if from_t is None else from_t
2416 else:
2417 T = 1
2419 solve_dict = {
2420 parameter: agent.__dict__[parameter] for parameter in agent.time_inv
2421 }
2422 solve_dict.update({parameter: None for parameter in agent.time_vary})
2424 # Initialize the solution for this cycle, then iterate on periods
2425 solution_cycle = []
2426 solution_next = solution_last
2428 cycles_range = [0] + list(range(T - 1, 0, -1))
2429 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range:
2430 # Update which single period solver to use (if it depends on time)
2431 if hasattr(agent.solve_one_period, "__getitem__"):
2432 solve_one_period = agent.solve_one_period[k]
2433 else:
2434 solve_one_period = agent.solve_one_period
2436 if hasattr(solve_one_period, "solver_args"):
2437 these_args = solve_one_period.solver_args
2438 else:
2439 these_args = get_arg_names(solve_one_period)
2441 # Update time-varying single period inputs
2442 for name in agent.time_vary:
2443 if name in these_args:
2444 solve_dict[name] = agent.__dict__[name][k]
2445 solve_dict["solution_next"] = solution_next
2447 # Make a temporary dictionary for this period
2448 temp_dict = {name: solve_dict[name] for name in these_args}
2450 # Solve one period, add it to the solution, and move to the next period
2451 solution_t = solve_one_period(**temp_dict)
2452 solution_cycle.insert(0, solution_t)
2453 solution_next = solution_t
2455 # Return the list of per-period solutions
2456 return solution_cycle
2459def make_one_period_oo_solver(solver_class):
2460 """
2461 Returns a function that solves a single period consumption-saving
2462 problem.
2463 Parameters
2464 ----------
2465 solver_class : Solver
2466 A class of Solver to be used.
2467 -------
2468 solver_function : function
2469 A function for solving one period of a problem.
2470 """
2472 def one_period_solver(**kwds):
2473 solver = solver_class(**kwds)
2475 # not ideal; better if this is defined in all Solver classes
2476 if hasattr(solver, "prepare_to_solve"):
2477 solver.prepare_to_solve()
2479 solution_now = solver.solve()
2480 return solution_now
2482 one_period_solver.solver_class = solver_class
2483 # This can be revisited once it is possible to export parameters
2484 one_period_solver.solver_args = get_arg_names(solver_class.__init__)[1:]
2486 return one_period_solver
2489# ========================================================================
2490# ========================================================================
2493class Market(Model):
2494 """
2495 A superclass to represent a central clearinghouse of information. Used for
2496 dynamic general equilibrium models to solve the "macroeconomic" model as a
2497 layer on top of the "microeconomic" models of one or more AgentTypes.
2499 Parameters
2500 ----------
2501 agents : [AgentType]
2502 A list of all the AgentTypes in this market.
2503 sow_vars : [string]
2504 Names of variables generated by the "aggregate market process" that should
2505 be "sown" to the agents in the market. Aggregate state, etc.
2506 reap_vars : [string]
2507 Names of variables to be collected ("reaped") from agents in the market
2508 to be used in the "aggregate market process".
2509 const_vars : [string]
2510 Names of attributes of the Market instance that are used in the "aggregate
2511 market process" but do not come from agents-- they are constant or simply
2512 parameters inherent to the process.
2513 track_vars : [string]
2514 Names of variables generated by the "aggregate market process" that should
2515 be tracked as a "history" so that a new dynamic rule can be calculated.
2516 This is often a subset of sow_vars.
2517 dyn_vars : [string]
2518 Names of variables that constitute a "dynamic rule".
2519 mill_rule : function
2520 A function that takes inputs named in reap_vars and returns a tuple the
2521 same size and order as sow_vars. The "aggregate market process" that
2522 transforms individual agent actions/states/data into aggregate data to
2523 be sent back to agents.
2524 calc_dynamics : function
2525 A function that takes inputs named in track_vars and returns an object
2526 with attributes named in dyn_vars. Looks at histories of aggregate
2527 variables and generates a new "dynamic rule" for agents to believe and
2528 act on.
2529 act_T : int
2530 The number of times that the "aggregate market process" should be run
2531 in order to generate a history of aggregate variables.
2532 tolerance: float
2533 Minimum acceptable distance between "dynamic rules" to consider the
2534 Market solution process converged. Distance is a user-defined metric.
2535 """
2537 def __init__(
2538 self,
2539 agents=None,
2540 sow_vars=None,
2541 reap_vars=None,
2542 const_vars=None,
2543 track_vars=None,
2544 dyn_vars=None,
2545 mill_rule=None,
2546 calc_dynamics=None,
2547 distributions=None,
2548 act_T=1000,
2549 tolerance=0.000001,
2550 seed=0,
2551 **kwds,
2552 ):
2553 super().__init__()
2554 self.agents = agents if agents is not None else list() # NOQA
2556 self.reap_vars = reap_vars if reap_vars is not None else list() # NOQA
2557 self.reap_state = {var: [] for var in self.reap_vars}
2559 self.sow_vars = sow_vars if sow_vars is not None else list() # NOQA
2560 # dictionaries for tracking initial and current values
2561 # of the sow variables.
2562 self.sow_init = {var: None for var in self.sow_vars}
2563 self.sow_state = {var: None for var in self.sow_vars}
2565 const_vars = const_vars if const_vars is not None else list() # NOQA
2566 self.const_vars = {var: None for var in const_vars}
2568 self.track_vars = track_vars if track_vars is not None else list() # NOQA
2569 self.dyn_vars = dyn_vars if dyn_vars is not None else list() # NOQA
2570 self.distributions = distributions if distributions is not None else list() # NOQA
2572 if mill_rule is not None: # To prevent overwriting of method-based mill_rules
2573 self.mill_rule = mill_rule
2574 if calc_dynamics is not None: # Ditto for calc_dynamics
2575 self.calc_dynamics = calc_dynamics
2576 self.act_T = act_T # NOQA
2577 self.tolerance = tolerance # NOQA
2578 self.seed = seed
2579 self.max_loops = 1000 # NOQA
2580 self.history = {}
2581 self.assign_parameters(**kwds)
2582 self.RNG = np.random.default_rng(self.seed)
2584 self.print_parallel_error_once = True
2585 # Print the error associated with calling the parallel method
2586 # "solve_agents" one time. If set to false, the error will never
2587 # print. See "solve_agents" for why this prints once or never.
2589 def give_agent_params(self, construct=True):
2590 """
2591 Distribute relevant market-level parameters to each AgentType in self.agents
2592 by having them call their get_market_params method.
2594 Parameters
2595 ----------
2596 construct : bool, optional
2597 Whether agents should run their construct method after fetching market
2598 data (default True).
2600 Returns
2601 -------
2602 None
2603 """
2604 for agent in self.agents:
2605 agent.get_market_params(self, construct)
2607 def solve_agents(self):
2608 """
2609 Solves the microeconomic problem for all AgentTypes in this market.
2611 Parameters
2612 ----------
2613 None
2615 Returns
2616 -------
2617 None
2618 """
2619 try:
2620 multi_thread_commands(self.agents, ["solve()"])
2621 except Exception as err:
2622 if self.print_parallel_error_once:
2623 # Set flag to False so this is only printed once.
2624 self.print_parallel_error_once = False
2625 print(
2626 "**** WARNING: could not execute multi_thread_commands in HARK.core.Market.solve_agents() ",
2627 "so using the serial version instead. This will likely be slower. "
2628 "The multi_thread_commands() functions failed with the following error:",
2629 "\n",
2630 sys.exc_info()[0],
2631 ":",
2632 err,
2633 ) # sys.exc_info()[0])
2634 multi_thread_commands_fake(self.agents, ["solve()"])
2636 def solve(self):
2637 """
2638 "Solves" the market by finding a "dynamic rule" that governs the aggregate
2639 market state such that when agents believe in these dynamics, their actions
2640 collectively generate the same dynamic rule.
2642 Parameters
2643 ----------
2644 None
2646 Returns
2647 -------
2648 None
2649 """
2650 go = True
2651 max_loops = self.max_loops # Failsafe against infinite solution loop
2652 completed_loops = 0
2653 old_dynamics = None
2655 while go: # Loop until the dynamic process converges or we hit the loop cap
2656 self.solve_agents() # Solve each AgentType's micro problem
2657 self.make_history() # "Run" the model while tracking aggregate variables
2658 new_dynamics = self.update_dynamics() # Find a new aggregate dynamic rule
2660 # Check to see if the dynamic rule has converged (if this is not the first loop)
2661 if completed_loops > 0:
2662 distance = new_dynamics.distance(old_dynamics)
2663 else:
2664 distance = 1000000.0
2666 # Move to the next loop if the terminal conditions are not met
2667 old_dynamics = new_dynamics
2668 completed_loops += 1
2669 go = distance >= self.tolerance and completed_loops < max_loops
2671 self.dynamics = new_dynamics # Store the final dynamic rule in self
2673 def reap(self):
2674 """
2675 Collects attributes named in reap_vars from each AgentType in the market,
2676 storing them in respectively named attributes of self.
2678 Parameters
2679 ----------
2680 none
2682 Returns
2683 -------
2684 none
2685 """
2686 for var in self.reap_state:
2687 harvest = []
2689 for agent in self.agents:
2690 # TODO: generalized variable lookup across namespaces
2691 if var in agent.state_now:
2692 # or state_now ??
2693 harvest.append(agent.state_now[var])
2695 self.reap_state[var] = harvest
2697 def sow(self):
2698 """
2699 Distributes attrributes named in sow_vars from self to each AgentType
2700 in the market, storing them in respectively named attributes.
2702 Parameters
2703 ----------
2704 none
2706 Returns
2707 -------
2708 none
2709 """
2710 for sow_var in self.sow_state:
2711 for this_type in self.agents:
2712 if sow_var in this_type.state_now:
2713 this_type.state_now[sow_var] = self.sow_state[sow_var]
2714 if sow_var in this_type.shocks:
2715 this_type.shocks[sow_var] = self.sow_state[sow_var]
2716 else:
2717 setattr(this_type, sow_var, self.sow_state[sow_var])
2719 def mill(self):
2720 """
2721 Processes the variables collected from agents using the function mill_rule,
2722 storing the results in attributes named in aggr_sow.
2724 Parameters
2725 ----------
2726 none
2728 Returns
2729 -------
2730 none
2731 """
2732 # Make a dictionary of inputs for the mill_rule
2733 mill_dict = copy(self.reap_state)
2734 mill_dict.update(self.const_vars)
2736 # Run the mill_rule and store its output in self
2737 product = self.mill_rule(**mill_dict)
2739 for i, sow_var in enumerate(self.sow_state):
2740 self.sow_state[sow_var] = product[i]
2742 def cultivate(self):
2743 """
2744 Has each AgentType in agents perform their market_action method, using
2745 variables sown from the market (and maybe also "private" variables).
2746 The market_action method should store new results in attributes named in
2747 reap_vars to be reaped later.
2749 Parameters
2750 ----------
2751 none
2753 Returns
2754 -------
2755 none
2756 """
2757 for this_type in self.agents:
2758 this_type.market_action()
2760 def reset(self):
2761 """
2762 Reset the state of the market (attributes in sow_vars, etc) to some
2763 user-defined initial state, and erase the histories of tracked variables.
2764 Also resets the internal RNG so that draws can be reproduced.
2766 Parameters
2767 ----------
2768 none
2770 Returns
2771 -------
2772 none
2773 """
2774 # Reset internal RNG and distributions
2775 for name in self.distributions:
2776 if not hasattr(self, name):
2777 continue
2778 dstn = getattr(self, name)
2779 if isinstance(dstn, list):
2780 for D in dstn:
2781 D.reset()
2782 else:
2783 dstn.reset()
2785 # Reset the history of tracked variables
2786 self.history = {var_name: [] for var_name in self.track_vars}
2788 # Set the sow variables to their initial levels
2789 for var_name in self.sow_state:
2790 self.sow_state[var_name] = self.sow_init[var_name]
2792 # Reset each AgentType in the market
2793 for this_type in self.agents:
2794 this_type.reset()
2796 def store(self):
2797 """
2798 Record the current value of each variable X named in track_vars in an
2799 dictionary field named history[X].
2801 Parameters
2802 ----------
2803 none
2805 Returns
2806 -------
2807 none
2808 """
2809 for var_name in self.track_vars:
2810 if var_name in self.sow_state:
2811 value_now = self.sow_state[var_name]
2812 elif var_name in self.reap_state:
2813 value_now = self.reap_state[var_name]
2814 elif var_name in self.const_vars:
2815 value_now = self.const_vars[var_name]
2816 else:
2817 value_now = getattr(self, var_name)
2819 self.history[var_name].append(value_now)
2821 def make_history(self):
2822 """
2823 Runs a loop of sow-->cultivate-->reap-->mill act_T times, tracking the
2824 evolution of variables X named in track_vars in dictionary fields
2825 history[X].
2827 Parameters
2828 ----------
2829 none
2831 Returns
2832 -------
2833 none
2834 """
2835 self.reset() # Initialize the state of the market
2836 for t in range(self.act_T):
2837 self.sow() # Distribute aggregated information/state to agents
2838 self.cultivate() # Agents take action
2839 self.reap() # Collect individual data from agents
2840 self.mill() # Process individual data into aggregate data
2841 self.store() # Record variables of interest
2843 def update_dynamics(self):
2844 """
2845 Calculates a new "aggregate dynamic rule" using the history of variables
2846 named in track_vars, and distributes this rule to AgentTypes in agents.
2848 Parameters
2849 ----------
2850 none
2852 Returns
2853 -------
2854 dynamics : instance
2855 The new "aggregate dynamic rule" that agents believe in and act on.
2856 Should have attributes named in dyn_vars.
2857 """
2858 # Make a dictionary of inputs for the dynamics calculator
2859 arg_names = list(get_arg_names(self.calc_dynamics))
2860 if "self" in arg_names:
2861 arg_names.remove("self")
2862 update_dict = {}
2863 for name in arg_names:
2864 update_dict[name] = (
2865 self.history[name] if name in self.track_vars else getattr(self, name)
2866 )
2867 # Calculate a new dynamic rule and distribute it to the agents in agent_list
2868 dynamics = self.calc_dynamics(**update_dict) # User-defined dynamics calculator
2869 for var_name in self.dyn_vars:
2870 this_obj = getattr(dynamics, var_name)
2871 for this_type in self.agents:
2872 setattr(this_type, var_name, this_obj)
2873 return dynamics
2876def distribute_params(agent, param_name, param_count, distribution):
2877 """
2878 Distributes heterogeneous values of one parameter to the AgentTypes in self.agents.
2879 Parameters
2880 ----------
2881 agent: AgentType
2882 An agent to clone.
2883 param_name : string
2884 Name of the parameter to be assigned.
2885 param_count : int
2886 Number of different values the parameter will take on.
2887 distribution : Distribution
2888 A 1-D distribution.
2890 Returns
2891 -------
2892 agent_set : [AgentType]
2893 A list of param_count agents, ex ante heterogeneous with
2894 respect to param_name. The AgentCount of the original
2895 will be split between the agents of the returned
2896 list in proportion to the given distribution.
2897 """
2898 param_dist = distribution.discretize(N=param_count)
2900 agent_set = [deepcopy(agent) for i in range(param_count)]
2902 for j in range(param_count):
2903 agent_set[j].assign_parameters(
2904 **{"AgentCount": int(agent.AgentCount * param_dist.pmv[j])}
2905 )
2906 agent_set[j].assign_parameters(**{param_name: param_dist.atoms[0, j]})
2908 return agent_set
2911@dataclass
2912class AgentPopulation:
2913 """
2914 A class for representing a population of ex-ante heterogeneous agents.
2915 """
2917 agent_type: AgentType # type of agent in the population
2918 parameters: dict # dictionary of parameters
2919 seed: int = 0 # random seed
2920 time_var: List[str] = field(init=False)
2921 time_inv: List[str] = field(init=False)
2922 distributed_params: List[str] = field(init=False)
2923 agent_type_count: Optional[int] = field(init=False)
2924 term_age: Optional[int] = field(init=False)
2925 continuous_distributions: Dict[str, Distribution] = field(init=False)
2926 discrete_distributions: Dict[str, Distribution] = field(init=False)
2927 population_parameters: List[Dict[str, Union[List[float], float]]] = field(
2928 init=False
2929 )
2930 agents: List[AgentType] = field(init=False)
2931 agent_database: pd.DataFrame = field(init=False)
2932 solution: List[Any] = field(init=False)
2934 def __post_init__(self):
2935 """
2936 Initialize the population of agents, determine distributed parameters,
2937 and infer `agent_type_count` and `term_age`.
2938 """
2939 # create a dummy agent and obtain its time-varying
2940 # and time-invariant attributes
2941 dummy_agent = self.agent_type()
2942 self.time_var = dummy_agent.time_vary
2943 self.time_inv = dummy_agent.time_inv
2945 # create list of distributed parameters
2946 # these are parameters that differ across agents
2947 self.distributed_params = [
2948 key
2949 for key, param in self.parameters.items()
2950 if (isinstance(param, list) and isinstance(param[0], list))
2951 or isinstance(param, Distribution)
2952 or (isinstance(param, DataArray) and param.dims[0] == "agent")
2953 ]
2955 self.__infer_counts__()
2957 self.print_parallel_error_once = True
2958 # Print warning once if parallel simulation fails
2960 def __infer_counts__(self):
2961 """
2962 Infer `agent_type_count` and `term_age` from the parameters.
2963 If parameters include a `Distribution` type, a list of lists,
2964 or a `DataArray` with `agent` as the first dimension, then
2965 the AgentPopulation contains ex-ante heterogenous agents.
2966 """
2968 # infer agent_type_count from distributed parameters
2969 agent_type_count = 1
2970 for key in self.distributed_params:
2971 param = self.parameters[key]
2972 if isinstance(param, Distribution):
2973 agent_type_count = None
2974 warn(
2975 "Cannot infer agent_type_count from a Distribution. "
2976 "Please provide approximation parameters."
2977 )
2978 break
2979 elif isinstance(param, list):
2980 agent_type_count = max(agent_type_count, len(param))
2981 elif isinstance(param, DataArray) and param.dims[0] == "agent":
2982 agent_type_count = max(agent_type_count, param.shape[0])
2984 self.agent_type_count = agent_type_count
2986 # infer term_age from all parameters
2987 term_age = 1
2988 for param in self.parameters.values():
2989 if isinstance(param, Distribution):
2990 term_age = None
2991 warn(
2992 "Cannot infer term_age from a Distribution. "
2993 "Please provide approximation parameters."
2994 )
2995 break
2996 elif isinstance(param, list) and isinstance(param[0], list):
2997 term_age = max(term_age, len(param[0]))
2998 elif isinstance(param, DataArray) and param.dims[-1] == "age":
2999 term_age = max(term_age, param.shape[-1])
3001 self.term_age = term_age
3003 def approx_distributions(self, approx_params: dict):
3004 """
3005 Approximate continuous distributions with discrete ones. If the initial
3006 parameters include a `Distribution` type, then the AgentPopulation is
3007 not ready to solve, and stands for an abstract population. To solve the
3008 AgentPopulation, we need discretization parameters for each continuous
3009 distribution. This method approximates the continuous distributions with
3010 discrete ones, and updates the parameters dictionary.
3011 """
3012 self.continuous_distributions = {}
3013 self.discrete_distributions = {}
3015 for key, args in approx_params.items():
3016 param = self.parameters[key]
3017 if key in self.distributed_params and isinstance(param, Distribution):
3018 self.continuous_distributions[key] = param
3019 self.discrete_distributions[key] = param.discretize(**args)
3020 else:
3021 raise ValueError(
3022 f"Warning: parameter {key} is not a Distribution found "
3023 f"in agent type {self.agent_type}"
3024 )
3026 if len(self.discrete_distributions) > 1:
3027 joint_dist = combine_indep_dstns(*self.discrete_distributions.values())
3028 else:
3029 joint_dist = list(self.discrete_distributions.values())[0]
3031 for i, key in enumerate(self.discrete_distributions):
3032 self.parameters[key] = DataArray(joint_dist.atoms[i], dims=("agent"))
3034 self.__infer_counts__()
3036 def __parse_parameters__(self) -> None:
3037 """
3038 Creates distributed dictionaries of parameters for each ex-ante
3039 heterogeneous agent in the parameterized population. The parameters
3040 are stored in a list of dictionaries, where each dictionary contains
3041 the parameters for one agent. Expands parameters that vary over time
3042 to a list of length `term_age`.
3043 """
3045 population_parameters = [] # container for dictionaries of each agent subgroup
3046 for agent in range(self.agent_type_count):
3047 agent_parameters = {}
3048 for key, param in self.parameters.items():
3049 if key in self.time_var:
3050 # parameters that vary over time have to be repeated
3051 if isinstance(param, (int, float)):
3052 parameter_per_t = [param] * self.term_age
3053 elif isinstance(param, list):
3054 if isinstance(param[0], list):
3055 parameter_per_t = param[agent]
3056 else:
3057 parameter_per_t = param
3058 elif isinstance(param, DataArray):
3059 if param.dims[0] == "agent":
3060 if len(param.dims) > 1 and param.dims[-1] == "age":
3061 parameter_per_t = param[agent].values.tolist()
3062 else:
3063 parameter_per_t = param[agent].item()
3064 elif param.dims[0] == "age":
3065 parameter_per_t = param.values.tolist()
3067 agent_parameters[key] = parameter_per_t
3069 elif key in self.time_inv:
3070 if isinstance(param, (int, float)):
3071 agent_parameters[key] = param
3072 elif isinstance(param, list):
3073 if isinstance(param[0], list):
3074 agent_parameters[key] = param[agent]
3075 else:
3076 agent_parameters[key] = param
3077 elif isinstance(param, DataArray) and param.dims[0] == "agent":
3078 agent_parameters[key] = param[agent].item()
3080 else:
3081 if isinstance(param, (int, float)):
3082 agent_parameters[key] = param # assume time inv
3083 elif isinstance(param, list):
3084 if isinstance(param[0], list):
3085 agent_parameters[key] = param[agent] # assume agent vary
3086 else:
3087 agent_parameters[key] = param # assume time vary
3088 elif isinstance(param, DataArray):
3089 if param.dims[0] == "agent":
3090 if len(param.dims) > 1 and param.dims[-1] == "age":
3091 agent_parameters[key] = param[agent].values.tolist()
3092 else:
3093 agent_parameters[key] = param[agent].item()
3094 elif param.dims[0] == "age":
3095 agent_parameters[key] = param.values.tolist()
3097 population_parameters.append(agent_parameters)
3099 self.population_parameters = population_parameters
3101 def create_distributed_agents(self):
3102 """
3103 Parses the parameters dictionary and creates a list of agents with the
3104 appropriate parameters. Also sets the seed for each agent.
3105 """
3107 self.__parse_parameters__()
3109 rng = np.random.default_rng(self.seed)
3111 self.agents = [
3112 self.agent_type(seed=rng.integers(0, 2**31 - 1), **agent_dict)
3113 for agent_dict in self.population_parameters
3114 ]
3116 def create_database(self):
3117 """
3118 Optionally creates a pandas DataFrame with the parameters for each agent.
3119 """
3120 database = pd.DataFrame(self.population_parameters)
3121 database["agents"] = self.agents
3123 self.agent_database = database
3125 def solve(self):
3126 """
3127 Solves each agent of the population serially.
3128 """
3130 # see Market class for an example of how to solve distributed agents in parallel
3132 for agent in self.agents:
3133 agent.solve()
3135 def unpack_solutions(self):
3136 """
3137 Unpacks the solutions of each agent into an attribute of the population.
3138 """
3139 self.solution = [agent.solution for agent in self.agents]
3141 def initialize_sim(self):
3142 """
3143 Initializes the simulation for each agent.
3144 """
3145 for agent in self.agents:
3146 agent.initialize_sim()
3148 def simulate(self, num_jobs=None):
3149 """
3150 Simulates each agent of the population.
3152 Parameters
3153 ----------
3154 num_jobs : int, optional
3155 Number of parallel jobs to use. Defaults to using all available
3156 cores when ``None``. Falls back to serial execution if parallel
3157 processing fails.
3158 """
3159 try:
3160 multi_thread_commands(self.agents, ["simulate()"], num_jobs)
3161 except Exception as err:
3162 if getattr(self, "print_parallel_error_once", False):
3163 self.print_parallel_error_once = False
3164 print(
3165 "**** WARNING: could not execute multi_thread_commands in HARK.core.AgentPopulation.simulate() ",
3166 "so using the serial version instead. This will likely be slower. ",
3167 "The multi_thread_commands() function failed with the following error:\n",
3168 sys.exc_info()[0],
3169 ":",
3170 err,
3171 )
3172 multi_thread_commands_fake(self.agents, ["simulate()"], num_jobs)
3174 def __iter__(self):
3175 """
3176 Allows for iteration over the agents in the population.
3177 """
3178 return iter(self.agents)
3180 def __getitem__(self, idx):
3181 """
3182 Allows for indexing into the population.
3183 """
3184 return self.agents[idx]
3187###############################################################################
3190def multi_thread_commands_fake(
3191 agent_list: List, command_list: List, num_jobs=None
3192) -> None:
3193 """
3194 Executes the list of commands in command_list for each AgentType in agent_list
3195 in an ordinary, single-threaded loop. Each command should be a method of
3196 that AgentType subclass. This function exists so as to easily disable
3197 multithreading, as it uses the same syntax as multi_thread_commands.
3199 Parameters
3200 ----------
3201 agent_list : [AgentType]
3202 A list of instances of AgentType on which the commands will be run.
3203 command_list : [string]
3204 A list of commands to run for each AgentType.
3205 num_jobs : None
3206 Dummy input to match syntax of multi_thread_commands. Does nothing.
3208 Returns
3209 -------
3210 none
3211 """
3212 for agent in agent_list:
3213 for command in command_list:
3214 # Can pass method names with or without parentheses
3215 if command[-2:] == "()":
3216 getattr(agent, command[:-2])()
3217 else:
3218 getattr(agent, command)()
3221def multi_thread_commands(agent_list: List, command_list: List, num_jobs=None) -> None:
3222 """
3223 Executes the list of commands in command_list for each AgentType in agent_list
3224 using a multithreaded system. Each command should be a method of that AgentType subclass.
3226 Parameters
3227 ----------
3228 agent_list : [AgentType]
3229 A list of instances of AgentType on which the commands will be run.
3230 command_list : [string]
3231 A list of commands to run for each AgentType in agent_list.
3233 Returns
3234 -------
3235 None
3236 """
3237 if len(agent_list) == 1:
3238 multi_thread_commands_fake(agent_list, command_list)
3239 return None
3241 # Default number of parallel jobs is the smaller of number of AgentTypes in
3242 # the input and the number of available cores.
3243 if num_jobs is None:
3244 num_jobs = min(len(agent_list), multiprocessing.cpu_count())
3246 # Send each command in command_list to each of the types in agent_list to be run
3247 agent_list_out = Parallel(n_jobs=num_jobs)(
3248 delayed(run_commands)(*args)
3249 for args in zip(agent_list, len(agent_list) * [command_list])
3250 )
3252 # Replace the original types with the output from the parallel call
3253 for j in range(len(agent_list)):
3254 agent_list[j] = agent_list_out[j]
3257def run_commands(agent: Any, command_list: List) -> Any:
3258 """
3259 Executes each command in command_list on a given AgentType. The commands
3260 should be methods of that AgentType's subclass.
3262 Parameters
3263 ----------
3264 agent : AgentType
3265 An instance of AgentType on which the commands will be run.
3266 command_list : [string]
3267 A list of commands that the agent should run, as methods.
3269 Returns
3270 -------
3271 agent : AgentType
3272 The same AgentType instance passed as input, after running the commands.
3273 """
3274 for command in command_list:
3275 # Can pass method names with or without parentheses
3276 if command[-2:] == "()":
3277 getattr(agent, command[:-2])()
3278 else:
3279 getattr(agent, command)()
3280 return agent