Coverage for HARK / core.py: 96%
967 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
1"""
2High-level functions and classes for solving a wide variety of economic models.
3The "core" of HARK is a framework for "microeconomic" and "macroeconomic"
4models. A micro model concerns the dynamic optimization problem for some type
5of agents, where agents take the inputs to their problem as exogenous. A macro
6model adds an additional layer, endogenizing some of the inputs to the micro
7problem by finding a general equilibrium dynamic rule.
8"""
10# Import basic modules
11import inspect
12import sys
13from collections import namedtuple
14from copy import copy, deepcopy
15from dataclasses import dataclass, field
16from time import time
17from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
18from warnings import warn
19import multiprocessing
20from joblib import Parallel, delayed
22import numpy as np
23import pandas as pd
24from xarray import DataArray
26from HARK.distributions import (
27 Distribution,
28 IndexDistribution,
29 combine_indep_dstns,
30)
31from HARK.utilities import NullFunc, get_arg_names, get_it_from
32from HARK.simulator import make_simulator_from_agent
33from HARK.SSJutils import (
34 make_basic_SSJ_matrices,
35 calc_shock_response_manually,
36)
37from HARK.metric import MetricObject, distance_metric
39__all__ = [
40 "AgentType",
41 "Market",
42 "Parameters",
43 "Model",
44 "AgentPopulation",
45 "multi_thread_commands",
46 "multi_thread_commands_fake",
47 "NullFunc",
48 "make_one_period_oo_solver",
49 "distribute_params",
50]
53class Parameters:
54 """
55 A smart container for model parameters that handles age-varying dynamics.
57 This class stores parameters as an internal dictionary and manages their
58 age-varying properties, providing both attribute-style and dictionary-style
59 access. It is designed to handle the time-varying dynamics of parameters
60 in economic models.
62 Attributes
63 ----------
64 _length : int
65 The terminal age of the agents in the model.
66 _invariant_params : Set[str]
67 A set of parameter names that are invariant over time.
68 _varying_params : Set[str]
69 A set of parameter names that vary over time.
70 _parameters : Dict[str, Any]
71 The internal dictionary storing all parameters.
72 """
74 __slots__ = (
75 "_length",
76 "_invariant_params",
77 "_varying_params",
78 "_parameters",
79 "_frozen",
80 "_namedtuple_cache",
81 )
83 def __init__(self, **parameters: Any) -> None:
84 """
85 Initialize a Parameters object and parse the age-varying dynamics of parameters.
87 Parameters
88 ----------
89 T_cycle : int, optional
90 The number of time periods in the model cycle (default: 1).
91 Must be >= 1.
92 frozen : bool, optional
93 If True, the Parameters object will be immutable after initialization
94 (default: False).
95 _time_inv : List[str], optional
96 List of parameter names to explicitly mark as time-invariant,
97 overriding automatic inference.
98 _time_vary : List[str], optional
99 List of parameter names to explicitly mark as time-varying,
100 overriding automatic inference.
101 **parameters : Any
102 Any number of parameters in the form key=value.
104 Raises
105 ------
106 ValueError
107 If T_cycle is less than 1.
109 Notes
110 -----
111 Automatic time-variance inference rules:
112 - Scalars (int, float, bool, None) are time-invariant
113 - NumPy arrays are time-invariant (use lists/tuples for time-varying)
114 - Single-element lists/tuples [x] are unwrapped to x and time-invariant
115 - Multi-element lists/tuples are time-varying if length matches T_cycle
116 - 2D arrays with first dimension matching T_cycle are time-varying
117 - Distributions and Callables are time-invariant
119 Use _time_inv or _time_vary to override automatic inference when needed.
120 """
121 # Extract special parameters
122 self._length: int = parameters.pop("T_cycle", 1)
123 frozen: bool = parameters.pop("frozen", False)
124 time_inv_override: List[str] = parameters.pop("_time_inv", [])
125 time_vary_override: List[str] = parameters.pop("_time_vary", [])
127 # Validate T_cycle
128 if self._length < 1:
129 raise ValueError(f"T_cycle must be >= 1, got {self._length}")
131 # Initialize internal state
132 self._invariant_params: Set[str] = set()
133 self._varying_params: Set[str] = set()
134 self._parameters: Dict[str, Any] = {"T_cycle": self._length}
135 self._frozen: bool = False # Set to False initially to allow setup
136 self._namedtuple_cache: Optional[type] = None
138 # Set parameters using automatic inference
139 for key, value in parameters.items():
140 self[key] = value
142 # Apply explicit overrides
143 for param in time_inv_override:
144 if param in self._parameters:
145 self._invariant_params.add(param)
146 self._varying_params.discard(param)
148 for param in time_vary_override:
149 if param in self._parameters:
150 self._varying_params.add(param)
151 self._invariant_params.discard(param)
153 # Freeze if requested
154 self._frozen = frozen
156 def __getitem__(self, item_or_key: Union[int, str]) -> Union["Parameters", Any]:
157 """
158 Access parameters by age index or parameter name.
160 If item_or_key is an integer, returns a Parameters object with the parameters
161 that apply to that age. This includes all invariant parameters and the
162 `item_or_key`th element of all age-varying parameters. If item_or_key is a
163 string, it returns the value of the parameter with that name.
165 Parameters
166 ----------
167 item_or_key : Union[int, str]
168 Age index or parameter name.
170 Returns
171 -------
172 Union[Parameters, Any]
173 A new Parameters object for the specified age, or the value of the
174 specified parameter.
176 Raises
177 ------
178 ValueError:
179 If the age index is out of bounds.
180 KeyError:
181 If the parameter name is not found.
182 TypeError:
183 If the key is neither an integer nor a string.
184 """
185 if isinstance(item_or_key, int):
186 if item_or_key < 0 or item_or_key >= self._length:
187 raise ValueError(
188 f"Age {item_or_key} is out of bounds (valid: 0-{self._length - 1})."
189 )
191 params = {key: self._parameters[key] for key in self._invariant_params}
192 params.update(
193 {
194 key: (
195 self._parameters[key][item_or_key]
196 if isinstance(self._parameters[key], (list, tuple, np.ndarray))
197 else self._parameters[key]
198 )
199 for key in self._varying_params
200 }
201 )
202 return Parameters(**params)
203 elif isinstance(item_or_key, str):
204 return self._parameters[item_or_key]
205 else:
206 raise TypeError("Key must be an integer (age) or string (parameter name).")
208 def __setitem__(self, key: str, value: Any) -> None:
209 """
210 Set parameter values, automatically inferring time variance.
212 If the parameter is a scalar, numpy array, boolean, distribution, callable
213 or None, it is assumed to be invariant over time. If the parameter is a
214 list or tuple, it is assumed to be varying over time. If the parameter
215 is a list or tuple of length greater than 1, the length of the list or
216 tuple must match the `_length` attribute of the Parameters object.
218 2D numpy arrays with first dimension matching T_cycle are treated as
219 time-varying parameters.
221 Parameters
222 ----------
223 key : str
224 Name of the parameter.
225 value : Any
226 Value of the parameter.
228 Raises
229 ------
230 ValueError:
231 If the parameter name is not a string or if the value type is unsupported.
232 If the parameter value is inconsistent with the current model length.
233 RuntimeError:
234 If the Parameters object is frozen.
235 """
236 if self._frozen:
237 raise RuntimeError("Cannot modify frozen Parameters object")
239 if not isinstance(key, str):
240 raise ValueError(f"Parameter name must be a string, got {type(key)}")
242 # Check for 2D numpy arrays with time-varying first dimension
243 if isinstance(value, np.ndarray) and value.ndim >= 2:
244 if value.shape[0] == self._length:
245 self._varying_params.add(key)
246 self._invariant_params.discard(key)
247 else:
248 self._invariant_params.add(key)
249 self._varying_params.discard(key)
250 elif isinstance(
251 value,
252 (
253 int,
254 float,
255 np.ndarray,
256 type(None),
257 Distribution,
258 bool,
259 Callable,
260 MetricObject,
261 ),
262 ):
263 self._invariant_params.add(key)
264 self._varying_params.discard(key)
265 elif isinstance(value, (list, tuple)):
266 if len(value) == 1:
267 value = value[0]
268 self._invariant_params.add(key)
269 self._varying_params.discard(key)
270 elif self._length is None or self._length == 1:
271 self._length = len(value)
272 self._varying_params.add(key)
273 self._invariant_params.discard(key)
274 elif len(value) == self._length:
275 self._varying_params.add(key)
276 self._invariant_params.discard(key)
277 else:
278 raise ValueError(
279 f"Parameter {key} must have length 1 or {self._length}, not {len(value)}"
280 )
281 else:
282 raise ValueError(f"Unsupported type for parameter {key}: {type(value)}")
284 self._parameters[key] = value
286 def __iter__(self) -> Iterator[str]:
287 """Allow iteration over parameter names."""
288 return iter(self._parameters)
290 def __len__(self) -> int:
291 """Return the number of parameters."""
292 return len(self._parameters)
294 def keys(self) -> Iterator[str]:
295 """Return a view of parameter names."""
296 return self._parameters.keys()
298 def values(self) -> Iterator[Any]:
299 """Return a view of parameter values."""
300 return self._parameters.values()
302 def items(self) -> Iterator[Tuple[str, Any]]:
303 """Return a view of parameter (name, value) pairs."""
304 return self._parameters.items()
306 def to_dict(self) -> Dict[str, Any]:
307 """
308 Convert parameters to a plain dictionary.
310 Returns
311 -------
312 Dict[str, Any]
313 A dictionary containing all parameters.
314 """
315 return dict(self._parameters)
317 def to_namedtuple(self) -> namedtuple:
318 """
319 Convert parameters to a namedtuple.
321 The namedtuple class is cached for efficiency on repeated calls.
323 Returns
324 -------
325 namedtuple
326 A namedtuple containing all parameters.
327 """
328 if self._namedtuple_cache is None:
329 self._namedtuple_cache = namedtuple("Parameters", self.keys())
330 return self._namedtuple_cache(**self.to_dict())
332 def update(self, other: Union["Parameters", Dict[str, Any]]) -> None:
333 """
334 Update parameters from another Parameters object or dictionary.
336 Parameters
337 ----------
338 other : Union[Parameters, Dict[str, Any]]
339 The source of parameters to update from.
341 Raises
342 ------
343 TypeError
344 If the input is neither a Parameters object nor a dictionary.
345 """
346 if isinstance(other, Parameters):
347 for key, value in other._parameters.items():
348 self[key] = value
349 elif isinstance(other, dict):
350 for key, value in other.items():
351 self[key] = value
352 else:
353 raise TypeError(
354 "Update source must be a Parameters object or a dictionary."
355 )
357 def __repr__(self) -> str:
358 """Return a detailed string representation of the Parameters object."""
359 return (
360 f"Parameters(_length={self._length}, "
361 f"_invariant_params={self._invariant_params}, "
362 f"_varying_params={self._varying_params}, "
363 f"_parameters={self._parameters})"
364 )
366 def __str__(self) -> str:
367 """Return a simple string representation of the Parameters object."""
368 return f"Parameters({str(self._parameters)})"
370 def __getattr__(self, name: str) -> Any:
371 """
372 Allow attribute-style access to parameters.
374 Parameters
375 ----------
376 name : str
377 Name of the parameter to access.
379 Returns
380 -------
381 Any
382 The value of the specified parameter.
384 Raises
385 ------
386 AttributeError:
387 If the parameter name is not found.
388 """
389 if name.startswith("_"):
390 return super().__getattribute__(name)
391 try:
392 return self._parameters[name]
393 except KeyError:
394 raise AttributeError(f"'Parameters' object has no attribute '{name}'")
396 def __setattr__(self, name: str, value: Any) -> None:
397 """
398 Allow attribute-style setting of parameters.
400 Parameters
401 ----------
402 name : str
403 Name of the parameter to set.
404 value : Any
405 Value to set for the parameter.
406 """
407 if name.startswith("_"):
408 super().__setattr__(name, value)
409 else:
410 self[name] = value
412 def __contains__(self, item: str) -> bool:
413 """Check if a parameter exists in the Parameters object."""
414 return item in self._parameters
416 def copy(self) -> "Parameters":
417 """
418 Create a deep copy of the Parameters object.
420 Returns
421 -------
422 Parameters
423 A new Parameters object with the same contents.
424 """
425 return deepcopy(self)
427 def add_to_time_vary(self, *params: str) -> None:
428 """
429 Adds any number of parameters to the time-varying set.
431 Parameters
432 ----------
433 *params : str
434 Any number of strings naming parameters to be added to time_vary.
435 """
436 for param in params:
437 if param in self._parameters:
438 self._varying_params.add(param)
439 self._invariant_params.discard(param)
440 else:
441 warn(
442 f"Parameter '{param}' does not exist and cannot be added to time_vary."
443 )
445 def add_to_time_inv(self, *params: str) -> None:
446 """
447 Adds any number of parameters to the time-invariant set.
449 Parameters
450 ----------
451 *params : str
452 Any number of strings naming parameters to be added to time_inv.
453 """
454 for param in params:
455 if param in self._parameters:
456 self._invariant_params.add(param)
457 self._varying_params.discard(param)
458 else:
459 warn(
460 f"Parameter '{param}' does not exist and cannot be added to time_inv."
461 )
463 def del_from_time_vary(self, *params: str) -> None:
464 """
465 Removes any number of parameters from the time-varying set.
467 Parameters
468 ----------
469 *params : str
470 Any number of strings naming parameters to be removed from time_vary.
471 """
472 for param in params:
473 self._varying_params.discard(param)
475 def del_from_time_inv(self, *params: str) -> None:
476 """
477 Removes any number of parameters from the time-invariant set.
479 Parameters
480 ----------
481 *params : str
482 Any number of strings naming parameters to be removed from time_inv.
483 """
484 for param in params:
485 self._invariant_params.discard(param)
487 def get(self, key: str, default: Any = None) -> Any:
488 """
489 Get a parameter value, returning a default if not found.
491 Parameters
492 ----------
493 key : str
494 The parameter name.
495 default : Any, optional
496 The default value to return if the key is not found.
498 Returns
499 -------
500 Any
501 The parameter value or the default.
502 """
503 return self._parameters.get(key, default)
505 def set_many(self, **kwargs: Any) -> None:
506 """
507 Set multiple parameters at once.
509 Parameters
510 ----------
511 **kwargs : Keyword arguments representing parameter names and values.
512 """
513 for key, value in kwargs.items():
514 self[key] = value
516 def is_time_varying(self, key: str) -> bool:
517 """
518 Check if a parameter is time-varying.
520 Parameters
521 ----------
522 key : str
523 The parameter name.
525 Returns
526 -------
527 bool
528 True if the parameter is time-varying, False otherwise.
529 """
530 return key in self._varying_params
532 def at_age(self, age: int) -> "Parameters":
533 """
534 Get parameters for a specific age.
536 This is an alternative to integer indexing (params[age]) that is more
537 explicit and avoids potential confusion with dictionary-style access.
539 Parameters
540 ----------
541 age : int
542 The age index to retrieve parameters for.
544 Returns
545 -------
546 Parameters
547 A new Parameters object with parameters for the specified age.
549 Raises
550 ------
551 ValueError
552 If the age index is out of bounds.
554 Examples
555 --------
556 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97], sigma=2.0)
557 >>> age_1_params = params.at_age(1)
558 >>> age_1_params.beta
559 0.96
560 """
561 return self[age]
563 def validate(self) -> None:
564 """
565 Validate parameter consistency.
567 Checks that all time-varying parameters have length matching T_cycle.
568 This is useful after manual modifications or when parameters are set
569 programmatically.
571 Raises
572 ------
573 ValueError
574 If any time-varying parameter has incorrect length.
576 Examples
577 --------
578 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97])
579 >>> params.validate() # Passes
580 >>> params.add_to_time_vary("beta")
581 >>> params.validate() # Still passes
582 """
583 errors = []
584 for param in self._varying_params:
585 value = self._parameters[param]
586 if isinstance(value, (list, tuple)):
587 if len(value) != self._length:
588 errors.append(
589 f"Parameter '{param}' has length {len(value)}, expected {self._length}"
590 )
591 elif isinstance(value, np.ndarray):
592 if value.ndim == 0:
593 errors.append(
594 f"Parameter '{param}' is a 0-dimensional array (scalar), "
595 "which should not be time-varying"
596 )
597 elif value.ndim >= 2:
598 if value.shape[0] != self._length:
599 errors.append(
600 f"Parameter '{param}' has first dimension {value.shape[0]}, expected {self._length}"
601 )
602 elif value.ndim == 1:
603 if len(value) != self._length:
604 errors.append(
605 f"Parameter '{param}' has length {len(value)}, expected {self._length}"
606 )
607 elif value.ndim == 0:
608 errors.append(
609 f"Parameter '{param}' is a 0-dimensional numpy array, expected length {self._length}"
610 )
612 if errors:
613 raise ValueError(
614 "Parameter validation failed:\n" + "\n".join(f" - {e}" for e in errors)
615 )
618class Model:
619 """
620 A class with special handling of parameters assignment.
621 """
623 def __init__(self):
624 if not hasattr(self, "parameters"):
625 self.parameters = {}
626 if not hasattr(self, "constructors"):
627 self.constructors = {}
629 def assign_parameters(self, **kwds):
630 """
631 Assign an arbitrary number of attributes to this agent.
633 Parameters
634 ----------
635 **kwds : keyword arguments
636 Any number of keyword arguments of the form key=value.
637 Each value will be assigned to the attribute named in self.
639 Returns
640 -------
641 None
642 """
643 self.parameters.update(kwds)
644 for key in kwds:
645 setattr(self, key, kwds[key])
647 def get_parameter(self, name):
648 """
649 Returns a parameter of this model
651 Parameters
652 ----------
653 name : str
654 The name of the parameter to get
656 Returns
657 -------
658 value : The value of the parameter
659 """
660 return self.parameters[name]
662 def __eq__(self, other):
663 if isinstance(other, type(self)):
664 return self.parameters == other.parameters
666 return NotImplemented
668 def __str__(self):
669 type_ = type(self)
670 module = type_.__module__
671 qualname = type_.__qualname__
673 s = f"<{module}.{qualname} object at {hex(id(self))}.\n"
674 s += "Parameters:"
676 for p in self.parameters:
677 s += f"\n{p}: {self.parameters[p]}"
679 s += ">"
680 return s
682 def describe(self):
683 return self.__str__()
685 def del_param(self, param_name):
686 """
687 Deletes a parameter from this instance, removing it both from the object's
688 namespace (if it's there) and the parameters dictionary (likewise).
690 Parameters
691 ----------
692 param_name : str
693 A string naming a parameter or data to be deleted from this instance.
694 Removes information from self.parameters dictionary and own namespace.
696 Returns
697 -------
698 None
699 """
700 if param_name in self.parameters:
701 del self.parameters[param_name]
702 if hasattr(self, param_name):
703 delattr(self, param_name)
705 def construct(self, *args, force=False):
706 """
707 Top-level method for building constructed inputs. If called without any
708 inputs, construct builds each of the objects named in the keys of the
709 constructors dictionary; it draws inputs for the constructors from the
710 parameters dictionary and adds its results to the same. If passed one or
711 more strings as arguments, the method builds only the named keys. The
712 method will do multiple "passes" over the requested keys, as some cons-
713 tructors require inputs built by other constructors. If any requested
714 constructors failed to build due to missing data, those keys (and the
715 missing data) will be named in self._missing_key_data. Other errors are
716 recorded in the dictionary attribute _constructor_errors.
718 This method tries to "start from scratch" by removing prior constructed
719 objects, holding them in a backup dictionary during construction. This
720 is done so that dependencies among constructors are resolved properly,
721 without mistakenly relying on "old information". A backup value is used
722 if a constructor function is set to None (i.e. "don't do anything"), or
723 if the construct method fails to produce a new object.
725 Parameters
726 ----------
727 *args : str, optional
728 Keys of self.constructors that are requested to be constructed.
729 If no arguments are passed, *all* elements of the dictionary are implied.
730 force : bool, optional
731 When True, the method will force its way past any errors, including
732 missing constructors, missing arguments for constructors, and errors
733 raised during execution of constructors. Information about all such
734 errors is stored in the dictionary attributes described above. When
735 False (default), any errors or exception will be raised.
737 Returns
738 -------
739 None
740 """
741 # Set up the requested work
742 if len(args) > 0:
743 keys = args
744 else:
745 keys = list(self.constructors.keys())
746 N_keys = len(keys)
747 keys_complete = np.zeros(N_keys, dtype=bool)
748 if N_keys == 0:
749 return # Do nothing if there are no constructed objects
751 # Remove pre-existing constructed objects, preventing "incomplete" updates,
752 # but store the current values in a backup dictionary in case something fails
753 backup = {}
754 for key in keys:
755 if hasattr(self, key):
756 backup[key] = getattr(self, key)
757 self.del_param(key)
759 # Get the dictionary of constructor errors
760 if not hasattr(self, "_constructor_errors"):
761 self._constructor_errors = {}
762 errors = self._constructor_errors
764 # As long as the work isn't complete and we made some progress on the last
765 # pass, repeatedly perform passes of trying to construct objects
766 any_keys_incomplete = np.any(np.logical_not(keys_complete))
767 go = any_keys_incomplete
768 while go:
769 anything_accomplished_this_pass = False # Nothing done yet!
770 missing_key_data = [] # Keep this up-to-date on each pass
772 # Loop over keys to be constructed
773 for i in range(N_keys):
774 if keys_complete[i]:
775 continue # This key has already been built
777 # Get this key and its constructor function
778 key = keys[i]
779 try:
780 constructor = self.constructors[key]
781 except Exception as not_found:
782 errors[key] = "No constructor found for " + str(not_found)
783 if force:
784 continue
785 else:
786 raise KeyError("No constructor found for " + key) from None
788 # If this constructor is None, do nothing and mark it as completed;
789 # this includes restoring the previous value if it exists
790 if constructor is None:
791 if key in backup.keys():
792 setattr(self, key, backup[key])
793 self.parameters[key] = backup[key]
794 keys_complete[i] = True
795 anything_accomplished_this_pass = True # We did something!
796 continue
798 # SPECIAL: if the constructor is get_it_from, handle it separately
799 if isinstance(constructor, get_it_from):
800 try:
801 parent = getattr(self, constructor.name)
802 query = key
803 any_missing = False
804 missing_args = []
805 except AttributeError:
806 parent = None
807 query = None
808 any_missing = True
809 missing_args = [constructor.name]
810 temp_dict = {"parent": parent, "query": query}
812 # Get the names of arguments for this constructor and try to gather them
813 else: # (if it's not the special case of get_it_from)
814 args_needed = get_arg_names(constructor)
815 has_no_default = {
816 k: v.default is inspect.Parameter.empty
817 for k, v in inspect.signature(constructor).parameters.items()
818 }
819 temp_dict = {}
820 any_missing = False
821 missing_args = []
822 for j in range(len(args_needed)):
823 this_arg = args_needed[j]
824 if hasattr(self, this_arg):
825 temp_dict[this_arg] = getattr(self, this_arg)
826 else:
827 try:
828 temp_dict[this_arg] = self.parameters[this_arg]
829 except KeyError:
830 if has_no_default[this_arg]:
831 # Record missing key-data pair
832 any_missing = True
833 missing_key_data.append((key, this_arg))
834 missing_args.append(this_arg)
836 # If all of the required data was found, run the constructor and
837 # store the result in parameters (and on self)
838 if not any_missing:
839 try:
840 temp = constructor(**temp_dict)
841 except Exception as problem:
842 errors[key] = str(type(problem)) + ": " + str(problem)
843 self.del_param(key)
844 if force:
845 continue
846 else:
847 raise
848 setattr(self, key, temp)
849 self.parameters[key] = temp
850 if key in errors:
851 del errors[key]
852 keys_complete[i] = True
853 anything_accomplished_this_pass = True # We did something!
854 else:
855 msg = "Missing required arguments:"
856 for arg in missing_args:
857 msg += " " + arg + ","
858 msg = msg[:-1]
859 errors[key] = msg
860 self.del_param(key)
861 # Never raise exceptions here, as the arguments might be filled in later
863 # Check whether another pass should be performed
864 any_keys_incomplete = np.any(np.logical_not(keys_complete))
865 go = any_keys_incomplete and anything_accomplished_this_pass
867 # Store missing key-data pairs and exit
868 self._missing_key_data = missing_key_data
869 self._constructor_errors = errors
870 if any_keys_incomplete:
871 msg = "Did not construct these objects:"
872 for i in range(N_keys):
873 if keys_complete[i]:
874 continue
875 msg += " " + keys[i] + ","
876 key = keys[i]
877 if key in backup.keys():
878 setattr(self, key, backup[key])
879 self.parameters[key] = backup[key]
880 msg = msg[:-1]
881 if not force:
882 raise ValueError(msg)
883 return
885 def describe_constructors(self, *args):
886 """
887 Prints to screen a string describing this instance's constructed objects,
888 including their names, the function that constructs them, the names of
889 those functions inputs, and whether those inputs are present.
891 Parameters
892 ----------
893 *args : str, optional
894 Optional list of strings naming constructed inputs to be described.
895 If none are passed, all constructors are described.
897 Returns
898 -------
899 None
900 """
901 if len(args) > 0:
902 keys = args
903 else:
904 keys = list(self.constructors.keys())
905 yes = "\u2713"
906 no = "X"
907 maybe = "*"
908 noyes = [no, yes]
910 out = ""
911 for key in keys:
912 has_val = hasattr(self, key) or (key in self.parameters)
914 try:
915 constructor = self.constructors[key]
916 except KeyError:
917 out += noyes[int(has_val)] + " " + key + " : NO CONSTRUCTOR FOUND\n"
918 continue
920 # Get the constructor function if possible
921 if isinstance(constructor, get_it_from):
922 parent_name = self.constructors[key].name
923 out += (
924 noyes[int(has_val)]
925 + " "
926 + key
927 + " : get it from "
928 + parent_name
929 + "\n"
930 )
931 continue
932 else:
933 out += (
934 noyes[int(has_val)]
935 + " "
936 + key
937 + " : "
938 + constructor.__name__
939 + "\n"
940 )
942 # Get constructor argument names
943 arg_names = get_arg_names(constructor)
944 has_no_default = {
945 k: v.default is inspect.Parameter.empty
946 for k, v in inspect.signature(constructor).parameters.items()
947 }
949 # Check whether each argument exists
950 for j in range(len(arg_names)):
951 this_arg = arg_names[j]
952 if hasattr(self, this_arg) or this_arg in self.parameters:
953 symb = yes
954 elif not has_no_default[this_arg]:
955 symb = maybe
956 else:
957 symb = no
958 out += " " + symb + " " + this_arg + "\n"
960 # Print the string to screen
961 print(out)
962 return
964 # This is a "synonym" method so that old calls to update() still work
965 def update(self, *args, **kwargs):
966 self.construct(*args, **kwargs)
969class AgentType(Model):
970 """
971 A superclass for economic agents in the HARK framework. Each model should
972 specify its own subclass of AgentType, inheriting its methods and overwriting
973 as necessary. Critically, every subclass of AgentType should define class-
974 specific static values of the attributes time_vary and time_inv as lists of
975 strings. Each element of time_vary is the name of a field in AgentSubType
976 that varies over time in the model. Each element of time_inv is the name of
977 a field in AgentSubType that is constant over time in the model.
979 Parameters
980 ----------
981 solution_terminal : Solution
982 A representation of the solution to the terminal period problem of
983 this AgentType instance, or an initial guess of the solution if this
984 is an infinite horizon problem.
985 cycles : int
986 The number of times the sequence of periods is experienced by this
987 AgentType in their "lifetime". cycles=1 corresponds to a lifecycle
988 model, with a certain sequence of one period problems experienced
989 once before terminating. cycles=0 corresponds to an infinite horizon
990 model, with a sequence of one period problems repeating indefinitely.
991 pseudo_terminal : bool
992 Indicates whether solution_terminal isn't actually part of the
993 solution to the problem (as a known solution to the terminal period
994 problem), but instead represents a "scrap value"-style termination.
995 When True, solution_terminal is not included in the solution; when
996 False, solution_terminal is the last element of the solution.
997 tolerance : float
998 Maximum acceptable "distance" between successive solutions to the
999 one period problem in an infinite horizon (cycles=0) model in order
1000 for the solution to be considered as having "converged". Inoperative
1001 when cycles>0.
1002 verbose : int
1003 Level of output to be displayed by this instance, default is 1.
1004 quiet : bool
1005 Indicator for whether this instance should operate "quietly", default False.
1006 seed : int
1007 A seed for this instance's random number generator.
1008 construct : bool
1009 Indicator for whether this instance's construct() method should be run
1010 when initialized (default True). When False, an instance of the class
1011 can be created even if not all of its attributes can be constructed.
1012 use_defaults : bool
1013 Indicator for whether this instance should use the values in the class'
1014 default dictionary to fill in parameters and constructors for those not
1015 provided by the user (default True). Setting this to False is useful for
1016 situations where the user wants to be absolutely sure that they know what
1017 is being passed to the class initializer, without resorting to defaults.
1019 Attributes
1020 ----------
1021 AgentCount : int
1022 The number of agents of this type to use in simulation.
1024 state_vars : list of string
1025 The string labels for this AgentType's model state variables.
1026 """
1028 time_vary_ = []
1029 time_inv_ = []
1030 shock_vars_ = []
1031 state_vars = []
1032 poststate_vars = []
1033 distributions = []
1034 default_ = {"params": {}, "solver": NullFunc()}
1036 def __init__(
1037 self,
1038 solution_terminal=None,
1039 pseudo_terminal=True,
1040 tolerance=0.000001,
1041 verbose=1,
1042 quiet=False,
1043 seed=0,
1044 construct=True,
1045 use_defaults=True,
1046 **kwds,
1047 ):
1048 super().__init__()
1049 params = deepcopy(self.default_["params"]) if use_defaults else {}
1050 params.update(kwds)
1052 # Correctly handle constructors that have been passed in kwds
1053 if "constructors" in self.default_["params"].keys() and use_defaults:
1054 constructors = deepcopy(self.default_["params"]["constructors"])
1055 else:
1056 constructors = {}
1057 if "constructors" in kwds.keys():
1058 constructors.update(kwds["constructors"])
1059 params["constructors"] = constructors
1061 # Set default track_vars
1062 if "track_vars" in self.default_.keys() and use_defaults:
1063 self.track_vars = copy(self.default_["track_vars"])
1064 else:
1065 self.track_vars = []
1067 # Set model file name if possible
1068 try:
1069 self.model_file = copy(self.default_["model"])
1070 except (KeyError, TypeError):
1071 # Fallback to None if "model" key is missing or invalid for copying
1072 self.model_file = None
1074 if solution_terminal is None:
1075 solution_terminal = NullFunc()
1077 self.solve_one_period = self.default_["solver"] # NOQA
1078 self.solution_terminal = solution_terminal # NOQA
1079 self.pseudo_terminal = pseudo_terminal # NOQA
1080 self.tolerance = tolerance # NOQA
1081 self.verbose = verbose
1082 self.quiet = quiet
1083 self.seed = seed # NOQA
1084 self.state_now = {sv: None for sv in self.state_vars}
1085 self.state_prev = self.state_now.copy()
1086 self.controls = {}
1087 self.shocks = {}
1088 self.read_shocks = False # NOQA
1089 self.shock_history = {}
1090 self.newborn_init_history = {}
1091 self.history = {}
1092 self.assign_parameters(**params) # NOQA
1093 self.reset_rng() # NOQA
1094 self.bilt = {}
1095 if construct:
1096 self.construct()
1098 # Add instance-level lists and objects
1099 self.time_vary = deepcopy(self.time_vary_)
1100 self.time_inv = deepcopy(self.time_inv_)
1101 self.shock_vars = deepcopy(self.shock_vars_)
1103 def add_to_time_vary(self, *params):
1104 """
1105 Adds any number of parameters to time_vary for this instance.
1107 Parameters
1108 ----------
1109 params : string
1110 Any number of strings naming attributes to be added to time_vary
1112 Returns
1113 -------
1114 None
1115 """
1116 for param in params:
1117 if param not in self.time_vary:
1118 self.time_vary.append(param)
1120 def add_to_time_inv(self, *params):
1121 """
1122 Adds any number of parameters to time_inv for this instance.
1124 Parameters
1125 ----------
1126 params : string
1127 Any number of strings naming attributes to be added to time_inv
1129 Returns
1130 -------
1131 None
1132 """
1133 for param in params:
1134 if param not in self.time_inv:
1135 self.time_inv.append(param)
1137 def del_from_time_vary(self, *params):
1138 """
1139 Removes any number of parameters from time_vary for this instance.
1141 Parameters
1142 ----------
1143 params : string
1144 Any number of strings naming attributes to be removed from time_vary
1146 Returns
1147 -------
1148 None
1149 """
1150 for param in params:
1151 if param in self.time_vary:
1152 self.time_vary.remove(param)
1154 def del_from_time_inv(self, *params):
1155 """
1156 Removes any number of parameters from time_inv for this instance.
1158 Parameters
1159 ----------
1160 params : string
1161 Any number of strings naming attributes to be removed from time_inv
1163 Returns
1164 -------
1165 None
1166 """
1167 for param in params:
1168 if param in self.time_inv:
1169 self.time_inv.remove(param)
1171 def unpack(self, name):
1172 """
1173 Unpacks an attribute from a solution object for easier access.
1174 After the model has been solved, its components (like consumption function)
1175 reside in the attributes of each element of `ThisType.solution` (e.g. `cFunc`).
1176 This method creates a (time varying) attribute of the given attribute name
1177 that contains a list of elements accessible by `ThisType.parameter`.
1179 Parameters
1180 ----------
1181 name: str
1182 Name of the attribute to unpack from the solution
1184 Returns
1185 -------
1186 none
1187 """
1188 # Use list comprehension for better performance instead of loop with append
1189 if type(self.solution[0]) is dict:
1190 setattr(self, name, [soln_t[name] for soln_t in self.solution])
1191 else:
1192 setattr(self, name, [soln_t.__dict__[name] for soln_t in self.solution])
1193 self.add_to_time_vary(name)
1195 def solve(
1196 self,
1197 verbose=False,
1198 presolve=True,
1199 postsolve=True,
1200 from_solution=None,
1201 from_t=None,
1202 ):
1203 """
1204 Solve the model for this instance of an agent type by backward induction.
1205 Loops through the sequence of one period problems, passing the solution
1206 from period t+1 to the problem for period t.
1208 Parameters
1209 ----------
1210 verbose : bool, optional
1211 If True, solution progress is printed to screen. Default False.
1212 presolve : bool, optional
1213 If True (default), the pre_solve method is run before solving.
1214 postsolve : bool, optional
1215 If True (default), the post_solve method is run after solving.
1216 from_solution: Solution
1217 If different from None, will be used as the starting point of backward
1218 induction, instead of self.solution_terminal.
1219 from_t : int or None
1220 If not None, indicates which period of the model the solver should start
1221 from. It should usually only be used in combination with from_solution.
1222 Stands for the time index that from_solution represents, and thus is
1223 only compatible with cycles=1 and will be reset to None otherwise.
1225 Returns
1226 -------
1227 none
1228 """
1230 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep-
1231 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is
1232 # -np.inf, np.inf/np.inf is np.nan and so on.
1233 with np.errstate(
1234 divide="ignore", over="ignore", under="ignore", invalid="ignore"
1235 ):
1236 if presolve:
1237 self.pre_solve() # Do pre-solution stuff
1238 self.solution = solve_agent(
1239 self,
1240 verbose,
1241 from_solution,
1242 from_t,
1243 ) # Solve the model by backward induction
1244 if postsolve:
1245 self.post_solve() # Do post-solution stuff
1247 def reset_rng(self):
1248 """
1249 Reset the random number generator and all distributions for this type.
1250 Type-checking for lists is to handle the following three cases:
1252 1) The target is a single distribution object
1253 2) The target is a list of distribution objects (probably time-varying)
1254 3) The target is a nested list of distributions, as in ConsMarkovModel.
1255 """
1256 self.RNG = np.random.default_rng(self.seed)
1257 for name in self.distributions:
1258 if not hasattr(self, name):
1259 continue
1261 dstn = getattr(self, name)
1262 if isinstance(dstn, list):
1263 for D in dstn:
1264 if isinstance(D, list):
1265 for d in D:
1266 d.reset()
1267 else:
1268 D.reset()
1269 else:
1270 dstn.reset()
1272 def check_elements_of_time_vary_are_lists(self):
1273 """
1274 A method to check that elements of time_vary are lists.
1275 """
1276 for param in self.time_vary:
1277 if not hasattr(self, param):
1278 continue
1279 if not isinstance(
1280 getattr(self, param),
1281 (IndexDistribution,),
1282 ):
1283 assert type(getattr(self, param)) == list, (
1284 param
1285 + " is not a list or time varying distribution,"
1286 + " but should be because it is in time_vary"
1287 )
1289 def check_restrictions(self):
1290 """
1291 A method to check that various restrictions are met for the model class.
1292 """
1293 return
1295 def pre_solve(self):
1296 """
1297 A method that is run immediately before the model is solved, to check inputs or to prepare
1298 the terminal solution, perhaps.
1300 Parameters
1301 ----------
1302 none
1304 Returns
1305 -------
1306 none
1307 """
1308 self.check_restrictions()
1309 self.check_elements_of_time_vary_are_lists()
1310 return None
1312 def post_solve(self):
1313 """
1314 A method that is run immediately after the model is solved, to finalize
1315 the solution in some way. Does nothing here.
1317 Parameters
1318 ----------
1319 none
1321 Returns
1322 -------
1323 none
1324 """
1325 return None
1327 def initialize_sym(self, **kwargs):
1328 """
1329 Use the new simulator structure to build a simulator from the agents'
1330 attributes, storing it in a private attribute.
1331 """
1332 self.reset_rng() # ensure seeds are set identically each time
1333 self._simulator = make_simulator_from_agent(self, **kwargs)
1334 self._simulator.reset()
1336 def initialize_sim(self):
1337 """
1338 Prepares this AgentType for a new simulation. Resets the internal random number generator,
1339 makes initial states for all agents (using sim_birth), clears histories of tracked variables.
1341 Parameters
1342 ----------
1343 None
1345 Returns
1346 -------
1347 None
1348 """
1349 if not hasattr(self, "T_sim"):
1350 raise Exception(
1351 "To initialize simulation variables it is necessary to first "
1352 + "set the attribute T_sim to the largest number of observations "
1353 + "you plan to simulate for each agent including re-births."
1354 )
1355 elif self.T_sim <= 0:
1356 raise Exception(
1357 "T_sim represents the largest number of observations "
1358 + "that can be simulated for an agent, and must be a positive number."
1359 )
1361 self.reset_rng()
1362 self.t_sim = 0
1363 all_agents = np.ones(self.AgentCount, dtype=bool)
1364 blank_array = np.empty(self.AgentCount)
1365 blank_array[:] = np.nan
1366 for var in self.state_vars:
1367 self.state_now[var] = copy(blank_array)
1369 # Number of periods since agent entry
1370 self.t_age = np.zeros(self.AgentCount, dtype=int)
1371 # Which cycle period each agent is on
1372 self.t_cycle = np.zeros(self.AgentCount, dtype=int)
1373 self.sim_birth(all_agents)
1375 # If we are asked to use existing shocks and a set of initial conditions
1376 # exist, use them
1377 if self.read_shocks and bool(self.newborn_init_history):
1378 for var_name in self.state_now:
1379 # Check that we are actually given a value for the variable
1380 if var_name in self.newborn_init_history.keys():
1381 # Copy only array-like idiosyncratic states. Aggregates should
1382 # not be set by newborns
1383 idio = (
1384 isinstance(self.state_now[var_name], np.ndarray)
1385 and len(self.state_now[var_name]) == self.AgentCount
1386 )
1387 if idio:
1388 self.state_now[var_name] = self.newborn_init_history[var_name][
1389 0
1390 ]
1392 else:
1393 warn(
1394 "The option for reading shocks was activated but "
1395 + "the model requires state "
1396 + var_name
1397 + ", not contained in "
1398 + "newborn_init_history."
1399 )
1401 self.clear_history()
1402 return None
1404 def sim_one_period(self):
1405 """
1406 Simulates one period for this type. Calls the methods get_mortality(), get_shocks() or
1407 read_shocks, get_states(), get_controls(), and get_poststates(). These should be defined for
1408 AgentType subclasses, except get_mortality (define its components sim_death and sim_birth
1409 instead) and read_shocks.
1411 Parameters
1412 ----------
1413 None
1415 Returns
1416 -------
1417 None
1418 """
1419 if not hasattr(self, "solution"):
1420 raise Exception(
1421 "Model instance does not have a solution stored. To simulate, it is necessary"
1422 " to run the `solve()` method first."
1423 )
1425 # Mortality adjusts the agent population
1426 self.get_mortality() # Replace some agents with "newborns"
1428 # state_{t-1}
1429 for var in self.state_now:
1430 self.state_prev[var] = self.state_now[var]
1432 if isinstance(self.state_now[var], np.ndarray):
1433 self.state_now[var] = np.empty(self.AgentCount)
1434 else:
1435 # Probably an aggregate variable. It may be getting set by the Market.
1436 pass
1438 if self.read_shocks: # If shock histories have been pre-specified, use those
1439 self.read_shocks_from_history()
1440 else: # Otherwise, draw shocks as usual according to subclass-specific method
1441 self.get_shocks()
1442 self.get_states() # Determine each agent's state at decision time
1443 self.get_controls() # Determine each agent's choice or control variables based on states
1444 self.get_poststates() # Calculate variables that come *after* decision-time
1446 # Advance time for all agents
1447 self.t_age = self.t_age + 1 # Age all consumers by one period
1448 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle
1449 self.t_cycle[self.t_cycle == self.T_cycle] = (
1450 0 # Resetting to zero for those who have reached the end
1451 )
1453 def make_shock_history(self):
1454 """
1455 Makes a pre-specified history of shocks for the simulation. Shock variables should be named
1456 in self.shock_vars, a list of strings that is subclass-specific. This method runs a subset
1457 of the standard simulation loop by simulating only mortality and shocks; each variable named
1458 in shock_vars is stored in a T_sim x AgentCount array in history dictionary self.history[X].
1459 Automatically sets self.read_shocks to True so that these pre-specified shocks are used for
1460 all subsequent calls to simulate().
1462 Parameters
1463 ----------
1464 None
1466 Returns
1467 -------
1468 None
1469 """
1470 # Re-initialize the simulation
1471 self.initialize_sim()
1473 # Make blank history arrays for each shock variable (and mortality)
1474 for var_name in self.shock_vars:
1475 self.shock_history[var_name] = (
1476 np.zeros((self.T_sim, self.AgentCount)) + np.nan
1477 )
1478 self.shock_history["who_dies"] = np.zeros(
1479 (self.T_sim, self.AgentCount), dtype=bool
1480 )
1482 # Also make blank arrays for the draws of newborns' initial conditions
1483 for var_name in self.state_vars:
1484 self.newborn_init_history[var_name] = (
1485 np.zeros((self.T_sim, self.AgentCount)) + np.nan
1486 )
1488 # Record the initial condition of the newborns created by
1489 # initialize_sim -> sim_births
1490 for var_name in self.state_vars:
1491 # Check whether the state is idiosyncratic or an aggregate
1492 idio = (
1493 isinstance(self.state_now[var_name], np.ndarray)
1494 and len(self.state_now[var_name]) == self.AgentCount
1495 )
1496 if idio:
1497 self.newborn_init_history[var_name][self.t_sim] = self.state_now[
1498 var_name
1499 ]
1500 else:
1501 # Aggregate state is a scalar. Assign it to every agent.
1502 self.newborn_init_history[var_name][self.t_sim, :] = self.state_now[
1503 var_name
1504 ]
1506 # Make and store the history of shocks for each period
1507 for t in range(self.T_sim):
1508 # Deaths
1509 self.get_mortality()
1510 self.shock_history["who_dies"][t, :] = self.who_dies
1512 # Initial conditions of newborns
1513 if self.who_dies.any():
1514 for var_name in self.state_vars:
1515 # Check whether the state is idiosyncratic or an aggregate
1516 idio = (
1517 isinstance(self.state_now[var_name], np.ndarray)
1518 and len(self.state_now[var_name]) == self.AgentCount
1519 )
1520 if idio:
1521 self.newborn_init_history[var_name][t, self.who_dies] = (
1522 self.state_now[var_name][self.who_dies]
1523 )
1524 else:
1525 self.newborn_init_history[var_name][t, self.who_dies] = (
1526 self.state_now[var_name]
1527 )
1529 # Other Shocks
1530 self.get_shocks()
1531 for var_name in self.shock_vars:
1532 self.shock_history[var_name][t, :] = self.shocks[var_name]
1534 self.t_sim += 1
1535 self.t_age = self.t_age + 1 # Age all consumers by one period
1536 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle
1537 self.t_cycle[self.t_cycle == self.T_cycle] = (
1538 0 # Resetting to zero for those who have reached the end
1539 )
1541 # Flag that shocks can be read rather than simulated
1542 self.read_shocks = True
1544 def get_mortality(self):
1545 """
1546 Simulates mortality or agent turnover according to some model-specific rules named sim_death
1547 and sim_birth (methods of an AgentType subclass). sim_death takes no arguments and returns
1548 a Boolean array of size AgentCount, indicating which agents of this type have "died" and
1549 must be replaced. sim_birth takes such a Boolean array as an argument and generates initial
1550 post-decision states for those agent indices.
1552 Parameters
1553 ----------
1554 None
1556 Returns
1557 -------
1558 None
1559 """
1560 if self.read_shocks:
1561 who_dies = self.shock_history["who_dies"][self.t_sim, :]
1562 # Instead of simulating births, assign the saved newborn initial conditions
1563 if who_dies.any():
1564 for var_name in self.state_now:
1565 if var_name in self.newborn_init_history.keys():
1566 # Copy only array-like idiosyncratic states. Aggregates should
1567 # not be set by newborns
1568 idio = (
1569 isinstance(self.state_now[var_name], np.ndarray)
1570 and len(self.state_now[var_name]) == self.AgentCount
1571 )
1572 if idio:
1573 self.state_now[var_name][who_dies] = (
1574 self.newborn_init_history[var_name][
1575 self.t_sim, who_dies
1576 ]
1577 )
1579 else:
1580 warn(
1581 "The option for reading shocks was activated but "
1582 + "the model requires state "
1583 + var_name
1584 + ", not contained in "
1585 + "newborn_init_history."
1586 )
1588 # Reset ages of newborns
1589 self.t_age[who_dies] = 0
1590 self.t_cycle[who_dies] = 0
1591 else:
1592 who_dies = self.sim_death()
1593 self.sim_birth(who_dies)
1594 self.who_dies = who_dies
1595 return None
1597 def sim_death(self):
1598 """
1599 Determines which agents in the current population "die" or should be replaced. Takes no
1600 inputs, returns a Boolean array of size self.AgentCount, which has True for agents who die
1601 and False for those that survive. Returns all False by default, must be overwritten by a
1602 subclass to have replacement events.
1604 Parameters
1605 ----------
1606 None
1608 Returns
1609 -------
1610 who_dies : np.array
1611 Boolean array of size self.AgentCount indicating which agents die and are replaced.
1612 """
1613 who_dies = np.zeros(self.AgentCount, dtype=bool)
1614 return who_dies
1616 def sim_birth(self, which_agents): # pragma: nocover
1617 """
1618 Makes new agents for the simulation. Takes a boolean array as an input, indicating which
1619 agent indices are to be "born". Does nothing by default, must be overwritten by a subclass.
1621 Parameters
1622 ----------
1623 which_agents : np.array(Bool)
1624 Boolean array of size self.AgentCount indicating which agents should be "born".
1626 Returns
1627 -------
1628 None
1629 """
1630 raise Exception("AgentType subclass must define method sim_birth!")
1632 def get_shocks(self): # pragma: nocover
1633 """
1634 Gets values of shock variables for the current period. Does nothing by default, but can
1635 be overwritten by subclasses of AgentType.
1637 Parameters
1638 ----------
1639 None
1641 Returns
1642 -------
1643 None
1644 """
1645 return None
1647 def read_shocks_from_history(self):
1648 """
1649 Reads values of shock variables for the current period from history arrays.
1650 For each variable X named in self.shock_vars, this attribute of self is
1651 set to self.history[X][self.t_sim,:].
1653 This method is only ever called if self.read_shocks is True. This can
1654 be achieved by using the method make_shock_history() (or manually after
1655 storing a "handcrafted" shock history).
1657 Parameters
1658 ----------
1659 None
1661 Returns
1662 -------
1663 None
1664 """
1665 for var_name in self.shock_vars:
1666 self.shocks[var_name] = self.shock_history[var_name][self.t_sim, :]
1668 def get_states(self):
1669 """
1670 Gets values of state variables for the current period.
1671 By default, calls transition function and assigns values
1672 to the state_now dictionary.
1674 Parameters
1675 ----------
1676 None
1678 Returns
1679 -------
1680 None
1681 """
1682 new_states = self.transition()
1684 for i, var in enumerate(self.state_now):
1685 # a hack for now to deal with 'post-states'
1686 if i < len(new_states):
1687 self.state_now[var] = new_states[i]
1689 def transition(self): # pragma: nocover
1690 """
1692 Parameters
1693 ----------
1694 None
1696 [Eventually, to match dolo spec:
1697 exogenous_prev, endogenous_prev, controls, exogenous, parameters]
1699 Returns
1700 -------
1702 endogenous_state: ()
1703 Tuple with new values of the endogenous states
1704 """
1705 return ()
1707 def get_controls(self): # pragma: nocover
1708 """
1709 Gets values of control variables for the current period, probably by using current states.
1710 Does nothing by default, but can be overwritten by subclasses of AgentType.
1712 Parameters
1713 ----------
1714 None
1716 Returns
1717 -------
1718 None
1719 """
1720 return None
1722 def get_poststates(self):
1723 """
1724 Gets values of post-decision state variables for the current period,
1725 probably by current
1726 states and controls and maybe market-level events or shock variables.
1727 Does nothing by
1728 default, but can be overwritten by subclasses of AgentType.
1730 Parameters
1731 ----------
1732 None
1734 Returns
1735 -------
1736 None
1737 """
1738 return None
1740 def symulate(self, T=None):
1741 """
1742 Run the new simulation structure, with history results written to the
1743 hystory attribute of self.
1744 """
1745 self._simulator.simulate(T)
1746 self.hystory = self._simulator.history
1748 def describe_model(self, display=True):
1749 """
1750 Print to screen information about this agent's model, based on its model
1751 file. This is useful for learning about outcome variable names for tracking
1752 during simulation, or for use with sequence space Jacobians.
1753 """
1754 if not hasattr(self, "_simulator"):
1755 self.initialize_sym()
1756 self._simulator.describe(display=display)
1758 def simulate(self, sim_periods=None):
1759 """
1760 Simulates this agent type for a given number of periods. Defaults to self.T_sim,
1761 or all remaining periods to simulate (T_sim - t_sim). Records histories of
1762 attributes named in self.track_vars in self.history[varname].
1764 Parameters
1765 ----------
1766 sim_periods : int or None
1767 Number of periods to simulate. Default is all remaining periods (usually T_sim).
1769 Returns
1770 -------
1771 history : dict
1772 The history tracked during the simulation.
1773 """
1774 if not hasattr(self, "t_sim"):
1775 raise Exception(
1776 "It seems that the simulation variables were not initialize before calling "
1777 + "simulate(). Call initialize_sim() to initialize the variables before calling simulate() again."
1778 )
1780 if not hasattr(self, "T_sim"):
1781 raise Exception(
1782 "This agent type instance must have the attribute T_sim set to a positive integer."
1783 + "Set T_sim to match the largest dataset you might simulate, and run this agent's"
1784 + "initialize_sim() method before running simulate() again."
1785 )
1787 if sim_periods is not None and self.T_sim < sim_periods:
1788 raise Exception(
1789 "To simulate, sim_periods has to be larger than the maximum data set size "
1790 + "T_sim. Either increase the attribute T_sim of this agent type instance "
1791 + "and call the initialize_sim() method again, or set sim_periods <= T_sim."
1792 )
1794 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep-
1795 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is
1796 # -np.inf, np.inf/np.inf is np.nan and so on.
1797 with np.errstate(
1798 divide="ignore", over="ignore", under="ignore", invalid="ignore"
1799 ):
1800 if sim_periods is None:
1801 sim_periods = self.T_sim - self.t_sim
1803 for t in range(sim_periods):
1804 self.sim_one_period()
1806 for var_name in self.track_vars:
1807 if var_name in self.state_now:
1808 self.history[var_name][self.t_sim, :] = self.state_now[var_name]
1809 elif var_name in self.shocks:
1810 self.history[var_name][self.t_sim, :] = self.shocks[var_name]
1811 elif var_name in self.controls:
1812 self.history[var_name][self.t_sim, :] = self.controls[var_name]
1813 else:
1814 if var_name == "who_dies" and self.t_sim > 1:
1815 self.history[var_name][self.t_sim - 1, :] = getattr(
1816 self, var_name
1817 )
1818 else:
1819 self.history[var_name][self.t_sim, :] = getattr(
1820 self, var_name
1821 )
1822 self.t_sim += 1
1824 def clear_history(self):
1825 """
1826 Clears the histories of the attributes named in self.track_vars.
1828 Parameters
1829 ----------
1830 None
1832 Returns
1833 -------
1834 None
1835 """
1836 for var_name in self.track_vars:
1837 self.history[var_name] = np.empty((self.T_sim, self.AgentCount))
1838 self.history[var_name].fill(np.nan)
1840 def make_basic_SSJ(self, shock, outcomes, grids, **kwargs):
1841 """
1842 Construct and return sequence space Jacobian matrices for specified outcomes
1843 with respect to specified "shock" variable. This "basic" method only works
1844 for "one period infinite horizon" models (cycles=0, T_cycle=1). See documen-
1845 tation for simulator.make_basic_SSJ_matrices for more information.
1846 """
1847 return make_basic_SSJ_matrices(self, shock, outcomes, grids, **kwargs)
1849 def calc_impulse_response_manually(self, shock, outcomes, grids, **kwargs):
1850 """
1851 Calculate and return the impulse response(s) of a perturbation to the shock
1852 parameter in period t=s, essentially computing one column of the sequence
1853 space Jacobian matrix manually. This "basic" method only works for "one
1854 period infinite horizon" models (cycles=0, T_cycle=1). See documentation
1855 for simulator.calc_shock_response_manually for more information.
1856 """
1857 return calc_shock_response_manually(self, shock, outcomes, grids, **kwargs)
1860def solve_agent(agent, verbose, from_solution=None, from_t=None):
1861 """
1862 Solve the dynamic model for one agent type using backwards induction. This
1863 function iterates on "cycles" of an agent's model either a given number of
1864 times or until solution convergence if an infinite horizon model is used
1865 (with agent.cycles = 0).
1867 Parameters
1868 ----------
1869 agent : AgentType
1870 The microeconomic AgentType whose dynamic problem
1871 is to be solved.
1872 verbose : boolean
1873 If True, solution progress is printed to screen (when cycles != 1).
1874 from_solution: Solution
1875 If different from None, will be used as the starting point of backward
1876 induction, instead of self.solution_terminal
1877 from_t : int or None
1878 If not None, indicates which period of the model the solver should start
1879 from. It should usually only be used in combination with from_solution.
1880 Stands for the time index that from_solution represents, and thus is
1881 only compatible with cycles=1 and will be reset to None otherwise.
1883 Returns
1884 -------
1885 solution : [Solution]
1886 A list of solutions to the one period problems that the agent will
1887 encounter in his "lifetime".
1888 """
1889 # Check to see whether this is an (in)finite horizon problem
1890 cycles_left = agent.cycles # NOQA
1891 infinite_horizon = cycles_left == 0 # NOQA
1893 if from_solution is None:
1894 solution_last = agent.solution_terminal # NOQA
1895 else:
1896 solution_last = from_solution
1897 if agent.cycles != 1:
1898 from_t = None
1900 # Initialize the solution, which includes the terminal solution if it's not a pseudo-terminal period
1901 solution = []
1902 if not agent.pseudo_terminal:
1903 solution.insert(0, deepcopy(solution_last))
1905 # Initialize the process, then loop over cycles
1906 go = True # NOQA
1907 completed_cycles = 0 # NOQA
1908 max_cycles = 5000 # NOQA - escape clause
1909 if verbose:
1910 t_last = time()
1911 while go:
1912 # Solve a cycle of the model, recording it if horizon is finite
1913 solution_cycle = solve_one_cycle(agent, solution_last, from_t)
1914 if not infinite_horizon:
1915 solution = solution_cycle + solution
1917 # Check for termination: identical solutions across
1918 # cycle iterations or run out of cycles
1919 solution_now = solution_cycle[0]
1920 if infinite_horizon:
1921 if completed_cycles > 0:
1922 solution_distance = distance_metric(solution_now, solution_last)
1923 agent.solution_distance = (
1924 solution_distance # Add these attributes so users can
1925 )
1926 agent.completed_cycles = (
1927 completed_cycles # query them to see if solution is ready
1928 )
1929 go = (
1930 solution_distance > agent.tolerance
1931 and completed_cycles < max_cycles
1932 )
1933 else: # Assume solution does not converge after only one cycle
1934 solution_distance = 100.0
1935 go = True
1936 else:
1937 cycles_left += -1
1938 go = cycles_left > 0
1940 # Update the "last period solution"
1941 solution_last = solution_now
1942 completed_cycles += 1
1944 # Display progress if requested
1945 if verbose:
1946 t_now = time()
1947 if infinite_horizon:
1948 print(
1949 "Finished cycle #"
1950 + str(completed_cycles)
1951 + " in "
1952 + str(t_now - t_last)
1953 + " seconds, solution distance = "
1954 + str(solution_distance)
1955 )
1956 else:
1957 print(
1958 "Finished cycle #"
1959 + str(completed_cycles)
1960 + " of "
1961 + str(agent.cycles)
1962 + " in "
1963 + str(t_now - t_last)
1964 + " seconds."
1965 )
1966 t_last = t_now
1968 # Record the last cycle if horizon is infinite (solution is still empty!)
1969 if infinite_horizon:
1970 solution = (
1971 solution_cycle # PseudoTerminal=False impossible for infinite horizon
1972 )
1974 return solution
1977def solve_one_cycle(agent, solution_last, from_t):
1978 """
1979 Solve one "cycle" of the dynamic model for one agent type. This function
1980 iterates over the periods within an agent's cycle, updating the time-varying
1981 parameters and passing them to the single period solver(s).
1983 Parameters
1984 ----------
1985 agent : AgentType
1986 The microeconomic AgentType whose dynamic problem is to be solved.
1987 solution_last : Solution
1988 A representation of the solution of the period that comes after the
1989 end of the sequence of one period problems. This might be the term-
1990 inal period solution, a "pseudo terminal" solution, or simply the
1991 solution to the earliest period from the succeeding cycle.
1992 from_t : int or None
1993 If not None, indicates which period of the model the solver should start
1994 from. When used, represents the time index that solution_last is from.
1996 Returns
1997 -------
1998 solution_cycle : [Solution]
1999 A list of one period solutions for one "cycle" of the AgentType's
2000 microeconomic model.
2001 """
2003 # Check if the agent has a 'Parameters' attribute of the 'Parameters' class
2004 # if so, take advantage of it. Else, use the old method
2005 if hasattr(agent, "parameters") and isinstance(agent.parameters, Parameters):
2006 T = agent.parameters._length if from_t is None else from_t
2008 # Initialize the solution for this cycle, then iterate on periods
2009 solution_cycle = []
2010 solution_next = solution_last
2012 cycles_range = [0] + list(range(T - 1, 0, -1))
2013 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range:
2014 # Update which single period solver to use (if it depends on time)
2015 if hasattr(agent.solve_one_period, "__getitem__"):
2016 solve_one_period = agent.solve_one_period[k]
2017 else:
2018 solve_one_period = agent.solve_one_period
2020 if hasattr(solve_one_period, "solver_args"):
2021 these_args = solve_one_period.solver_args
2022 else:
2023 these_args = get_arg_names(solve_one_period)
2025 # Make a temporary dictionary for this period
2026 temp_pars = agent.parameters[k]
2027 temp_dict = {
2028 name: solution_next if name == "solution_next" else temp_pars[name]
2029 for name in these_args
2030 }
2032 # Solve one period, add it to the solution, and move to the next period
2033 solution_t = solve_one_period(**temp_dict)
2034 solution_cycle.insert(0, solution_t)
2035 solution_next = solution_t
2037 else:
2038 # Calculate number of periods per cycle, defaults to 1 if all variables are time invariant
2039 if len(agent.time_vary) > 0:
2040 T = agent.T_cycle if from_t is None else from_t
2041 else:
2042 T = 1
2044 solve_dict = {
2045 parameter: agent.__dict__[parameter] for parameter in agent.time_inv
2046 }
2047 solve_dict.update({parameter: None for parameter in agent.time_vary})
2049 # Initialize the solution for this cycle, then iterate on periods
2050 solution_cycle = []
2051 solution_next = solution_last
2053 cycles_range = [0] + list(range(T - 1, 0, -1))
2054 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range:
2055 # Update which single period solver to use (if it depends on time)
2056 if hasattr(agent.solve_one_period, "__getitem__"):
2057 solve_one_period = agent.solve_one_period[k]
2058 else:
2059 solve_one_period = agent.solve_one_period
2061 if hasattr(solve_one_period, "solver_args"):
2062 these_args = solve_one_period.solver_args
2063 else:
2064 these_args = get_arg_names(solve_one_period)
2066 # Update time-varying single period inputs
2067 for name in agent.time_vary:
2068 if name in these_args:
2069 solve_dict[name] = agent.__dict__[name][k]
2070 solve_dict["solution_next"] = solution_next
2072 # Make a temporary dictionary for this period
2073 temp_dict = {name: solve_dict[name] for name in these_args}
2075 # Solve one period, add it to the solution, and move to the next period
2076 solution_t = solve_one_period(**temp_dict)
2077 solution_cycle.insert(0, solution_t)
2078 solution_next = solution_t
2080 # Return the list of per-period solutions
2081 return solution_cycle
2084def make_one_period_oo_solver(solver_class):
2085 """
2086 Returns a function that solves a single period consumption-saving
2087 problem.
2088 Parameters
2089 ----------
2090 solver_class : Solver
2091 A class of Solver to be used.
2092 -------
2093 solver_function : function
2094 A function for solving one period of a problem.
2095 """
2097 def one_period_solver(**kwds):
2098 solver = solver_class(**kwds)
2100 # not ideal; better if this is defined in all Solver classes
2101 if hasattr(solver, "prepare_to_solve"):
2102 solver.prepare_to_solve()
2104 solution_now = solver.solve()
2105 return solution_now
2107 one_period_solver.solver_class = solver_class
2108 # This can be revisited once it is possible to export parameters
2109 one_period_solver.solver_args = get_arg_names(solver_class.__init__)[1:]
2111 return one_period_solver
2114# ========================================================================
2115# ========================================================================
2118class Market(Model):
2119 """
2120 A superclass to represent a central clearinghouse of information. Used for
2121 dynamic general equilibrium models to solve the "macroeconomic" model as a
2122 layer on top of the "microeconomic" models of one or more AgentTypes.
2124 Parameters
2125 ----------
2126 agents : [AgentType]
2127 A list of all the AgentTypes in this market.
2128 sow_vars : [string]
2129 Names of variables generated by the "aggregate market process" that should
2130 be "sown" to the agents in the market. Aggregate state, etc.
2131 reap_vars : [string]
2132 Names of variables to be collected ("reaped") from agents in the market
2133 to be used in the "aggregate market process".
2134 const_vars : [string]
2135 Names of attributes of the Market instance that are used in the "aggregate
2136 market process" but do not come from agents-- they are constant or simply
2137 parameters inherent to the process.
2138 track_vars : [string]
2139 Names of variables generated by the "aggregate market process" that should
2140 be tracked as a "history" so that a new dynamic rule can be calculated.
2141 This is often a subset of sow_vars.
2142 dyn_vars : [string]
2143 Names of variables that constitute a "dynamic rule".
2144 mill_rule : function
2145 A function that takes inputs named in reap_vars and returns a tuple the
2146 same size and order as sow_vars. The "aggregate market process" that
2147 transforms individual agent actions/states/data into aggregate data to
2148 be sent back to agents.
2149 calc_dynamics : function
2150 A function that takes inputs named in track_vars and returns an object
2151 with attributes named in dyn_vars. Looks at histories of aggregate
2152 variables and generates a new "dynamic rule" for agents to believe and
2153 act on.
2154 act_T : int
2155 The number of times that the "aggregate market process" should be run
2156 in order to generate a history of aggregate variables.
2157 tolerance: float
2158 Minimum acceptable distance between "dynamic rules" to consider the
2159 Market solution process converged. Distance is a user-defined metric.
2160 """
2162 def __init__(
2163 self,
2164 agents=None,
2165 sow_vars=None,
2166 reap_vars=None,
2167 const_vars=None,
2168 track_vars=None,
2169 dyn_vars=None,
2170 mill_rule=None,
2171 calc_dynamics=None,
2172 distributions=None,
2173 act_T=1000,
2174 tolerance=0.000001,
2175 seed=0,
2176 **kwds,
2177 ):
2178 super().__init__()
2179 self.agents = agents if agents is not None else list() # NOQA
2181 self.reap_vars = reap_vars if reap_vars is not None else list() # NOQA
2182 self.reap_state = {var: [] for var in self.reap_vars}
2184 self.sow_vars = sow_vars if sow_vars is not None else list() # NOQA
2185 # dictionaries for tracking initial and current values
2186 # of the sow variables.
2187 self.sow_init = {var: None for var in self.sow_vars}
2188 self.sow_state = {var: None for var in self.sow_vars}
2190 const_vars = const_vars if const_vars is not None else list() # NOQA
2191 self.const_vars = {var: None for var in const_vars}
2193 self.track_vars = track_vars if track_vars is not None else list() # NOQA
2194 self.dyn_vars = dyn_vars if dyn_vars is not None else list() # NOQA
2195 self.distributions = distributions if distributions is not None else list() # NOQA
2197 if mill_rule is not None: # To prevent overwriting of method-based mill_rules
2198 self.mill_rule = mill_rule
2199 if calc_dynamics is not None: # Ditto for calc_dynamics
2200 self.calc_dynamics = calc_dynamics
2201 self.act_T = act_T # NOQA
2202 self.tolerance = tolerance # NOQA
2203 self.seed = seed
2204 self.max_loops = 1000 # NOQA
2205 self.history = {}
2206 self.assign_parameters(**kwds)
2207 self.RNG = np.random.default_rng(self.seed)
2209 self.print_parallel_error_once = True
2210 # Print the error associated with calling the parallel method
2211 # "solve_agents" one time. If set to false, the error will never
2212 # print. See "solve_agents" for why this prints once or never.
2214 def solve_agents(self):
2215 """
2216 Solves the microeconomic problem for all AgentTypes in this market.
2218 Parameters
2219 ----------
2220 None
2222 Returns
2223 -------
2224 None
2225 """
2226 try:
2227 multi_thread_commands(self.agents, ["solve()"])
2228 except Exception as err:
2229 if self.print_parallel_error_once:
2230 # Set flag to False so this is only printed once.
2231 self.print_parallel_error_once = False
2232 print(
2233 "**** WARNING: could not execute multi_thread_commands in HARK.core.Market.solve_agents() ",
2234 "so using the serial version instead. This will likely be slower. "
2235 "The multi_thread_commands() functions failed with the following error:",
2236 "\n",
2237 sys.exc_info()[0],
2238 ":",
2239 err,
2240 ) # sys.exc_info()[0])
2241 multi_thread_commands_fake(self.agents, ["solve()"])
2243 def solve(self):
2244 """
2245 "Solves" the market by finding a "dynamic rule" that governs the aggregate
2246 market state such that when agents believe in these dynamics, their actions
2247 collectively generate the same dynamic rule.
2249 Parameters
2250 ----------
2251 None
2253 Returns
2254 -------
2255 None
2256 """
2257 go = True
2258 max_loops = self.max_loops # Failsafe against infinite solution loop
2259 completed_loops = 0
2260 old_dynamics = None
2262 while go: # Loop until the dynamic process converges or we hit the loop cap
2263 self.solve_agents() # Solve each AgentType's micro problem
2264 self.make_history() # "Run" the model while tracking aggregate variables
2265 new_dynamics = self.update_dynamics() # Find a new aggregate dynamic rule
2267 # Check to see if the dynamic rule has converged (if this is not the first loop)
2268 if completed_loops > 0:
2269 distance = new_dynamics.distance(old_dynamics)
2270 else:
2271 distance = 1000000.0
2273 # Move to the next loop if the terminal conditions are not met
2274 old_dynamics = new_dynamics
2275 completed_loops += 1
2276 go = distance >= self.tolerance and completed_loops < max_loops
2278 self.dynamics = new_dynamics # Store the final dynamic rule in self
2280 def reap(self):
2281 """
2282 Collects attributes named in reap_vars from each AgentType in the market,
2283 storing them in respectively named attributes of self.
2285 Parameters
2286 ----------
2287 none
2289 Returns
2290 -------
2291 none
2292 """
2293 for var in self.reap_state:
2294 harvest = []
2296 for agent in self.agents:
2297 # TODO: generalized variable lookup across namespaces
2298 if var in agent.state_now:
2299 # or state_now ??
2300 harvest.append(agent.state_now[var])
2302 self.reap_state[var] = harvest
2304 def sow(self):
2305 """
2306 Distributes attrributes named in sow_vars from self to each AgentType
2307 in the market, storing them in respectively named attributes.
2309 Parameters
2310 ----------
2311 none
2313 Returns
2314 -------
2315 none
2316 """
2317 for sow_var in self.sow_state:
2318 for this_type in self.agents:
2319 if sow_var in this_type.state_now:
2320 this_type.state_now[sow_var] = self.sow_state[sow_var]
2321 if sow_var in this_type.shocks:
2322 this_type.shocks[sow_var] = self.sow_state[sow_var]
2323 else:
2324 setattr(this_type, sow_var, self.sow_state[sow_var])
2326 def mill(self):
2327 """
2328 Processes the variables collected from agents using the function mill_rule,
2329 storing the results in attributes named in aggr_sow.
2331 Parameters
2332 ----------
2333 none
2335 Returns
2336 -------
2337 none
2338 """
2339 # Make a dictionary of inputs for the mill_rule
2340 mill_dict = copy(self.reap_state)
2341 mill_dict.update(self.const_vars)
2343 # Run the mill_rule and store its output in self
2344 product = self.mill_rule(**mill_dict)
2346 for i, sow_var in enumerate(self.sow_state):
2347 self.sow_state[sow_var] = product[i]
2349 def cultivate(self):
2350 """
2351 Has each AgentType in agents perform their market_action method, using
2352 variables sown from the market (and maybe also "private" variables).
2353 The market_action method should store new results in attributes named in
2354 reap_vars to be reaped later.
2356 Parameters
2357 ----------
2358 none
2360 Returns
2361 -------
2362 none
2363 """
2364 for this_type in self.agents:
2365 this_type.market_action()
2367 def reset(self):
2368 """
2369 Reset the state of the market (attributes in sow_vars, etc) to some
2370 user-defined initial state, and erase the histories of tracked variables.
2371 Also resets the internal RNG so that draws can be reproduced.
2373 Parameters
2374 ----------
2375 none
2377 Returns
2378 -------
2379 none
2380 """
2381 # Reset internal RNG and distributions
2382 for name in self.distributions:
2383 if not hasattr(self, name):
2384 continue
2385 dstn = getattr(self, name)
2386 if isinstance(dstn, list):
2387 for D in dstn:
2388 D.reset()
2389 else:
2390 dstn.reset()
2392 # Reset the history of tracked variables
2393 self.history = {var_name: [] for var_name in self.track_vars}
2395 # Set the sow variables to their initial levels
2396 for var_name in self.sow_state:
2397 self.sow_state[var_name] = self.sow_init[var_name]
2399 # Reset each AgentType in the market
2400 for this_type in self.agents:
2401 this_type.reset()
2403 def store(self):
2404 """
2405 Record the current value of each variable X named in track_vars in an
2406 dictionary field named history[X].
2408 Parameters
2409 ----------
2410 none
2412 Returns
2413 -------
2414 none
2415 """
2416 for var_name in self.track_vars:
2417 if var_name in self.sow_state:
2418 value_now = self.sow_state[var_name]
2419 elif var_name in self.reap_state:
2420 value_now = self.reap_state[var_name]
2421 elif var_name in self.const_vars:
2422 value_now = self.const_vars[var_name]
2423 else:
2424 value_now = getattr(self, var_name)
2426 self.history[var_name].append(value_now)
2428 def make_history(self):
2429 """
2430 Runs a loop of sow-->cultivate-->reap-->mill act_T times, tracking the
2431 evolution of variables X named in track_vars in dictionary fields
2432 history[X].
2434 Parameters
2435 ----------
2436 none
2438 Returns
2439 -------
2440 none
2441 """
2442 self.reset() # Initialize the state of the market
2443 for t in range(self.act_T):
2444 self.sow() # Distribute aggregated information/state to agents
2445 self.cultivate() # Agents take action
2446 self.reap() # Collect individual data from agents
2447 self.mill() # Process individual data into aggregate data
2448 self.store() # Record variables of interest
2450 def update_dynamics(self):
2451 """
2452 Calculates a new "aggregate dynamic rule" using the history of variables
2453 named in track_vars, and distributes this rule to AgentTypes in agents.
2455 Parameters
2456 ----------
2457 none
2459 Returns
2460 -------
2461 dynamics : instance
2462 The new "aggregate dynamic rule" that agents believe in and act on.
2463 Should have attributes named in dyn_vars.
2464 """
2465 # Make a dictionary of inputs for the dynamics calculator
2466 arg_names = list(get_arg_names(self.calc_dynamics))
2467 if "self" in arg_names:
2468 arg_names.remove("self")
2469 update_dict = {name: self.history[name] for name in arg_names}
2470 # Calculate a new dynamic rule and distribute it to the agents in agent_list
2471 dynamics = self.calc_dynamics(**update_dict) # User-defined dynamics calculator
2472 for var_name in self.dyn_vars:
2473 this_obj = getattr(dynamics, var_name)
2474 for this_type in self.agents:
2475 setattr(this_type, var_name, this_obj)
2476 return dynamics
2479def distribute_params(agent, param_name, param_count, distribution):
2480 """
2481 Distributes heterogeneous values of one parameter to the AgentTypes in self.agents.
2482 Parameters
2483 ----------
2484 agent: AgentType
2485 An agent to clone.
2486 param_name : string
2487 Name of the parameter to be assigned.
2488 param_count : int
2489 Number of different values the parameter will take on.
2490 distribution : Distribution
2491 A 1-D distribution.
2493 Returns
2494 -------
2495 agent_set : [AgentType]
2496 A list of param_count agents, ex ante heterogeneous with
2497 respect to param_name. The AgentCount of the original
2498 will be split between the agents of the returned
2499 list in proportion to the given distribution.
2500 """
2501 param_dist = distribution.discretize(N=param_count)
2503 agent_set = [deepcopy(agent) for i in range(param_count)]
2505 for j in range(param_count):
2506 agent_set[j].assign_parameters(
2507 **{"AgentCount": int(agent.AgentCount * param_dist.pmv[j])}
2508 )
2509 agent_set[j].assign_parameters(**{param_name: param_dist.atoms[0, j]})
2511 return agent_set
2514@dataclass
2515class AgentPopulation:
2516 """
2517 A class for representing a population of ex-ante heterogeneous agents.
2518 """
2520 agent_type: AgentType # type of agent in the population
2521 parameters: dict # dictionary of parameters
2522 seed: int = 0 # random seed
2523 time_var: List[str] = field(init=False)
2524 time_inv: List[str] = field(init=False)
2525 distributed_params: List[str] = field(init=False)
2526 agent_type_count: Optional[int] = field(init=False)
2527 term_age: Optional[int] = field(init=False)
2528 continuous_distributions: Dict[str, Distribution] = field(init=False)
2529 discrete_distributions: Dict[str, Distribution] = field(init=False)
2530 population_parameters: List[Dict[str, Union[List[float], float]]] = field(
2531 init=False
2532 )
2533 agents: List[AgentType] = field(init=False)
2534 agent_database: pd.DataFrame = field(init=False)
2535 solution: List[Any] = field(init=False)
2537 def __post_init__(self):
2538 """
2539 Initialize the population of agents, determine distributed parameters,
2540 and infer `agent_type_count` and `term_age`.
2541 """
2542 # create a dummy agent and obtain its time-varying
2543 # and time-invariant attributes
2544 dummy_agent = self.agent_type()
2545 self.time_var = dummy_agent.time_vary
2546 self.time_inv = dummy_agent.time_inv
2548 # create list of distributed parameters
2549 # these are parameters that differ across agents
2550 self.distributed_params = [
2551 key
2552 for key, param in self.parameters.items()
2553 if (isinstance(param, list) and isinstance(param[0], list))
2554 or isinstance(param, Distribution)
2555 or (isinstance(param, DataArray) and param.dims[0] == "agent")
2556 ]
2558 self.__infer_counts__()
2560 self.print_parallel_error_once = True
2561 # Print warning once if parallel simulation fails
2563 def __infer_counts__(self):
2564 """
2565 Infer `agent_type_count` and `term_age` from the parameters.
2566 If parameters include a `Distribution` type, a list of lists,
2567 or a `DataArray` with `agent` as the first dimension, then
2568 the AgentPopulation contains ex-ante heterogenous agents.
2569 """
2571 # infer agent_type_count from distributed parameters
2572 agent_type_count = 1
2573 for key in self.distributed_params:
2574 param = self.parameters[key]
2575 if isinstance(param, Distribution):
2576 agent_type_count = None
2577 warn(
2578 "Cannot infer agent_type_count from a Distribution. "
2579 "Please provide approximation parameters."
2580 )
2581 break
2582 elif isinstance(param, list):
2583 agent_type_count = max(agent_type_count, len(param))
2584 elif isinstance(param, DataArray) and param.dims[0] == "agent":
2585 agent_type_count = max(agent_type_count, param.shape[0])
2587 self.agent_type_count = agent_type_count
2589 # infer term_age from all parameters
2590 term_age = 1
2591 for param in self.parameters.values():
2592 if isinstance(param, Distribution):
2593 term_age = None
2594 warn(
2595 "Cannot infer term_age from a Distribution. "
2596 "Please provide approximation parameters."
2597 )
2598 break
2599 elif isinstance(param, list) and isinstance(param[0], list):
2600 term_age = max(term_age, len(param[0]))
2601 elif isinstance(param, DataArray) and param.dims[-1] == "age":
2602 term_age = max(term_age, param.shape[-1])
2604 self.term_age = term_age
2606 def approx_distributions(self, approx_params: dict):
2607 """
2608 Approximate continuous distributions with discrete ones. If the initial
2609 parameters include a `Distribution` type, then the AgentPopulation is
2610 not ready to solve, and stands for an abstract population. To solve the
2611 AgentPopulation, we need discretization parameters for each continuous
2612 distribution. This method approximates the continuous distributions with
2613 discrete ones, and updates the parameters dictionary.
2614 """
2615 self.continuous_distributions = {}
2616 self.discrete_distributions = {}
2618 for key, args in approx_params.items():
2619 param = self.parameters[key]
2620 if key in self.distributed_params and isinstance(param, Distribution):
2621 self.continuous_distributions[key] = param
2622 self.discrete_distributions[key] = param.discretize(**args)
2623 else:
2624 raise ValueError(
2625 f"Warning: parameter {key} is not a Distribution found "
2626 f"in agent type {self.agent_type}"
2627 )
2629 if len(self.discrete_distributions) > 1:
2630 joint_dist = combine_indep_dstns(*self.discrete_distributions.values())
2631 else:
2632 joint_dist = list(self.discrete_distributions.values())[0]
2634 for i, key in enumerate(self.discrete_distributions):
2635 self.parameters[key] = DataArray(joint_dist.atoms[i], dims=("agent"))
2637 self.__infer_counts__()
2639 def __parse_parameters__(self) -> None:
2640 """
2641 Creates distributed dictionaries of parameters for each ex-ante
2642 heterogeneous agent in the parameterized population. The parameters
2643 are stored in a list of dictionaries, where each dictionary contains
2644 the parameters for one agent. Expands parameters that vary over time
2645 to a list of length `term_age`.
2646 """
2648 population_parameters = [] # container for dictionaries of each agent subgroup
2649 for agent in range(self.agent_type_count):
2650 agent_parameters = {}
2651 for key, param in self.parameters.items():
2652 if key in self.time_var:
2653 # parameters that vary over time have to be repeated
2654 if isinstance(param, (int, float)):
2655 parameter_per_t = [param] * self.term_age
2656 elif isinstance(param, list):
2657 if isinstance(param[0], list):
2658 parameter_per_t = param[agent]
2659 else:
2660 parameter_per_t = param
2661 elif isinstance(param, DataArray):
2662 if param.dims[0] == "agent":
2663 if len(param.dims) > 1 and param.dims[-1] == "age":
2664 parameter_per_t = param[agent].values.tolist()
2665 else:
2666 parameter_per_t = param[agent].item()
2667 elif param.dims[0] == "age":
2668 parameter_per_t = param.values.tolist()
2670 agent_parameters[key] = parameter_per_t
2672 elif key in self.time_inv:
2673 if isinstance(param, (int, float)):
2674 agent_parameters[key] = param
2675 elif isinstance(param, list):
2676 if isinstance(param[0], list):
2677 agent_parameters[key] = param[agent]
2678 else:
2679 agent_parameters[key] = param
2680 elif isinstance(param, DataArray) and param.dims[0] == "agent":
2681 agent_parameters[key] = param[agent].item()
2683 else:
2684 if isinstance(param, (int, float)):
2685 agent_parameters[key] = param # assume time inv
2686 elif isinstance(param, list):
2687 if isinstance(param[0], list):
2688 agent_parameters[key] = param[agent] # assume agent vary
2689 else:
2690 agent_parameters[key] = param # assume time vary
2691 elif isinstance(param, DataArray):
2692 if param.dims[0] == "agent":
2693 if len(param.dims) > 1 and param.dims[-1] == "age":
2694 agent_parameters[key] = param[agent].values.tolist()
2695 else:
2696 agent_parameters[key] = param[agent].item()
2697 elif param.dims[0] == "age":
2698 agent_parameters[key] = param.values.tolist()
2700 population_parameters.append(agent_parameters)
2702 self.population_parameters = population_parameters
2704 def create_distributed_agents(self):
2705 """
2706 Parses the parameters dictionary and creates a list of agents with the
2707 appropriate parameters. Also sets the seed for each agent.
2708 """
2710 self.__parse_parameters__()
2712 rng = np.random.default_rng(self.seed)
2714 self.agents = [
2715 self.agent_type(seed=rng.integers(0, 2**31 - 1), **agent_dict)
2716 for agent_dict in self.population_parameters
2717 ]
2719 def create_database(self):
2720 """
2721 Optionally creates a pandas DataFrame with the parameters for each agent.
2722 """
2723 database = pd.DataFrame(self.population_parameters)
2724 database["agents"] = self.agents
2726 self.agent_database = database
2728 def solve(self):
2729 """
2730 Solves each agent of the population serially.
2731 """
2733 # see Market class for an example of how to solve distributed agents in parallel
2735 for agent in self.agents:
2736 agent.solve()
2738 def unpack_solutions(self):
2739 """
2740 Unpacks the solutions of each agent into an attribute of the population.
2741 """
2742 self.solution = [agent.solution for agent in self.agents]
2744 def initialize_sim(self):
2745 """
2746 Initializes the simulation for each agent.
2747 """
2748 for agent in self.agents:
2749 agent.initialize_sim()
2751 def simulate(self, num_jobs=None):
2752 """
2753 Simulates each agent of the population.
2755 Parameters
2756 ----------
2757 num_jobs : int, optional
2758 Number of parallel jobs to use. Defaults to using all available
2759 cores when ``None``. Falls back to serial execution if parallel
2760 processing fails.
2761 """
2762 try:
2763 multi_thread_commands(self.agents, ["simulate()"], num_jobs)
2764 except Exception as err:
2765 if getattr(self, "print_parallel_error_once", False):
2766 self.print_parallel_error_once = False
2767 print(
2768 "**** WARNING: could not execute multi_thread_commands in HARK.core.AgentPopulation.simulate() ",
2769 "so using the serial version instead. This will likely be slower. ",
2770 "The multi_thread_commands() function failed with the following error:\n",
2771 sys.exc_info()[0],
2772 ":",
2773 err,
2774 )
2775 multi_thread_commands_fake(self.agents, ["simulate()"], num_jobs)
2777 def __iter__(self):
2778 """
2779 Allows for iteration over the agents in the population.
2780 """
2781 return iter(self.agents)
2783 def __getitem__(self, idx):
2784 """
2785 Allows for indexing into the population.
2786 """
2787 return self.agents[idx]
2790###############################################################################
2793def multi_thread_commands_fake(
2794 agent_list: List, command_list: List, num_jobs=None
2795) -> None:
2796 """
2797 Executes the list of commands in command_list for each AgentType in agent_list
2798 in an ordinary, single-threaded loop. Each command should be a method of
2799 that AgentType subclass. This function exists so as to easily disable
2800 multithreading, as it uses the same syntax as multi_thread_commands.
2802 Parameters
2803 ----------
2804 agent_list : [AgentType]
2805 A list of instances of AgentType on which the commands will be run.
2806 command_list : [string]
2807 A list of commands to run for each AgentType.
2808 num_jobs : None
2809 Dummy input to match syntax of multi_thread_commands. Does nothing.
2811 Returns
2812 -------
2813 none
2814 """
2815 for agent in agent_list:
2816 for command in command_list:
2817 # Can pass method names with or without parentheses
2818 if command[-2:] == "()":
2819 getattr(agent, command[:-2])()
2820 else:
2821 getattr(agent, command)()
2824def multi_thread_commands(agent_list: List, command_list: List, num_jobs=None) -> None:
2825 """
2826 Executes the list of commands in command_list for each AgentType in agent_list
2827 using a multithreaded system. Each command should be a method of that AgentType subclass.
2829 Parameters
2830 ----------
2831 agent_list : [AgentType]
2832 A list of instances of AgentType on which the commands will be run.
2833 command_list : [string]
2834 A list of commands to run for each AgentType in agent_list.
2836 Returns
2837 -------
2838 None
2839 """
2840 if len(agent_list) == 1:
2841 multi_thread_commands_fake(agent_list, command_list)
2842 return None
2844 # Default number of parallel jobs is the smaller of number of AgentTypes in
2845 # the input and the number of available cores.
2846 if num_jobs is None:
2847 num_jobs = min(len(agent_list), multiprocessing.cpu_count())
2849 # Send each command in command_list to each of the types in agent_list to be run
2850 agent_list_out = Parallel(n_jobs=num_jobs)(
2851 delayed(run_commands)(*args)
2852 for args in zip(agent_list, len(agent_list) * [command_list])
2853 )
2855 # Replace the original types with the output from the parallel call
2856 for j in range(len(agent_list)):
2857 agent_list[j] = agent_list_out[j]
2860def run_commands(agent: Any, command_list: List) -> Any:
2861 """
2862 Executes each command in command_list on a given AgentType. The commands
2863 should be methods of that AgentType's subclass.
2865 Parameters
2866 ----------
2867 agent : AgentType
2868 An instance of AgentType on which the commands will be run.
2869 command_list : [string]
2870 A list of commands that the agent should run, as methods.
2872 Returns
2873 -------
2874 agent : AgentType
2875 The same AgentType instance passed as input, after running the commands.
2876 """
2877 for command in command_list:
2878 # Can pass method names with or without parentheses
2879 if command[-2:] == "()":
2880 getattr(agent, command[:-2])()
2881 else:
2882 getattr(agent, command)()
2883 return agent