Coverage for HARK / core.py: 94%
948 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-07 05:16 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-07 05:16 +0000
1"""
2High-level functions and classes for solving a wide variety of economic models.
3The "core" of HARK is a framework for "microeconomic" and "macroeconomic"
4models. A micro model concerns the dynamic optimization problem for some type
5of agents, where agents take the inputs to their problem as exogenous. A macro
6model adds an additional layer, endogenizing some of the inputs to the micro
7problem by finding a general equilibrium dynamic rule.
8"""
10# Import basic modules
11import inspect
12import sys
13from collections import namedtuple
14from copy import copy, deepcopy
15from dataclasses import dataclass, field
16from time import time
17from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
18from warnings import warn
19import multiprocessing
20from joblib import Parallel, delayed
22import numpy as np
23import pandas as pd
24from xarray import DataArray
26from HARK.distributions import (
27 Distribution,
28 IndexDistribution,
29 combine_indep_dstns,
30)
31from HARK.utilities import NullFunc, get_arg_names, get_it_from
32from HARK.simulator import make_simulator_from_agent
33from HARK.SSJutils import (
34 make_basic_SSJ_matrices,
35 calc_shock_response_manually,
36)
37from HARK.metric import MetricObject
39__all__ = [
40 "AgentType",
41 "Market",
42 "Parameters",
43 "Model",
44 "AgentPopulation",
45 "multi_thread_commands",
46 "multi_thread_commands_fake",
47 "NullFunc",
48 "make_one_period_oo_solver",
49 "distribute_params",
50]
53class Parameters:
54 """
55 A smart container for model parameters that handles age-varying dynamics.
57 This class stores parameters as an internal dictionary and manages their
58 age-varying properties, providing both attribute-style and dictionary-style
59 access. It is designed to handle the time-varying dynamics of parameters
60 in economic models.
62 Attributes
63 ----------
64 _length : int
65 The terminal age of the agents in the model.
66 _invariant_params : Set[str]
67 A set of parameter names that are invariant over time.
68 _varying_params : Set[str]
69 A set of parameter names that vary over time.
70 _parameters : Dict[str, Any]
71 The internal dictionary storing all parameters.
72 """
74 __slots__ = (
75 "_length",
76 "_invariant_params",
77 "_varying_params",
78 "_parameters",
79 "_frozen",
80 "_namedtuple_cache",
81 )
83 def __init__(self, **parameters: Any) -> None:
84 """
85 Initialize a Parameters object and parse the age-varying dynamics of parameters.
87 Parameters
88 ----------
89 T_cycle : int, optional
90 The number of time periods in the model cycle (default: 1).
91 Must be >= 1.
92 frozen : bool, optional
93 If True, the Parameters object will be immutable after initialization
94 (default: False).
95 _time_inv : List[str], optional
96 List of parameter names to explicitly mark as time-invariant,
97 overriding automatic inference.
98 _time_vary : List[str], optional
99 List of parameter names to explicitly mark as time-varying,
100 overriding automatic inference.
101 **parameters : Any
102 Any number of parameters in the form key=value.
104 Raises
105 ------
106 ValueError
107 If T_cycle is less than 1.
109 Notes
110 -----
111 Automatic time-variance inference rules:
112 - Scalars (int, float, bool, None) are time-invariant
113 - NumPy arrays are time-invariant (use lists/tuples for time-varying)
114 - Single-element lists/tuples [x] are unwrapped to x and time-invariant
115 - Multi-element lists/tuples are time-varying if length matches T_cycle
116 - 2D arrays with first dimension matching T_cycle are time-varying
117 - Distributions and Callables are time-invariant
119 Use _time_inv or _time_vary to override automatic inference when needed.
120 """
121 # Extract special parameters
122 self._length: int = parameters.pop("T_cycle", 1)
123 frozen: bool = parameters.pop("frozen", False)
124 time_inv_override: List[str] = parameters.pop("_time_inv", [])
125 time_vary_override: List[str] = parameters.pop("_time_vary", [])
127 # Validate T_cycle
128 if self._length < 1:
129 raise ValueError(f"T_cycle must be >= 1, got {self._length}")
131 # Initialize internal state
132 self._invariant_params: Set[str] = set()
133 self._varying_params: Set[str] = set()
134 self._parameters: Dict[str, Any] = {"T_cycle": self._length}
135 self._frozen: bool = False # Set to False initially to allow setup
136 self._namedtuple_cache: Optional[type] = None
138 # Set parameters using automatic inference
139 for key, value in parameters.items():
140 self[key] = value
142 # Apply explicit overrides
143 for param in time_inv_override:
144 if param in self._parameters:
145 self._invariant_params.add(param)
146 self._varying_params.discard(param)
148 for param in time_vary_override:
149 if param in self._parameters:
150 self._varying_params.add(param)
151 self._invariant_params.discard(param)
153 # Freeze if requested
154 self._frozen = frozen
156 def __getitem__(self, item_or_key: Union[int, str]) -> Union["Parameters", Any]:
157 """
158 Access parameters by age index or parameter name.
160 If item_or_key is an integer, returns a Parameters object with the parameters
161 that apply to that age. This includes all invariant parameters and the
162 `item_or_key`th element of all age-varying parameters. If item_or_key is a
163 string, it returns the value of the parameter with that name.
165 Parameters
166 ----------
167 item_or_key : Union[int, str]
168 Age index or parameter name.
170 Returns
171 -------
172 Union[Parameters, Any]
173 A new Parameters object for the specified age, or the value of the
174 specified parameter.
176 Raises
177 ------
178 ValueError:
179 If the age index is out of bounds.
180 KeyError:
181 If the parameter name is not found.
182 TypeError:
183 If the key is neither an integer nor a string.
184 """
185 if isinstance(item_or_key, int):
186 if item_or_key < 0 or item_or_key >= self._length:
187 raise ValueError(
188 f"Age {item_or_key} is out of bounds (valid: 0-{self._length - 1})."
189 )
191 params = {key: self._parameters[key] for key in self._invariant_params}
192 params.update(
193 {
194 key: (
195 self._parameters[key][item_or_key]
196 if isinstance(self._parameters[key], (list, tuple, np.ndarray))
197 else self._parameters[key]
198 )
199 for key in self._varying_params
200 }
201 )
202 return Parameters(**params)
203 elif isinstance(item_or_key, str):
204 return self._parameters[item_or_key]
205 else:
206 raise TypeError("Key must be an integer (age) or string (parameter name).")
208 def __setitem__(self, key: str, value: Any) -> None:
209 """
210 Set parameter values, automatically inferring time variance.
212 If the parameter is a scalar, numpy array, boolean, distribution, callable
213 or None, it is assumed to be invariant over time. If the parameter is a
214 list or tuple, it is assumed to be varying over time. If the parameter
215 is a list or tuple of length greater than 1, the length of the list or
216 tuple must match the `_length` attribute of the Parameters object.
218 2D numpy arrays with first dimension matching T_cycle are treated as
219 time-varying parameters.
221 Parameters
222 ----------
223 key : str
224 Name of the parameter.
225 value : Any
226 Value of the parameter.
228 Raises
229 ------
230 ValueError:
231 If the parameter name is not a string or if the value type is unsupported.
232 If the parameter value is inconsistent with the current model length.
233 RuntimeError:
234 If the Parameters object is frozen.
235 """
236 if self._frozen:
237 raise RuntimeError("Cannot modify frozen Parameters object")
239 if not isinstance(key, str):
240 raise ValueError(f"Parameter name must be a string, got {type(key)}")
242 # Check for 2D numpy arrays with time-varying first dimension
243 if isinstance(value, np.ndarray) and value.ndim >= 2:
244 if value.shape[0] == self._length:
245 self._varying_params.add(key)
246 self._invariant_params.discard(key)
247 else:
248 self._invariant_params.add(key)
249 self._varying_params.discard(key)
250 elif isinstance(
251 value,
252 (
253 int,
254 float,
255 np.ndarray,
256 type(None),
257 Distribution,
258 bool,
259 Callable,
260 MetricObject,
261 ),
262 ):
263 self._invariant_params.add(key)
264 self._varying_params.discard(key)
265 elif isinstance(value, (list, tuple)):
266 if len(value) == 1:
267 value = value[0]
268 self._invariant_params.add(key)
269 self._varying_params.discard(key)
270 elif self._length is None or self._length == 1:
271 self._length = len(value)
272 self._varying_params.add(key)
273 self._invariant_params.discard(key)
274 elif len(value) == self._length:
275 self._varying_params.add(key)
276 self._invariant_params.discard(key)
277 else:
278 raise ValueError(
279 f"Parameter {key} must have length 1 or {self._length}, not {len(value)}"
280 )
281 else:
282 raise ValueError(f"Unsupported type for parameter {key}: {type(value)}")
284 self._parameters[key] = value
286 def __iter__(self) -> Iterator[str]:
287 """Allow iteration over parameter names."""
288 return iter(self._parameters)
290 def __len__(self) -> int:
291 """Return the number of parameters."""
292 return len(self._parameters)
294 def keys(self) -> Iterator[str]:
295 """Return a view of parameter names."""
296 return self._parameters.keys()
298 def values(self) -> Iterator[Any]:
299 """Return a view of parameter values."""
300 return self._parameters.values()
302 def items(self) -> Iterator[Tuple[str, Any]]:
303 """Return a view of parameter (name, value) pairs."""
304 return self._parameters.items()
306 def to_dict(self) -> Dict[str, Any]:
307 """
308 Convert parameters to a plain dictionary.
310 Returns
311 -------
312 Dict[str, Any]
313 A dictionary containing all parameters.
314 """
315 return dict(self._parameters)
317 def to_namedtuple(self) -> namedtuple:
318 """
319 Convert parameters to a namedtuple.
321 The namedtuple class is cached for efficiency on repeated calls.
323 Returns
324 -------
325 namedtuple
326 A namedtuple containing all parameters.
327 """
328 if self._namedtuple_cache is None:
329 self._namedtuple_cache = namedtuple("Parameters", self.keys())
330 return self._namedtuple_cache(**self.to_dict())
332 def update(self, other: Union["Parameters", Dict[str, Any]]) -> None:
333 """
334 Update parameters from another Parameters object or dictionary.
336 Parameters
337 ----------
338 other : Union[Parameters, Dict[str, Any]]
339 The source of parameters to update from.
341 Raises
342 ------
343 TypeError
344 If the input is neither a Parameters object nor a dictionary.
345 """
346 if isinstance(other, Parameters):
347 for key, value in other._parameters.items():
348 self[key] = value
349 elif isinstance(other, dict):
350 for key, value in other.items():
351 self[key] = value
352 else:
353 raise TypeError(
354 "Update source must be a Parameters object or a dictionary."
355 )
357 def __repr__(self) -> str:
358 """Return a detailed string representation of the Parameters object."""
359 return (
360 f"Parameters(_length={self._length}, "
361 f"_invariant_params={self._invariant_params}, "
362 f"_varying_params={self._varying_params}, "
363 f"_parameters={self._parameters})"
364 )
366 def __str__(self) -> str:
367 """Return a simple string representation of the Parameters object."""
368 return f"Parameters({str(self._parameters)})"
370 def __getattr__(self, name: str) -> Any:
371 """
372 Allow attribute-style access to parameters.
374 Parameters
375 ----------
376 name : str
377 Name of the parameter to access.
379 Returns
380 -------
381 Any
382 The value of the specified parameter.
384 Raises
385 ------
386 AttributeError:
387 If the parameter name is not found.
388 """
389 if name.startswith("_"):
390 return super().__getattribute__(name)
391 try:
392 return self._parameters[name]
393 except KeyError:
394 raise AttributeError(f"'Parameters' object has no attribute '{name}'")
396 def __setattr__(self, name: str, value: Any) -> None:
397 """
398 Allow attribute-style setting of parameters.
400 Parameters
401 ----------
402 name : str
403 Name of the parameter to set.
404 value : Any
405 Value to set for the parameter.
406 """
407 if name.startswith("_"):
408 super().__setattr__(name, value)
409 else:
410 self[name] = value
412 def __contains__(self, item: str) -> bool:
413 """Check if a parameter exists in the Parameters object."""
414 return item in self._parameters
416 def copy(self) -> "Parameters":
417 """
418 Create a deep copy of the Parameters object.
420 Returns
421 -------
422 Parameters
423 A new Parameters object with the same contents.
424 """
425 return deepcopy(self)
427 def add_to_time_vary(self, *params: str) -> None:
428 """
429 Adds any number of parameters to the time-varying set.
431 Parameters
432 ----------
433 *params : str
434 Any number of strings naming parameters to be added to time_vary.
435 """
436 for param in params:
437 if param in self._parameters:
438 self._varying_params.add(param)
439 self._invariant_params.discard(param)
440 else:
441 warn(
442 f"Parameter '{param}' does not exist and cannot be added to time_vary."
443 )
445 def add_to_time_inv(self, *params: str) -> None:
446 """
447 Adds any number of parameters to the time-invariant set.
449 Parameters
450 ----------
451 *params : str
452 Any number of strings naming parameters to be added to time_inv.
453 """
454 for param in params:
455 if param in self._parameters:
456 self._invariant_params.add(param)
457 self._varying_params.discard(param)
458 else:
459 warn(
460 f"Parameter '{param}' does not exist and cannot be added to time_inv."
461 )
463 def del_from_time_vary(self, *params: str) -> None:
464 """
465 Removes any number of parameters from the time-varying set.
467 Parameters
468 ----------
469 *params : str
470 Any number of strings naming parameters to be removed from time_vary.
471 """
472 for param in params:
473 self._varying_params.discard(param)
475 def del_from_time_inv(self, *params: str) -> None:
476 """
477 Removes any number of parameters from the time-invariant set.
479 Parameters
480 ----------
481 *params : str
482 Any number of strings naming parameters to be removed from time_inv.
483 """
484 for param in params:
485 self._invariant_params.discard(param)
487 def get(self, key: str, default: Any = None) -> Any:
488 """
489 Get a parameter value, returning a default if not found.
491 Parameters
492 ----------
493 key : str
494 The parameter name.
495 default : Any, optional
496 The default value to return if the key is not found.
498 Returns
499 -------
500 Any
501 The parameter value or the default.
502 """
503 return self._parameters.get(key, default)
505 def set_many(self, **kwargs: Any) -> None:
506 """
507 Set multiple parameters at once.
509 Parameters
510 ----------
511 **kwargs : Keyword arguments representing parameter names and values.
512 """
513 for key, value in kwargs.items():
514 self[key] = value
516 def is_time_varying(self, key: str) -> bool:
517 """
518 Check if a parameter is time-varying.
520 Parameters
521 ----------
522 key : str
523 The parameter name.
525 Returns
526 -------
527 bool
528 True if the parameter is time-varying, False otherwise.
529 """
530 return key in self._varying_params
532 def at_age(self, age: int) -> "Parameters":
533 """
534 Get parameters for a specific age.
536 This is an alternative to integer indexing (params[age]) that is more
537 explicit and avoids potential confusion with dictionary-style access.
539 Parameters
540 ----------
541 age : int
542 The age index to retrieve parameters for.
544 Returns
545 -------
546 Parameters
547 A new Parameters object with parameters for the specified age.
549 Raises
550 ------
551 ValueError
552 If the age index is out of bounds.
554 Examples
555 --------
556 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97], sigma=2.0)
557 >>> age_1_params = params.at_age(1)
558 >>> age_1_params.beta
559 0.96
560 """
561 return self[age]
563 def validate(self) -> None:
564 """
565 Validate parameter consistency.
567 Checks that all time-varying parameters have length matching T_cycle.
568 This is useful after manual modifications or when parameters are set
569 programmatically.
571 Raises
572 ------
573 ValueError
574 If any time-varying parameter has incorrect length.
576 Examples
577 --------
578 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97])
579 >>> params.validate() # Passes
580 >>> params.add_to_time_vary("beta")
581 >>> params.validate() # Still passes
582 """
583 errors = []
584 for param in self._varying_params:
585 value = self._parameters[param]
586 if isinstance(value, (list, tuple)):
587 if len(value) != self._length:
588 errors.append(
589 f"Parameter '{param}' has length {len(value)}, expected {self._length}"
590 )
591 elif isinstance(value, np.ndarray):
592 if value.ndim == 0:
593 errors.append(
594 f"Parameter '{param}' is a 0-dimensional array (scalar), "
595 "which should not be time-varying"
596 )
597 elif value.ndim >= 2:
598 if value.shape[0] != self._length:
599 errors.append(
600 f"Parameter '{param}' has first dimension {value.shape[0]}, expected {self._length}"
601 )
602 elif value.ndim == 1:
603 if len(value) != self._length:
604 errors.append(
605 f"Parameter '{param}' has length {len(value)}, expected {self._length}"
606 )
607 elif value.ndim == 0:
608 errors.append(
609 f"Parameter '{param}' is a 0-dimensional numpy array, expected length {self._length}"
610 )
612 if errors:
613 raise ValueError(
614 "Parameter validation failed:\n" + "\n".join(f" - {e}" for e in errors)
615 )
618class Model:
619 """
620 A class with special handling of parameters assignment.
621 """
623 def __init__(self):
624 if not hasattr(self, "parameters"):
625 self.parameters = {}
626 if not hasattr(self, "constructors"):
627 self.constructors = {}
629 def assign_parameters(self, **kwds):
630 """
631 Assign an arbitrary number of attributes to this agent.
633 Parameters
634 ----------
635 **kwds : keyword arguments
636 Any number of keyword arguments of the form key=value.
637 Each value will be assigned to the attribute named in self.
639 Returns
640 -------
641 None
642 """
643 self.parameters.update(kwds)
644 for key in kwds:
645 setattr(self, key, kwds[key])
647 def get_parameter(self, name):
648 """
649 Returns a parameter of this model
651 Parameters
652 ----------
653 name : str
654 The name of the parameter to get
656 Returns
657 -------
658 value : The value of the parameter
659 """
660 return self.parameters[name]
662 def __eq__(self, other):
663 if isinstance(other, type(self)):
664 return self.parameters == other.parameters
666 return NotImplemented
668 def __str__(self):
669 type_ = type(self)
670 module = type_.__module__
671 qualname = type_.__qualname__
673 s = f"<{module}.{qualname} object at {hex(id(self))}.\n"
674 s += "Parameters:"
676 for p in self.parameters:
677 s += f"\n{p}: {self.parameters[p]}"
679 s += ">"
680 return s
682 def describe(self):
683 return self.__str__()
685 def del_param(self, param_name):
686 """
687 Deletes a parameter from this instance, removing it both from the object's
688 namespace (if it's there) and the parameters dictionary (likewise).
690 Parameters
691 ----------
692 param_name : str
693 A string naming a parameter or data to be deleted from this instance.
694 Removes information from self.parameters dictionary and own namespace.
696 Returns
697 -------
698 None
699 """
700 if param_name in self.parameters:
701 del self.parameters[param_name]
702 if hasattr(self, param_name):
703 delattr(self, param_name)
705 def construct(self, *args, force=False):
706 """
707 Top-level method for building constructed inputs. If called without any
708 inputs, construct builds each of the objects named in the keys of the
709 constructors dictionary; it draws inputs for the constructors from the
710 parameters dictionary and adds its results to the same. If passed one or
711 more strings as arguments, the method builds only the named keys. The
712 method will do multiple "passes" over the requested keys, as some cons-
713 tructors require inputs built by other constructors. If any requested
714 constructors failed to build due to missing data, those keys (and the
715 missing data) will be named in self._missing_key_data. Other errors are
716 recorded in the dictionary attribute _constructor_errors.
718 This method tries to "start from scratch" by removing prior constructed
719 objects, holding them in a backup dictionary during construction. This
720 is done so that dependencies among constructors are resolved properly,
721 without mistakenly relying on "old information". A backup value is used
722 if a constructor function is set to None (i.e. "don't do anything"), or
723 if the construct method fails to produce a new object.
725 Parameters
726 ----------
727 *args : str, optional
728 Keys of self.constructors that are requested to be constructed.
729 If no arguments are passed, *all* elements of the dictionary are implied.
730 force : bool, optional
731 When True, the method will force its way past any errors, including
732 missing constructors, missing arguments for constructors, and errors
733 raised during execution of constructors. Information about all such
734 errors is stored in the dictionary attributes described above. When
735 False (default), any errors or exception will be raised.
737 Returns
738 -------
739 None
740 """
741 # Set up the requested work
742 if len(args) > 0:
743 keys = args
744 else:
745 keys = list(self.constructors.keys())
746 N_keys = len(keys)
747 keys_complete = np.zeros(N_keys, dtype=bool)
748 if N_keys == 0:
749 return # Do nothing if there are no constructed objects
751 # Remove pre-existing constructed objects, preventing "incomplete" updates,
752 # but store the current values in a backup dictionary in case something fails
753 backup = {}
754 for key in keys:
755 if hasattr(self, key):
756 backup[key] = getattr(self, key)
757 self.del_param(key)
759 # Get the dictionary of constructor errors
760 if not hasattr(self, "_constructor_errors"):
761 self._constructor_errors = {}
762 errors = self._constructor_errors
764 # As long as the work isn't complete and we made some progress on the last
765 # pass, repeatedly perform passes of trying to construct objects
766 any_keys_incomplete = np.any(np.logical_not(keys_complete))
767 go = any_keys_incomplete
768 while go:
769 anything_accomplished_this_pass = False # Nothing done yet!
770 missing_key_data = [] # Keep this up-to-date on each pass
772 # Loop over keys to be constructed
773 for i in range(N_keys):
774 if keys_complete[i]:
775 continue # This key has already been built
777 # Get this key and its constructor function
778 key = keys[i]
779 try:
780 constructor = self.constructors[key]
781 except Exception as not_found:
782 errors[key] = "No constructor found for " + str(not_found)
783 if force:
784 continue
785 else:
786 raise KeyError("No constructor found for " + key) from None
788 # If this constructor is None, do nothing and mark it as completed;
789 # this includes restoring the previous value if it exists
790 if constructor is None:
791 if key in backup.keys():
792 setattr(self, key, backup[key])
793 self.parameters[key] = backup[key]
794 keys_complete[i] = True
795 anything_accomplished_this_pass = True # We did something!
796 continue
798 # SPECIAL: if the constructor is get_it_from, handle it separately
799 if isinstance(constructor, get_it_from):
800 try:
801 parent = getattr(self, constructor.name)
802 query = key
803 any_missing = False
804 missing_args = []
805 except:
806 parent = None
807 query = None
808 any_missing = True
809 missing_args = [constructor.name]
810 temp_dict = {"parent": parent, "query": query}
812 # Get the names of arguments for this constructor and try to gather them
813 else: # (if it's not the special case of get_it_from)
814 args_needed = get_arg_names(constructor)
815 has_no_default = {
816 k: v.default is inspect.Parameter.empty
817 for k, v in inspect.signature(constructor).parameters.items()
818 }
819 temp_dict = {}
820 any_missing = False
821 missing_args = []
822 for j in range(len(args_needed)):
823 this_arg = args_needed[j]
824 if hasattr(self, this_arg):
825 temp_dict[this_arg] = getattr(self, this_arg)
826 else:
827 try:
828 temp_dict[this_arg] = self.parameters[this_arg]
829 except:
830 if has_no_default[this_arg]:
831 # Record missing key-data pair
832 any_missing = True
833 missing_key_data.append((key, this_arg))
834 missing_args.append(this_arg)
836 # If all of the required data was found, run the constructor and
837 # store the result in parameters (and on self)
838 if not any_missing:
839 try:
840 temp = constructor(**temp_dict)
841 except Exception as problem:
842 errors[key] = str(type(problem)) + ": " + str(problem)
843 self.del_param(key)
844 if force:
845 continue
846 else:
847 raise
848 setattr(self, key, temp)
849 self.parameters[key] = temp
850 if key in errors:
851 del errors[key]
852 keys_complete[i] = True
853 anything_accomplished_this_pass = True # We did something!
854 else:
855 msg = "Missing required arguments:"
856 for arg in missing_args:
857 msg += " " + arg + ","
858 msg = msg[:-1]
859 errors[key] = msg
860 self.del_param(key)
861 # Never raise exceptions here, as the arguments might be filled in later
863 # Check whether another pass should be performed
864 any_keys_incomplete = np.any(np.logical_not(keys_complete))
865 go = any_keys_incomplete and anything_accomplished_this_pass
867 # Store missing key-data pairs and exit
868 self._missing_key_data = missing_key_data
869 self._constructor_errors = errors
870 if any_keys_incomplete:
871 msg = "Did not construct these objects:"
872 for i in range(N_keys):
873 if keys_complete[i]:
874 continue
875 msg += " " + keys[i] + ","
876 key = keys[i]
877 if key in backup.keys():
878 setattr(self, key, backup[key])
879 self.parameters[key] = backup[key]
880 msg = msg[:-1]
881 if not force:
882 raise ValueError(msg)
883 return
885 def describe_constructors(self, *args):
886 """
887 Prints to screen a string describing this instance's constructed objects,
888 including their names, the function that constructs them, the names of
889 those functions inputs, and whether those inputs are present.
891 Parameters
892 ----------
893 *args : str, optional
894 Optional list of strings naming constructed inputs to be described.
895 If none are passed, all constructors are described.
897 Returns
898 -------
899 None
900 """
901 if len(args) > 0:
902 keys = args
903 else:
904 keys = list(self.constructors.keys())
905 yes = "\u2713"
906 no = "X"
907 maybe = "*"
908 noyes = [no, yes]
910 out = ""
911 for key in keys:
912 has_val = hasattr(self, key) or (key in self.parameters)
914 try:
915 constructor = self.constructors[key]
916 except:
917 out += noyes[int(has_val)] + " " + key + " : NO CONSTRUCTOR FOUND\n"
918 continue
920 # Get the constructor function if possible
921 if isinstance(constructor, get_it_from):
922 parent_name = self.constructors[key].name
923 out += (
924 noyes[int(has_val)]
925 + " "
926 + key
927 + " : get it from "
928 + parent_name
929 + "\n"
930 )
931 continue
932 else:
933 out += (
934 noyes[int(has_val)]
935 + " "
936 + key
937 + " : "
938 + constructor.__name__
939 + "\n"
940 )
942 # Get constructor argument names
943 arg_names = get_arg_names(constructor)
944 has_no_default = {
945 k: v.default is inspect.Parameter.empty
946 for k, v in inspect.signature(constructor).parameters.items()
947 }
949 # Check whether each argument exists
950 for j in range(len(arg_names)):
951 this_arg = arg_names[j]
952 if hasattr(self, this_arg) or this_arg in self.parameters:
953 symb = yes
954 elif not has_no_default[this_arg]:
955 symb = maybe
956 else:
957 symb = no
958 out += " " + symb + " " + this_arg + "\n"
960 # Print the string to screen
961 print(out)
962 return
964 # This is a "synonym" method so that old calls to update() still work
965 def update(self, *args):
966 self.construct(*args)
969class AgentType(Model):
970 """
971 A superclass for economic agents in the HARK framework. Each model should
972 specify its own subclass of AgentType, inheriting its methods and overwriting
973 as necessary. Critically, every subclass of AgentType should define class-
974 specific static values of the attributes time_vary and time_inv as lists of
975 strings. Each element of time_vary is the name of a field in AgentSubType
976 that varies over time in the model. Each element of time_inv is the name of
977 a field in AgentSubType that is constant over time in the model.
979 Parameters
980 ----------
981 solution_terminal : Solution
982 A representation of the solution to the terminal period problem of
983 this AgentType instance, or an initial guess of the solution if this
984 is an infinite horizon problem.
985 cycles : int
986 The number of times the sequence of periods is experienced by this
987 AgentType in their "lifetime". cycles=1 corresponds to a lifecycle
988 model, with a certain sequence of one period problems experienced
989 once before terminating. cycles=0 corresponds to an infinite horizon
990 model, with a sequence of one period problems repeating indefinitely.
991 pseudo_terminal : bool
992 Indicates whether solution_terminal isn't actually part of the
993 solution to the problem (as a known solution to the terminal period
994 problem), but instead represents a "scrap value"-style termination.
995 When True, solution_terminal is not included in the solution; when
996 False, solution_terminal is the last element of the solution.
997 tolerance : float
998 Maximum acceptable "distance" between successive solutions to the
999 one period problem in an infinite horizon (cycles=0) model in order
1000 for the solution to be considered as having "converged". Inoperative
1001 when cycles>0.
1002 verbose : int
1003 Level of output to be displayed by this instance, default is 1.
1004 quiet : bool
1005 Indicator for whether this instance should operate "quietly", default False.
1006 seed : int
1007 A seed for this instance's random number generator.
1008 construct : bool
1009 Indicator for whether this instance's construct() method should be run
1010 when initialized (default True). When False, an instance of the class
1011 can be created even if not all of its attributes can be constructed.
1012 use_defaults : bool
1013 Indicator for whether this instance should use the values in the class'
1014 default dictionary to fill in parameters and constructors for those not
1015 provided by the user (default True). Setting this to False is useful for
1016 situations where the user wants to be absolutely sure that they know what
1017 is being passed to the class initializer, without resorting to defaults.
1019 Attributes
1020 ----------
1021 AgentCount : int
1022 The number of agents of this type to use in simulation.
1024 state_vars : list of string
1025 The string labels for this AgentType's model state variables.
1026 """
1028 time_vary_ = []
1029 time_inv_ = []
1030 shock_vars_ = []
1031 state_vars = []
1032 poststate_vars = []
1033 distributions = []
1034 default_ = {"params": {}, "solver": NullFunc()}
1036 def __init__(
1037 self,
1038 solution_terminal=None,
1039 pseudo_terminal=True,
1040 tolerance=0.000001,
1041 verbose=1,
1042 quiet=False,
1043 seed=0,
1044 construct=True,
1045 use_defaults=True,
1046 **kwds,
1047 ):
1048 super().__init__()
1049 params = deepcopy(self.default_["params"]) if use_defaults else {}
1050 params.update(kwds)
1052 # Correctly handle constructors that have been passed in kwds
1053 if "constructors" in self.default_["params"].keys() and use_defaults:
1054 constructors = deepcopy(self.default_["params"]["constructors"])
1055 else:
1056 constructors = {}
1057 if "constructors" in kwds.keys():
1058 constructors.update(kwds["constructors"])
1059 params["constructors"] = constructors
1061 # Set model file name if possible
1062 try:
1063 self.model_file = copy(self.default_["model"])
1064 except (KeyError, TypeError):
1065 # Fallback to None if "model" key is missing or invalid for copying
1066 self.model_file = None
1068 if solution_terminal is None:
1069 solution_terminal = NullFunc()
1071 self.solve_one_period = self.default_["solver"] # NOQA
1072 self.solution_terminal = solution_terminal # NOQA
1073 self.pseudo_terminal = pseudo_terminal # NOQA
1074 self.tolerance = tolerance # NOQA
1075 self.verbose = verbose
1076 self.quiet = quiet
1077 self.seed = seed # NOQA
1078 self.track_vars = [] # NOQA
1079 self.state_now = {sv: None for sv in self.state_vars}
1080 self.state_prev = self.state_now.copy()
1081 self.controls = {}
1082 self.shocks = {}
1083 self.read_shocks = False # NOQA
1084 self.shock_history = {}
1085 self.newborn_init_history = {}
1086 self.history = {}
1087 self.assign_parameters(**params) # NOQA
1088 self.reset_rng() # NOQA
1089 self.bilt = {}
1090 if construct:
1091 self.construct()
1093 # Add instance-level lists and objects
1094 self.time_vary = deepcopy(self.time_vary_)
1095 self.time_inv = deepcopy(self.time_inv_)
1096 self.shock_vars = deepcopy(self.shock_vars_)
1098 def add_to_time_vary(self, *params):
1099 """
1100 Adds any number of parameters to time_vary for this instance.
1102 Parameters
1103 ----------
1104 params : string
1105 Any number of strings naming attributes to be added to time_vary
1107 Returns
1108 -------
1109 None
1110 """
1111 for param in params:
1112 if param not in self.time_vary:
1113 self.time_vary.append(param)
1115 def add_to_time_inv(self, *params):
1116 """
1117 Adds any number of parameters to time_inv for this instance.
1119 Parameters
1120 ----------
1121 params : string
1122 Any number of strings naming attributes to be added to time_inv
1124 Returns
1125 -------
1126 None
1127 """
1128 for param in params:
1129 if param not in self.time_inv:
1130 self.time_inv.append(param)
1132 def del_from_time_vary(self, *params):
1133 """
1134 Removes any number of parameters from time_vary for this instance.
1136 Parameters
1137 ----------
1138 params : string
1139 Any number of strings naming attributes to be removed from time_vary
1141 Returns
1142 -------
1143 None
1144 """
1145 for param in params:
1146 if param in self.time_vary:
1147 self.time_vary.remove(param)
1149 def del_from_time_inv(self, *params):
1150 """
1151 Removes any number of parameters from time_inv for this instance.
1153 Parameters
1154 ----------
1155 params : string
1156 Any number of strings naming attributes to be removed from time_inv
1158 Returns
1159 -------
1160 None
1161 """
1162 for param in params:
1163 if param in self.time_inv:
1164 self.time_inv.remove(param)
1166 def unpack(self, parameter):
1167 """
1168 Unpacks an attribute from a solution object for easier access.
1169 After the model has been solved, its components (like consumption function)
1170 reside in the attributes of each element of `ThisType.solution` (e.g. `cFunc`).
1171 This method creates a (time varying) attribute of the given attribute name
1172 that contains a list of elements accessible by `ThisType.parameter`.
1174 Parameters
1175 ----------
1176 parameter: str
1177 Name of the attribute to unpack from the solution
1179 Returns
1180 -------
1181 none
1182 """
1183 # Use list comprehension for better performance instead of loop with append
1184 setattr(
1185 self,
1186 parameter,
1187 [solution_t.__dict__[parameter] for solution_t in self.solution],
1188 )
1189 self.add_to_time_vary(parameter)
1191 def solve(
1192 self,
1193 verbose=False,
1194 presolve=True,
1195 postsolve=True,
1196 from_solution=None,
1197 from_t=None,
1198 ):
1199 """
1200 Solve the model for this instance of an agent type by backward induction.
1201 Loops through the sequence of one period problems, passing the solution
1202 from period t+1 to the problem for period t.
1204 Parameters
1205 ----------
1206 verbose : bool, optional
1207 If True, solution progress is printed to screen. Default False.
1208 presolve : bool, optional
1209 If True (default), the pre_solve method is run before solving.
1210 postsolve : bool, optional
1211 If True (default), the post_solve method is run after solving.
1212 from_solution: Solution
1213 If different from None, will be used as the starting point of backward
1214 induction, instead of self.solution_terminal.
1215 from_t : int or None
1216 If not None, indicates which period of the model the solver should start
1217 from. It should usually only be used in combination with from_solution.
1218 Stands for the time index that from_solution represents, and thus is
1219 only compatible with cycles=1 and will be reset to None otherwise.
1221 Returns
1222 -------
1223 none
1224 """
1226 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep-
1227 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is
1228 # -np.inf, np.inf/np.inf is np.nan and so on.
1229 with np.errstate(
1230 divide="ignore", over="ignore", under="ignore", invalid="ignore"
1231 ):
1232 if presolve:
1233 self.pre_solve() # Do pre-solution stuff
1234 self.solution = solve_agent(
1235 self,
1236 verbose,
1237 from_solution,
1238 from_t,
1239 ) # Solve the model by backward induction
1240 if postsolve:
1241 self.post_solve() # Do post-solution stuff
1243 def reset_rng(self):
1244 """
1245 Reset the random number generator and all distributions for this type.
1246 Type-checking for lists is to handle the following three cases:
1248 1) The target is a single distribution object
1249 2) The target is a list of distribution objects (probably time-varying)
1250 3) The target is a nested list of distributions, as in ConsMarkovModel.
1251 """
1252 self.RNG = np.random.default_rng(self.seed)
1253 for name in self.distributions:
1254 if not hasattr(self, name):
1255 continue
1257 dstn = getattr(self, name)
1258 if isinstance(dstn, list):
1259 for D in dstn:
1260 if isinstance(D, list):
1261 for d in D:
1262 d.reset()
1263 else:
1264 D.reset()
1265 else:
1266 dstn.reset()
1268 def check_elements_of_time_vary_are_lists(self):
1269 """
1270 A method to check that elements of time_vary are lists.
1271 """
1272 for param in self.time_vary:
1273 if not hasattr(self, param):
1274 continue
1275 if not isinstance(
1276 getattr(self, param),
1277 (IndexDistribution,),
1278 ):
1279 assert type(getattr(self, param)) == list, (
1280 param
1281 + " is not a list or time varying distribution,"
1282 + " but should be because it is in time_vary"
1283 )
1285 def check_restrictions(self):
1286 """
1287 A method to check that various restrictions are met for the model class.
1288 """
1289 return
1291 def pre_solve(self):
1292 """
1293 A method that is run immediately before the model is solved, to check inputs or to prepare
1294 the terminal solution, perhaps.
1296 Parameters
1297 ----------
1298 none
1300 Returns
1301 -------
1302 none
1303 """
1304 self.check_restrictions()
1305 self.check_elements_of_time_vary_are_lists()
1306 return None
1308 def post_solve(self):
1309 """
1310 A method that is run immediately after the model is solved, to finalize
1311 the solution in some way. Does nothing here.
1313 Parameters
1314 ----------
1315 none
1317 Returns
1318 -------
1319 none
1320 """
1321 return None
1323 def initialize_sym(self, **kwargs):
1324 """
1325 Use the new simulator structure to build a simulator from the agents'
1326 attributes, storing it in a private attribute.
1327 """
1328 self.reset_rng() # ensure seeds are set identically each time
1329 self._simulator = make_simulator_from_agent(self, **kwargs)
1330 self._simulator.reset()
1332 def initialize_sim(self):
1333 """
1334 Prepares this AgentType for a new simulation. Resets the internal random number generator,
1335 makes initial states for all agents (using sim_birth), clears histories of tracked variables.
1337 Parameters
1338 ----------
1339 None
1341 Returns
1342 -------
1343 None
1344 """
1345 if not hasattr(self, "T_sim"):
1346 raise Exception(
1347 "To initialize simulation variables it is necessary to first "
1348 + "set the attribute T_sim to the largest number of observations "
1349 + "you plan to simulate for each agent including re-births."
1350 )
1351 elif self.T_sim <= 0:
1352 raise Exception(
1353 "T_sim represents the largest number of observations "
1354 + "that can be simulated for an agent, and must be a positive number."
1355 )
1357 self.reset_rng()
1358 self.t_sim = 0
1359 all_agents = np.ones(self.AgentCount, dtype=bool)
1360 blank_array = np.empty(self.AgentCount)
1361 blank_array[:] = np.nan
1362 for var in self.state_vars:
1363 self.state_now[var] = copy(blank_array)
1365 # Number of periods since agent entry
1366 self.t_age = np.zeros(self.AgentCount, dtype=int)
1367 # Which cycle period each agent is on
1368 self.t_cycle = np.zeros(self.AgentCount, dtype=int)
1369 self.sim_birth(all_agents)
1371 # If we are asked to use existing shocks and a set of initial conditions
1372 # exist, use them
1373 if self.read_shocks and bool(self.newborn_init_history):
1374 for var_name in self.state_now:
1375 # Check that we are actually given a value for the variable
1376 if var_name in self.newborn_init_history.keys():
1377 # Copy only array-like idiosyncratic states. Aggregates should
1378 # not be set by newborns
1379 idio = (
1380 isinstance(self.state_now[var_name], np.ndarray)
1381 and len(self.state_now[var_name]) == self.AgentCount
1382 )
1383 if idio:
1384 self.state_now[var_name] = self.newborn_init_history[var_name][
1385 0
1386 ]
1388 else:
1389 warn(
1390 "The option for reading shocks was activated but "
1391 + "the model requires state "
1392 + var_name
1393 + ", not contained in "
1394 + "newborn_init_history."
1395 )
1397 self.clear_history()
1398 return None
1400 def sim_one_period(self):
1401 """
1402 Simulates one period for this type. Calls the methods get_mortality(), get_shocks() or
1403 read_shocks, get_states(), get_controls(), and get_poststates(). These should be defined for
1404 AgentType subclasses, except get_mortality (define its components sim_death and sim_birth
1405 instead) and read_shocks.
1407 Parameters
1408 ----------
1409 None
1411 Returns
1412 -------
1413 None
1414 """
1415 if not hasattr(self, "solution"):
1416 raise Exception(
1417 "Model instance does not have a solution stored. To simulate, it is necessary"
1418 " to run the `solve()` method first."
1419 )
1421 # Mortality adjusts the agent population
1422 self.get_mortality() # Replace some agents with "newborns"
1424 # state_{t-1}
1425 for var in self.state_now:
1426 self.state_prev[var] = self.state_now[var]
1428 if isinstance(self.state_now[var], np.ndarray):
1429 self.state_now[var] = np.empty(self.AgentCount)
1430 else:
1431 # Probably an aggregate variable. It may be getting set by the Market.
1432 pass
1434 if self.read_shocks: # If shock histories have been pre-specified, use those
1435 self.read_shocks_from_history()
1436 else: # Otherwise, draw shocks as usual according to subclass-specific method
1437 self.get_shocks()
1438 self.get_states() # Determine each agent's state at decision time
1439 self.get_controls() # Determine each agent's choice or control variables based on states
1440 self.get_poststates() # Calculate variables that come *after* decision-time
1442 # Advance time for all agents
1443 self.t_age = self.t_age + 1 # Age all consumers by one period
1444 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle
1445 self.t_cycle[self.t_cycle == self.T_cycle] = (
1446 0 # Resetting to zero for those who have reached the end
1447 )
1449 def make_shock_history(self):
1450 """
1451 Makes a pre-specified history of shocks for the simulation. Shock variables should be named
1452 in self.shock_vars, a list of strings that is subclass-specific. This method runs a subset
1453 of the standard simulation loop by simulating only mortality and shocks; each variable named
1454 in shock_vars is stored in a T_sim x AgentCount array in history dictionary self.history[X].
1455 Automatically sets self.read_shocks to True so that these pre-specified shocks are used for
1456 all subsequent calls to simulate().
1458 Parameters
1459 ----------
1460 None
1462 Returns
1463 -------
1464 None
1465 """
1466 # Re-initialize the simulation
1467 self.initialize_sim()
1469 # Make blank history arrays for each shock variable (and mortality)
1470 for var_name in self.shock_vars:
1471 self.shock_history[var_name] = (
1472 np.zeros((self.T_sim, self.AgentCount)) + np.nan
1473 )
1474 self.shock_history["who_dies"] = np.zeros(
1475 (self.T_sim, self.AgentCount), dtype=bool
1476 )
1478 # Also make blank arrays for the draws of newborns' initial conditions
1479 for var_name in self.state_vars:
1480 self.newborn_init_history[var_name] = (
1481 np.zeros((self.T_sim, self.AgentCount)) + np.nan
1482 )
1484 # Record the initial condition of the newborns created by
1485 # initialize_sim -> sim_births
1486 for var_name in self.state_vars:
1487 # Check whether the state is idiosyncratic or an aggregate
1488 idio = (
1489 isinstance(self.state_now[var_name], np.ndarray)
1490 and len(self.state_now[var_name]) == self.AgentCount
1491 )
1492 if idio:
1493 self.newborn_init_history[var_name][self.t_sim] = self.state_now[
1494 var_name
1495 ]
1496 else:
1497 # Aggregate state is a scalar. Assign it to every agent.
1498 self.newborn_init_history[var_name][self.t_sim, :] = self.state_now[
1499 var_name
1500 ]
1502 # Make and store the history of shocks for each period
1503 for t in range(self.T_sim):
1504 # Deaths
1505 self.get_mortality()
1506 self.shock_history["who_dies"][t, :] = self.who_dies
1508 # Initial conditions of newborns
1509 if self.who_dies.any():
1510 for var_name in self.state_vars:
1511 # Check whether the state is idiosyncratic or an aggregate
1512 idio = (
1513 isinstance(self.state_now[var_name], np.ndarray)
1514 and len(self.state_now[var_name]) == self.AgentCount
1515 )
1516 if idio:
1517 self.newborn_init_history[var_name][t, self.who_dies] = (
1518 self.state_now[var_name][self.who_dies]
1519 )
1520 else:
1521 self.newborn_init_history[var_name][t, self.who_dies] = (
1522 self.state_now[var_name]
1523 )
1525 # Other Shocks
1526 self.get_shocks()
1527 for var_name in self.shock_vars:
1528 self.shock_history[var_name][t, :] = self.shocks[var_name]
1530 self.t_sim += 1
1531 self.t_age = self.t_age + 1 # Age all consumers by one period
1532 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle
1533 self.t_cycle[self.t_cycle == self.T_cycle] = (
1534 0 # Resetting to zero for those who have reached the end
1535 )
1537 # Flag that shocks can be read rather than simulated
1538 self.read_shocks = True
1540 def get_mortality(self):
1541 """
1542 Simulates mortality or agent turnover according to some model-specific rules named sim_death
1543 and sim_birth (methods of an AgentType subclass). sim_death takes no arguments and returns
1544 a Boolean array of size AgentCount, indicating which agents of this type have "died" and
1545 must be replaced. sim_birth takes such a Boolean array as an argument and generates initial
1546 post-decision states for those agent indices.
1548 Parameters
1549 ----------
1550 None
1552 Returns
1553 -------
1554 None
1555 """
1556 if self.read_shocks:
1557 who_dies = self.shock_history["who_dies"][self.t_sim, :]
1558 # Instead of simulating births, assign the saved newborn initial conditions
1559 if who_dies.any():
1560 for var_name in self.state_now:
1561 if var_name in self.newborn_init_history.keys():
1562 # Copy only array-like idiosyncratic states. Aggregates should
1563 # not be set by newborns
1564 idio = (
1565 isinstance(self.state_now[var_name], np.ndarray)
1566 and len(self.state_now[var_name]) == self.AgentCount
1567 )
1568 if idio:
1569 self.state_now[var_name][who_dies] = (
1570 self.newborn_init_history[var_name][
1571 self.t_sim, who_dies
1572 ]
1573 )
1575 else:
1576 warn(
1577 "The option for reading shocks was activated but "
1578 + "the model requires state "
1579 + var_name
1580 + ", not contained in "
1581 + "newborn_init_history."
1582 )
1584 # Reset ages of newborns
1585 self.t_age[who_dies] = 0
1586 self.t_cycle[who_dies] = 0
1587 else:
1588 who_dies = self.sim_death()
1589 self.sim_birth(who_dies)
1590 self.who_dies = who_dies
1591 return None
1593 def sim_death(self):
1594 """
1595 Determines which agents in the current population "die" or should be replaced. Takes no
1596 inputs, returns a Boolean array of size self.AgentCount, which has True for agents who die
1597 and False for those that survive. Returns all False by default, must be overwritten by a
1598 subclass to have replacement events.
1600 Parameters
1601 ----------
1602 None
1604 Returns
1605 -------
1606 who_dies : np.array
1607 Boolean array of size self.AgentCount indicating which agents die and are replaced.
1608 """
1609 who_dies = np.zeros(self.AgentCount, dtype=bool)
1610 return who_dies
1612 def sim_birth(self, which_agents): # pragma: nocover
1613 """
1614 Makes new agents for the simulation. Takes a boolean array as an input, indicating which
1615 agent indices are to be "born". Does nothing by default, must be overwritten by a subclass.
1617 Parameters
1618 ----------
1619 which_agents : np.array(Bool)
1620 Boolean array of size self.AgentCount indicating which agents should be "born".
1622 Returns
1623 -------
1624 None
1625 """
1626 raise Exception("AgentType subclass must define method sim_birth!")
1628 def get_shocks(self): # pragma: nocover
1629 """
1630 Gets values of shock variables for the current period. Does nothing by default, but can
1631 be overwritten by subclasses of AgentType.
1633 Parameters
1634 ----------
1635 None
1637 Returns
1638 -------
1639 None
1640 """
1641 return None
1643 def read_shocks_from_history(self):
1644 """
1645 Reads values of shock variables for the current period from history arrays.
1646 For each variable X named in self.shock_vars, this attribute of self is
1647 set to self.history[X][self.t_sim,:].
1649 This method is only ever called if self.read_shocks is True. This can
1650 be achieved by using the method make_shock_history() (or manually after
1651 storing a "handcrafted" shock history).
1653 Parameters
1654 ----------
1655 None
1657 Returns
1658 -------
1659 None
1660 """
1661 for var_name in self.shock_vars:
1662 self.shocks[var_name] = self.shock_history[var_name][self.t_sim, :]
1664 def get_states(self):
1665 """
1666 Gets values of state variables for the current period.
1667 By default, calls transition function and assigns values
1668 to the state_now dictionary.
1670 Parameters
1671 ----------
1672 None
1674 Returns
1675 -------
1676 None
1677 """
1678 new_states = self.transition()
1680 for i, var in enumerate(self.state_now):
1681 # a hack for now to deal with 'post-states'
1682 if i < len(new_states):
1683 self.state_now[var] = new_states[i]
1685 def transition(self): # pragma: nocover
1686 """
1688 Parameters
1689 ----------
1690 None
1692 [Eventually, to match dolo spec:
1693 exogenous_prev, endogenous_prev, controls, exogenous, parameters]
1695 Returns
1696 -------
1698 endogenous_state: ()
1699 Tuple with new values of the endogenous states
1700 """
1701 return ()
1703 def get_controls(self): # pragma: nocover
1704 """
1705 Gets values of control variables for the current period, probably by using current states.
1706 Does nothing by default, but can be overwritten by subclasses of AgentType.
1708 Parameters
1709 ----------
1710 None
1712 Returns
1713 -------
1714 None
1715 """
1716 return None
1718 def get_poststates(self):
1719 """
1720 Gets values of post-decision state variables for the current period,
1721 probably by current
1722 states and controls and maybe market-level events or shock variables.
1723 Does nothing by
1724 default, but can be overwritten by subclasses of AgentType.
1726 Parameters
1727 ----------
1728 None
1730 Returns
1731 -------
1732 None
1733 """
1734 return None
1736 def symulate(self, T=None):
1737 """
1738 Run the new simulation structure, with history results written to the
1739 hystory attribute of self.
1740 """
1741 self._simulator.simulate(T)
1742 self.hystory = self._simulator.history
1744 def describe_model(self, display=True):
1745 """
1746 Print to screen information about this agent's model, based on its model
1747 file. This is useful for learning about outcome variable names for tracking
1748 during simulation, or for use with sequence space Jacobians.
1749 """
1750 if not hasattr(self, "_simulator"):
1751 self.initialize_sym()
1752 self._simulator.describe(display=display)
1754 def simulate(self, sim_periods=None):
1755 """
1756 Simulates this agent type for a given number of periods. Defaults to self.T_sim,
1757 or all remaining periods to simulate (T_sim - t_sim). Records histories of
1758 attributes named in self.track_vars in self.history[varname].
1760 Parameters
1761 ----------
1762 sim_periods : int or None
1763 Number of periods to simulate. Default is all remaining periods (usually T_sim).
1765 Returns
1766 -------
1767 history : dict
1768 The history tracked during the simulation.
1769 """
1770 if not hasattr(self, "t_sim"):
1771 raise Exception(
1772 "It seems that the simulation variables were not initialize before calling "
1773 + "simulate(). Call initialize_sim() to initialize the variables before calling simulate() again."
1774 )
1776 if not hasattr(self, "T_sim"):
1777 raise Exception(
1778 "This agent type instance must have the attribute T_sim set to a positive integer."
1779 + "Set T_sim to match the largest dataset you might simulate, and run this agent's"
1780 + "initialize_sim() method before running simulate() again."
1781 )
1783 if sim_periods is not None and self.T_sim < sim_periods:
1784 raise Exception(
1785 "To simulate, sim_periods has to be larger than the maximum data set size "
1786 + "T_sim. Either increase the attribute T_sim of this agent type instance "
1787 + "and call the initialize_sim() method again, or set sim_periods <= T_sim."
1788 )
1790 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep-
1791 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is
1792 # -np.inf, np.inf/np.inf is np.nan and so on.
1793 with np.errstate(
1794 divide="ignore", over="ignore", under="ignore", invalid="ignore"
1795 ):
1796 if sim_periods is None:
1797 sim_periods = self.T_sim - self.t_sim
1799 for t in range(sim_periods):
1800 self.sim_one_period()
1802 for var_name in self.track_vars:
1803 if var_name in self.state_now:
1804 self.history[var_name][self.t_sim, :] = self.state_now[var_name]
1805 elif var_name in self.shocks:
1806 self.history[var_name][self.t_sim, :] = self.shocks[var_name]
1807 elif var_name in self.controls:
1808 self.history[var_name][self.t_sim, :] = self.controls[var_name]
1809 else:
1810 if var_name == "who_dies" and self.t_sim > 1:
1811 self.history[var_name][self.t_sim - 1, :] = getattr(
1812 self, var_name
1813 )
1814 else:
1815 self.history[var_name][self.t_sim, :] = getattr(
1816 self, var_name
1817 )
1818 self.t_sim += 1
1820 def clear_history(self):
1821 """
1822 Clears the histories of the attributes named in self.track_vars.
1824 Parameters
1825 ----------
1826 None
1828 Returns
1829 -------
1830 None
1831 """
1832 for var_name in self.track_vars:
1833 self.history[var_name] = np.empty((self.T_sim, self.AgentCount))
1834 self.history[var_name].fill(np.nan)
1836 def make_basic_SSJ(self, shock, outcomes, grids, **kwargs):
1837 """
1838 Construct and return sequence space Jacobian matrices for specified outcomes
1839 with respect to specified "shock" variable. This "basic" method only works
1840 for "one period infinite horizon" models (cycles=0, T_cycle=1). See documen-
1841 tation for simulator.make_basic_SSJ_matrices for more information.
1842 """
1843 return make_basic_SSJ_matrices(self, shock, outcomes, grids, **kwargs)
1845 def calc_impulse_response_manually(self, shock, outcomes, grids, **kwargs):
1846 """
1847 Calculate and return the impulse response(s) of a perturbation to the shock
1848 parameter in period t=s, essentially computing one column of the sequence
1849 space Jacobian matrix manually. This "basic" method only works for "one
1850 period infinite horizon" models (cycles=0, T_cycle=1). See documentation
1851 for simulator.calc_shock_response_manually for more information.
1852 """
1853 return calc_shock_response_manually(self, shock, outcomes, grids, **kwargs)
1856def solve_agent(agent, verbose, from_solution=None, from_t=None):
1857 """
1858 Solve the dynamic model for one agent type using backwards induction. This
1859 function iterates on "cycles" of an agent's model either a given number of
1860 times or until solution convergence if an infinite horizon model is used
1861 (with agent.cycles = 0).
1863 Parameters
1864 ----------
1865 agent : AgentType
1866 The microeconomic AgentType whose dynamic problem
1867 is to be solved.
1868 verbose : boolean
1869 If True, solution progress is printed to screen (when cycles != 1).
1870 from_solution: Solution
1871 If different from None, will be used as the starting point of backward
1872 induction, instead of self.solution_terminal
1873 from_t : int or None
1874 If not None, indicates which period of the model the solver should start
1875 from. It should usually only be used in combination with from_solution.
1876 Stands for the time index that from_solution represents, and thus is
1877 only compatible with cycles=1 and will be reset to None otherwise.
1879 Returns
1880 -------
1881 solution : [Solution]
1882 A list of solutions to the one period problems that the agent will
1883 encounter in his "lifetime".
1884 """
1885 # Check to see whether this is an (in)finite horizon problem
1886 cycles_left = agent.cycles # NOQA
1887 infinite_horizon = cycles_left == 0 # NOQA
1889 if from_solution is None:
1890 solution_last = agent.solution_terminal # NOQA
1891 else:
1892 solution_last = from_solution
1893 if agent.cycles != 1:
1894 from_t = None
1896 # Initialize the solution, which includes the terminal solution if it's not a pseudo-terminal period
1897 solution = []
1898 if not agent.pseudo_terminal:
1899 solution.insert(0, deepcopy(solution_last))
1901 # Initialize the process, then loop over cycles
1902 go = True # NOQA
1903 completed_cycles = 0 # NOQA
1904 max_cycles = 5000 # NOQA - escape clause
1905 if verbose:
1906 t_last = time()
1907 while go:
1908 # Solve a cycle of the model, recording it if horizon is finite
1909 solution_cycle = solve_one_cycle(agent, solution_last, from_t)
1910 if not infinite_horizon:
1911 solution = solution_cycle + solution
1913 # Check for termination: identical solutions across
1914 # cycle iterations or run out of cycles
1915 solution_now = solution_cycle[0]
1916 if infinite_horizon:
1917 if completed_cycles > 0:
1918 solution_distance = solution_now.distance(solution_last)
1919 agent.solution_distance = (
1920 solution_distance # Add these attributes so users can
1921 )
1922 agent.completed_cycles = (
1923 completed_cycles # query them to see if solution is ready
1924 )
1925 go = (
1926 solution_distance > agent.tolerance
1927 and completed_cycles < max_cycles
1928 )
1929 else: # Assume solution does not converge after only one cycle
1930 solution_distance = 100.0
1931 go = True
1932 else:
1933 cycles_left += -1
1934 go = cycles_left > 0
1936 # Update the "last period solution"
1937 solution_last = solution_now
1938 completed_cycles += 1
1940 # Display progress if requested
1941 if verbose:
1942 t_now = time()
1943 if infinite_horizon:
1944 print(
1945 "Finished cycle #"
1946 + str(completed_cycles)
1947 + " in "
1948 + str(t_now - t_last)
1949 + " seconds, solution distance = "
1950 + str(solution_distance)
1951 )
1952 else:
1953 print(
1954 "Finished cycle #"
1955 + str(completed_cycles)
1956 + " of "
1957 + str(agent.cycles)
1958 + " in "
1959 + str(t_now - t_last)
1960 + " seconds."
1961 )
1962 t_last = t_now
1964 # Record the last cycle if horizon is infinite (solution is still empty!)
1965 if infinite_horizon:
1966 solution = (
1967 solution_cycle # PseudoTerminal=False impossible for infinite horizon
1968 )
1970 return solution
1973def solve_one_cycle(agent, solution_last, from_t):
1974 """
1975 Solve one "cycle" of the dynamic model for one agent type. This function
1976 iterates over the periods within an agent's cycle, updating the time-varying
1977 parameters and passing them to the single period solver(s).
1979 Parameters
1980 ----------
1981 agent : AgentType
1982 The microeconomic AgentType whose dynamic problem is to be solved.
1983 solution_last : Solution
1984 A representation of the solution of the period that comes after the
1985 end of the sequence of one period problems. This might be the term-
1986 inal period solution, a "pseudo terminal" solution, or simply the
1987 solution to the earliest period from the succeeding cycle.
1988 from_t : int or None
1989 If not None, indicates which period of the model the solver should start
1990 from. When used, represents the time index that solution_last is from.
1992 Returns
1993 -------
1994 solution_cycle : [Solution]
1995 A list of one period solutions for one "cycle" of the AgentType's
1996 microeconomic model.
1997 """
1999 # Check if the agent has a 'Parameters' attribute of the 'Parameters' class
2000 # if so, take advantage of it. Else, use the old method
2001 if hasattr(agent, "parameters") and isinstance(agent.parameters, Parameters):
2002 T = agent.parameters._length if from_t is None else from_t
2004 # Initialize the solution for this cycle, then iterate on periods
2005 solution_cycle = []
2006 solution_next = solution_last
2008 cycles_range = [0] + list(range(T - 1, 0, -1))
2009 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range:
2010 # Update which single period solver to use (if it depends on time)
2011 if hasattr(agent.solve_one_period, "__getitem__"):
2012 solve_one_period = agent.solve_one_period[k]
2013 else:
2014 solve_one_period = agent.solve_one_period
2016 if hasattr(solve_one_period, "solver_args"):
2017 these_args = solve_one_period.solver_args
2018 else:
2019 these_args = get_arg_names(solve_one_period)
2021 # Make a temporary dictionary for this period
2022 temp_pars = agent.parameters[k]
2023 temp_dict = {
2024 name: solution_next if name == "solution_next" else temp_pars[name]
2025 for name in these_args
2026 }
2028 # Solve one period, add it to the solution, and move to the next period
2029 solution_t = solve_one_period(**temp_dict)
2030 solution_cycle.insert(0, solution_t)
2031 solution_next = solution_t
2033 else:
2034 # Calculate number of periods per cycle, defaults to 1 if all variables are time invariant
2035 if len(agent.time_vary) > 0:
2036 T = agent.T_cycle if from_t is None else from_t
2037 else:
2038 T = 1
2040 solve_dict = {
2041 parameter: agent.__dict__[parameter] for parameter in agent.time_inv
2042 }
2043 solve_dict.update({parameter: None for parameter in agent.time_vary})
2045 # Initialize the solution for this cycle, then iterate on periods
2046 solution_cycle = []
2047 solution_next = solution_last
2049 cycles_range = [0] + list(range(T - 1, 0, -1))
2050 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range:
2051 # Update which single period solver to use (if it depends on time)
2052 if hasattr(agent.solve_one_period, "__getitem__"):
2053 solve_one_period = agent.solve_one_period[k]
2054 else:
2055 solve_one_period = agent.solve_one_period
2057 if hasattr(solve_one_period, "solver_args"):
2058 these_args = solve_one_period.solver_args
2059 else:
2060 these_args = get_arg_names(solve_one_period)
2062 # Update time-varying single period inputs
2063 for name in agent.time_vary:
2064 if name in these_args:
2065 solve_dict[name] = agent.__dict__[name][k]
2066 solve_dict["solution_next"] = solution_next
2068 # Make a temporary dictionary for this period
2069 temp_dict = {name: solve_dict[name] for name in these_args}
2071 # Solve one period, add it to the solution, and move to the next period
2072 solution_t = solve_one_period(**temp_dict)
2073 solution_cycle.insert(0, solution_t)
2074 solution_next = solution_t
2076 # Return the list of per-period solutions
2077 return solution_cycle
2080def make_one_period_oo_solver(solver_class):
2081 """
2082 Returns a function that solves a single period consumption-saving
2083 problem.
2084 Parameters
2085 ----------
2086 solver_class : Solver
2087 A class of Solver to be used.
2088 -------
2089 solver_function : function
2090 A function for solving one period of a problem.
2091 """
2093 def one_period_solver(**kwds):
2094 solver = solver_class(**kwds)
2096 # not ideal; better if this is defined in all Solver classes
2097 if hasattr(solver, "prepare_to_solve"):
2098 solver.prepare_to_solve()
2100 solution_now = solver.solve()
2101 return solution_now
2103 one_period_solver.solver_class = solver_class
2104 # This can be revisited once it is possible to export parameters
2105 one_period_solver.solver_args = get_arg_names(solver_class.__init__)[1:]
2107 return one_period_solver
2110# ========================================================================
2111# ========================================================================
2114class Market(Model):
2115 """
2116 A superclass to represent a central clearinghouse of information. Used for
2117 dynamic general equilibrium models to solve the "macroeconomic" model as a
2118 layer on top of the "microeconomic" models of one or more AgentTypes.
2120 Parameters
2121 ----------
2122 agents : [AgentType]
2123 A list of all the AgentTypes in this market.
2124 sow_vars : [string]
2125 Names of variables generated by the "aggregate market process" that should
2126 be "sown" to the agents in the market. Aggregate state, etc.
2127 reap_vars : [string]
2128 Names of variables to be collected ("reaped") from agents in the market
2129 to be used in the "aggregate market process".
2130 const_vars : [string]
2131 Names of attributes of the Market instance that are used in the "aggregate
2132 market process" but do not come from agents-- they are constant or simply
2133 parameters inherent to the process.
2134 track_vars : [string]
2135 Names of variables generated by the "aggregate market process" that should
2136 be tracked as a "history" so that a new dynamic rule can be calculated.
2137 This is often a subset of sow_vars.
2138 dyn_vars : [string]
2139 Names of variables that constitute a "dynamic rule".
2140 mill_rule : function
2141 A function that takes inputs named in reap_vars and returns a tuple the
2142 same size and order as sow_vars. The "aggregate market process" that
2143 transforms individual agent actions/states/data into aggregate data to
2144 be sent back to agents.
2145 calc_dynamics : function
2146 A function that takes inputs named in track_vars and returns an object
2147 with attributes named in dyn_vars. Looks at histories of aggregate
2148 variables and generates a new "dynamic rule" for agents to believe and
2149 act on.
2150 act_T : int
2151 The number of times that the "aggregate market process" should be run
2152 in order to generate a history of aggregate variables.
2153 tolerance: float
2154 Minimum acceptable distance between "dynamic rules" to consider the
2155 Market solution process converged. Distance is a user-defined metric.
2156 """
2158 def __init__(
2159 self,
2160 agents=None,
2161 sow_vars=None,
2162 reap_vars=None,
2163 const_vars=None,
2164 track_vars=None,
2165 dyn_vars=None,
2166 mill_rule=None,
2167 calc_dynamics=None,
2168 act_T=1000,
2169 tolerance=0.000001,
2170 **kwds,
2171 ):
2172 super().__init__()
2173 self.agents = agents if agents is not None else list() # NOQA
2175 self.reap_vars = reap_vars if reap_vars is not None else list() # NOQA
2176 self.reap_state = {var: [] for var in self.reap_vars}
2178 self.sow_vars = sow_vars if sow_vars is not None else list() # NOQA
2179 # dictionaries for tracking initial and current values
2180 # of the sow variables.
2181 self.sow_init = {var: None for var in self.sow_vars}
2182 self.sow_state = {var: None for var in self.sow_vars}
2184 const_vars = const_vars if const_vars is not None else list() # NOQA
2185 self.const_vars = {var: None for var in const_vars}
2187 self.track_vars = track_vars if track_vars is not None else list() # NOQA
2188 self.dyn_vars = dyn_vars if dyn_vars is not None else list() # NOQA
2190 if mill_rule is not None: # To prevent overwriting of method-based mill_rules
2191 self.mill_rule = mill_rule
2192 if calc_dynamics is not None: # Ditto for calc_dynamics
2193 self.calc_dynamics = calc_dynamics
2194 self.act_T = act_T # NOQA
2195 self.tolerance = tolerance # NOQA
2196 self.max_loops = 1000 # NOQA
2197 self.history = {}
2198 self.assign_parameters(**kwds)
2200 self.print_parallel_error_once = True
2201 # Print the error associated with calling the parallel method
2202 # "solve_agents" one time. If set to false, the error will never
2203 # print. See "solve_agents" for why this prints once or never.
2205 def solve_agents(self):
2206 """
2207 Solves the microeconomic problem for all AgentTypes in this market.
2209 Parameters
2210 ----------
2211 None
2213 Returns
2214 -------
2215 None
2216 """
2217 try:
2218 multi_thread_commands(self.agents, ["solve()"])
2219 except Exception as err:
2220 if self.print_parallel_error_once:
2221 # Set flag to False so this is only printed once.
2222 self.print_parallel_error_once = False
2223 print(
2224 "**** WARNING: could not execute multi_thread_commands in HARK.core.Market.solve_agents() ",
2225 "so using the serial version instead. This will likely be slower. "
2226 "The multi_thread_commands() functions failed with the following error:",
2227 "\n",
2228 sys.exc_info()[0],
2229 ":",
2230 err,
2231 ) # sys.exc_info()[0])
2232 multi_thread_commands_fake(self.agents, ["solve()"])
2234 def solve(self):
2235 """
2236 "Solves" the market by finding a "dynamic rule" that governs the aggregate
2237 market state such that when agents believe in these dynamics, their actions
2238 collectively generate the same dynamic rule.
2240 Parameters
2241 ----------
2242 None
2244 Returns
2245 -------
2246 None
2247 """
2248 go = True
2249 max_loops = self.max_loops # Failsafe against infinite solution loop
2250 completed_loops = 0
2251 old_dynamics = None
2253 while go: # Loop until the dynamic process converges or we hit the loop cap
2254 self.solve_agents() # Solve each AgentType's micro problem
2255 self.make_history() # "Run" the model while tracking aggregate variables
2256 new_dynamics = self.update_dynamics() # Find a new aggregate dynamic rule
2258 # Check to see if the dynamic rule has converged (if this is not the first loop)
2259 if completed_loops > 0:
2260 distance = new_dynamics.distance(old_dynamics)
2261 else:
2262 distance = 1000000.0
2264 # Move to the next loop if the terminal conditions are not met
2265 old_dynamics = new_dynamics
2266 completed_loops += 1
2267 go = distance >= self.tolerance and completed_loops < max_loops
2269 self.dynamics = new_dynamics # Store the final dynamic rule in self
2271 def reap(self):
2272 """
2273 Collects attributes named in reap_vars from each AgentType in the market,
2274 storing them in respectively named attributes of self.
2276 Parameters
2277 ----------
2278 none
2280 Returns
2281 -------
2282 none
2283 """
2284 for var in self.reap_state:
2285 harvest = []
2287 for agent in self.agents:
2288 # TODO: generalized variable lookup across namespaces
2289 if var in agent.state_now:
2290 # or state_now ??
2291 harvest.append(agent.state_now[var])
2293 self.reap_state[var] = harvest
2295 def sow(self):
2296 """
2297 Distributes attrributes named in sow_vars from self to each AgentType
2298 in the market, storing them in respectively named attributes.
2300 Parameters
2301 ----------
2302 none
2304 Returns
2305 -------
2306 none
2307 """
2308 for sow_var in self.sow_state:
2309 for this_type in self.agents:
2310 if sow_var in this_type.state_now:
2311 this_type.state_now[sow_var] = self.sow_state[sow_var]
2312 if sow_var in this_type.shocks:
2313 this_type.shocks[sow_var] = self.sow_state[sow_var]
2314 else:
2315 setattr(this_type, sow_var, self.sow_state[sow_var])
2317 def mill(self):
2318 """
2319 Processes the variables collected from agents using the function mill_rule,
2320 storing the results in attributes named in aggr_sow.
2322 Parameters
2323 ----------
2324 none
2326 Returns
2327 -------
2328 none
2329 """
2330 # Make a dictionary of inputs for the mill_rule
2331 mill_dict = copy(self.reap_state)
2332 mill_dict.update(self.const_vars)
2334 # Run the mill_rule and store its output in self
2335 product = self.mill_rule(**mill_dict)
2337 for i, sow_var in enumerate(self.sow_state):
2338 self.sow_state[sow_var] = product[i]
2340 def cultivate(self):
2341 """
2342 Has each AgentType in agents perform their market_action method, using
2343 variables sown from the market (and maybe also "private" variables).
2344 The market_action method should store new results in attributes named in
2345 reap_vars to be reaped later.
2347 Parameters
2348 ----------
2349 none
2351 Returns
2352 -------
2353 none
2354 """
2355 for this_type in self.agents:
2356 this_type.market_action()
2358 def reset(self):
2359 """
2360 Reset the state of the market (attributes in sow_vars, etc) to some
2361 user-defined initial state, and erase the histories of tracked variables.
2363 Parameters
2364 ----------
2365 none
2367 Returns
2368 -------
2369 none
2370 """
2371 # Reset the history of tracked variables
2372 self.history = {var_name: [] for var_name in self.track_vars}
2374 # Set the sow variables to their initial levels
2375 for var_name in self.sow_state:
2376 self.sow_state[var_name] = self.sow_init[var_name]
2378 # Reset each AgentType in the market
2379 for this_type in self.agents:
2380 this_type.reset()
2382 def store(self):
2383 """
2384 Record the current value of each variable X named in track_vars in an
2385 dictionary field named history[X].
2387 Parameters
2388 ----------
2389 none
2391 Returns
2392 -------
2393 none
2394 """
2395 for var_name in self.track_vars:
2396 if var_name in self.sow_state:
2397 value_now = self.sow_state[var_name]
2398 elif var_name in self.reap_state:
2399 value_now = self.reap_state[var_name]
2400 elif var_name in self.const_vars:
2401 value_now = self.const_vars[var_name]
2402 else:
2403 value_now = getattr(self, var_name)
2405 self.history[var_name].append(value_now)
2407 def make_history(self):
2408 """
2409 Runs a loop of sow-->cultivate-->reap-->mill act_T times, tracking the
2410 evolution of variables X named in track_vars in dictionary fields
2411 history[X].
2413 Parameters
2414 ----------
2415 none
2417 Returns
2418 -------
2419 none
2420 """
2421 self.reset() # Initialize the state of the market
2422 for t in range(self.act_T):
2423 self.sow() # Distribute aggregated information/state to agents
2424 self.cultivate() # Agents take action
2425 self.reap() # Collect individual data from agents
2426 self.mill() # Process individual data into aggregate data
2427 self.store() # Record variables of interest
2429 def update_dynamics(self):
2430 """
2431 Calculates a new "aggregate dynamic rule" using the history of variables
2432 named in track_vars, and distributes this rule to AgentTypes in agents.
2434 Parameters
2435 ----------
2436 none
2438 Returns
2439 -------
2440 dynamics : instance
2441 The new "aggregate dynamic rule" that agents believe in and act on.
2442 Should have attributes named in dyn_vars.
2443 """
2444 # Make a dictionary of inputs for the dynamics calculator
2445 arg_names = list(get_arg_names(self.calc_dynamics))
2446 if "self" in arg_names:
2447 arg_names.remove("self")
2448 update_dict = {name: self.history[name] for name in arg_names}
2449 # Calculate a new dynamic rule and distribute it to the agents in agent_list
2450 dynamics = self.calc_dynamics(**update_dict) # User-defined dynamics calculator
2451 for var_name in self.dyn_vars:
2452 this_obj = getattr(dynamics, var_name)
2453 for this_type in self.agents:
2454 setattr(this_type, var_name, this_obj)
2455 return dynamics
2458def distribute_params(agent, param_name, param_count, distribution):
2459 """
2460 Distributes heterogeneous values of one parameter to the AgentTypes in self.agents.
2461 Parameters
2462 ----------
2463 agent: AgentType
2464 An agent to clone.
2465 param_name : string
2466 Name of the parameter to be assigned.
2467 param_count : int
2468 Number of different values the parameter will take on.
2469 distribution : Distribution
2470 A 1-D distribution.
2472 Returns
2473 -------
2474 agent_set : [AgentType]
2475 A list of param_count agents, ex ante heterogeneous with
2476 respect to param_name. The AgentCount of the original
2477 will be split between the agents of the returned
2478 list in proportion to the given distribution.
2479 """
2480 param_dist = distribution.discretize(N=param_count)
2482 agent_set = [deepcopy(agent) for i in range(param_count)]
2484 for j in range(param_count):
2485 agent_set[j].assign_parameters(
2486 **{"AgentCount": int(agent.AgentCount * param_dist.pmv[j])}
2487 )
2488 agent_set[j].assign_parameters(**{param_name: param_dist.atoms[0, j]})
2490 return agent_set
2493@dataclass
2494class AgentPopulation:
2495 """
2496 A class for representing a population of ex-ante heterogeneous agents.
2497 """
2499 agent_type: AgentType # type of agent in the population
2500 parameters: dict # dictionary of parameters
2501 seed: int = 0 # random seed
2502 time_var: List[str] = field(init=False)
2503 time_inv: List[str] = field(init=False)
2504 distributed_params: List[str] = field(init=False)
2505 agent_type_count: Optional[int] = field(init=False)
2506 term_age: Optional[int] = field(init=False)
2507 continuous_distributions: Dict[str, Distribution] = field(init=False)
2508 discrete_distributions: Dict[str, Distribution] = field(init=False)
2509 population_parameters: List[Dict[str, Union[List[float], float]]] = field(
2510 init=False
2511 )
2512 agents: List[AgentType] = field(init=False)
2513 agent_database: pd.DataFrame = field(init=False)
2514 solution: List[Any] = field(init=False)
2516 def __post_init__(self):
2517 """
2518 Initialize the population of agents, determine distributed parameters,
2519 and infer `agent_type_count` and `term_age`.
2520 """
2521 # create a dummy agent and obtain its time-varying
2522 # and time-invariant attributes
2523 dummy_agent = self.agent_type()
2524 self.time_var = dummy_agent.time_vary
2525 self.time_inv = dummy_agent.time_inv
2527 # create list of distributed parameters
2528 # these are parameters that differ across agents
2529 self.distributed_params = [
2530 key
2531 for key, param in self.parameters.items()
2532 if (isinstance(param, list) and isinstance(param[0], list))
2533 or isinstance(param, Distribution)
2534 or (isinstance(param, DataArray) and param.dims[0] == "agent")
2535 ]
2537 self.__infer_counts__()
2539 self.print_parallel_error_once = True
2540 # Print warning once if parallel simulation fails
2542 def __infer_counts__(self):
2543 """
2544 Infer `agent_type_count` and `term_age` from the parameters.
2545 If parameters include a `Distribution` type, a list of lists,
2546 or a `DataArray` with `agent` as the first dimension, then
2547 the AgentPopulation contains ex-ante heterogenous agents.
2548 """
2550 # infer agent_type_count from distributed parameters
2551 agent_type_count = 1
2552 for key in self.distributed_params:
2553 param = self.parameters[key]
2554 if isinstance(param, Distribution):
2555 agent_type_count = None
2556 warn(
2557 "Cannot infer agent_type_count from a Distribution. "
2558 "Please provide approximation parameters."
2559 )
2560 break
2561 elif isinstance(param, list):
2562 agent_type_count = max(agent_type_count, len(param))
2563 elif isinstance(param, DataArray) and param.dims[0] == "agent":
2564 agent_type_count = max(agent_type_count, param.shape[0])
2566 self.agent_type_count = agent_type_count
2568 # infer term_age from all parameters
2569 term_age = 1
2570 for param in self.parameters.values():
2571 if isinstance(param, Distribution):
2572 term_age = None
2573 warn(
2574 "Cannot infer term_age from a Distribution. "
2575 "Please provide approximation parameters."
2576 )
2577 break
2578 elif isinstance(param, list) and isinstance(param[0], list):
2579 term_age = max(term_age, len(param[0]))
2580 elif isinstance(param, DataArray) and param.dims[-1] == "age":
2581 term_age = max(term_age, param.shape[-1])
2583 self.term_age = term_age
2585 def approx_distributions(self, approx_params: dict):
2586 """
2587 Approximate continuous distributions with discrete ones. If the initial
2588 parameters include a `Distribution` type, then the AgentPopulation is
2589 not ready to solve, and stands for an abstract population. To solve the
2590 AgentPopulation, we need discretization parameters for each continuous
2591 distribution. This method approximates the continuous distributions with
2592 discrete ones, and updates the parameters dictionary.
2593 """
2594 self.continuous_distributions = {}
2595 self.discrete_distributions = {}
2597 for key, args in approx_params.items():
2598 param = self.parameters[key]
2599 if key in self.distributed_params and isinstance(param, Distribution):
2600 self.continuous_distributions[key] = param
2601 self.discrete_distributions[key] = param.discretize(**args)
2602 else:
2603 raise ValueError(
2604 f"Warning: parameter {key} is not a Distribution found "
2605 f"in agent type {self.agent_type}"
2606 )
2608 if len(self.discrete_distributions) > 1:
2609 joint_dist = combine_indep_dstns(*self.discrete_distributions.values())
2610 else:
2611 joint_dist = list(self.discrete_distributions.values())[0]
2613 for i, key in enumerate(self.discrete_distributions):
2614 self.parameters[key] = DataArray(joint_dist.atoms[i], dims=("agent"))
2616 self.__infer_counts__()
2618 def __parse_parameters__(self) -> None:
2619 """
2620 Creates distributed dictionaries of parameters for each ex-ante
2621 heterogeneous agent in the parameterized population. The parameters
2622 are stored in a list of dictionaries, where each dictionary contains
2623 the parameters for one agent. Expands parameters that vary over time
2624 to a list of length `term_age`.
2625 """
2627 population_parameters = [] # container for dictionaries of each agent subgroup
2628 for agent in range(self.agent_type_count):
2629 agent_parameters = {}
2630 for key, param in self.parameters.items():
2631 if key in self.time_var:
2632 # parameters that vary over time have to be repeated
2633 if isinstance(param, (int, float)):
2634 parameter_per_t = [param] * self.term_age
2635 elif isinstance(param, list):
2636 if isinstance(param[0], list):
2637 parameter_per_t = param[agent]
2638 else:
2639 parameter_per_t = param
2640 elif isinstance(param, DataArray):
2641 if param.dims[0] == "agent":
2642 if param.dims[-1] == "age":
2643 parameter_per_t = param[agent].item()
2644 else:
2645 parameter_per_t = param.item()
2646 elif param.dims[0] == "age":
2647 parameter_per_t = param.item()
2649 agent_parameters[key] = parameter_per_t
2651 elif key in self.time_inv:
2652 if isinstance(param, (int, float)):
2653 agent_parameters[key] = param
2654 elif isinstance(param, list):
2655 if isinstance(param[0], list):
2656 agent_parameters[key] = param[agent]
2657 else:
2658 agent_parameters[key] = param
2659 elif isinstance(param, DataArray) and param.dims[0] == "agent":
2660 agent_parameters[key] = param[agent].item()
2662 else:
2663 if isinstance(param, (int, float)):
2664 agent_parameters[key] = param # assume time inv
2665 elif isinstance(param, list):
2666 if isinstance(param[0], list):
2667 agent_parameters[key] = param[agent] # assume agent vary
2668 else:
2669 agent_parameters[key] = param # assume time vary
2670 elif isinstance(param, DataArray):
2671 if param.dims[0] == "agent":
2672 if param.dims[-1] == "age":
2673 agent_parameters[key] = param[
2674 agent
2675 ].item() # assume agent vary
2676 else:
2677 agent_parameters[key] = param.item() # assume time vary
2678 elif param.dims[0] == "age":
2679 agent_parameters[key] = param.item() # assume time vary
2681 population_parameters.append(agent_parameters)
2683 self.population_parameters = population_parameters
2685 def create_distributed_agents(self):
2686 """
2687 Parses the parameters dictionary and creates a list of agents with the
2688 appropriate parameters. Also sets the seed for each agent.
2689 """
2691 self.__parse_parameters__()
2693 rng = np.random.default_rng(self.seed)
2695 self.agents = [
2696 self.agent_type(seed=rng.integers(0, 2**31 - 1), **agent_dict)
2697 for agent_dict in self.population_parameters
2698 ]
2700 def create_database(self):
2701 """
2702 Optionally creates a pandas DataFrame with the parameters for each agent.
2703 """
2704 database = pd.DataFrame(self.population_parameters)
2705 database["agents"] = self.agents
2707 self.agent_database = database
2709 def solve(self):
2710 """
2711 Solves each agent of the population serially.
2712 """
2714 # see Market class for an example of how to solve distributed agents in parallel
2716 for agent in self.agents:
2717 agent.solve()
2719 def unpack_solutions(self):
2720 """
2721 Unpacks the solutions of each agent into an attribute of the population.
2722 """
2723 self.solution = [agent.solution for agent in self.agents]
2725 def initialize_sim(self):
2726 """
2727 Initializes the simulation for each agent.
2728 """
2729 for agent in self.agents:
2730 agent.initialize_sim()
2732 def simulate(self, num_jobs=None):
2733 """
2734 Simulates each agent of the population.
2736 Parameters
2737 ----------
2738 num_jobs : int, optional
2739 Number of parallel jobs to use. Defaults to using all available
2740 cores when ``None``. Falls back to serial execution if parallel
2741 processing fails.
2742 """
2743 try:
2744 multi_thread_commands(self.agents, ["simulate()"], num_jobs)
2745 except Exception as err:
2746 if getattr(self, "print_parallel_error_once", False):
2747 self.print_parallel_error_once = False
2748 print(
2749 "**** WARNING: could not execute multi_thread_commands in HARK.core.AgentPopulation.simulate() ",
2750 "so using the serial version instead. This will likely be slower. ",
2751 "The multi_thread_commands() function failed with the following error:\n",
2752 sys.exc_info()[0],
2753 ":",
2754 err,
2755 )
2756 multi_thread_commands_fake(self.agents, ["simulate()"], num_jobs)
2758 def __iter__(self):
2759 """
2760 Allows for iteration over the agents in the population.
2761 """
2762 return iter(self.agents)
2764 def __getitem__(self, idx):
2765 """
2766 Allows for indexing into the population.
2767 """
2768 return self.agents[idx]
2771###############################################################################
2774def multi_thread_commands_fake(
2775 agent_list: List, command_list: List, num_jobs=None
2776) -> None:
2777 """
2778 Executes the list of commands in command_list for each AgentType in agent_list
2779 in an ordinary, single-threaded loop. Each command should be a method of
2780 that AgentType subclass. This function exists so as to easily disable
2781 multithreading, as it uses the same syntax as multi_thread_commands.
2783 Parameters
2784 ----------
2785 agent_list : [AgentType]
2786 A list of instances of AgentType on which the commands will be run.
2787 command_list : [string]
2788 A list of commands to run for each AgentType.
2789 num_jobs : None
2790 Dummy input to match syntax of multi_thread_commands. Does nothing.
2792 Returns
2793 -------
2794 none
2795 """
2796 for agent in agent_list:
2797 for command in command_list:
2798 # TODO: Code should be updated to pass in the method name instead of method()
2799 getattr(agent, command[:-2])()
2802def multi_thread_commands(agent_list: List, command_list: List, num_jobs=None) -> None:
2803 """
2804 Executes the list of commands in command_list for each AgentType in agent_list
2805 using a multithreaded system. Each command should be a method of that AgentType subclass.
2807 Parameters
2808 ----------
2809 agent_list : [AgentType]
2810 A list of instances of AgentType on which the commands will be run.
2811 command_list : [string]
2812 A list of commands to run for each AgentType in agent_list.
2814 Returns
2815 -------
2816 None
2817 """
2818 if len(agent_list) == 1:
2819 multi_thread_commands_fake(agent_list, command_list)
2820 return None
2822 # Default number of parallel jobs is the smaller of number of AgentTypes in
2823 # the input and the number of available cores.
2824 if num_jobs is None:
2825 num_jobs = min(len(agent_list), multiprocessing.cpu_count())
2827 # Send each command in command_list to each of the types in agent_list to be run
2828 agent_list_out = Parallel(n_jobs=num_jobs)(
2829 delayed(run_commands)(*args)
2830 for args in zip(agent_list, len(agent_list) * [command_list])
2831 )
2833 # Replace the original types with the output from the parallel call
2834 for j in range(len(agent_list)):
2835 agent_list[j] = agent_list_out[j]
2838def run_commands(agent: Any, command_list: List) -> Any:
2839 """
2840 Executes each command in command_list on a given AgentType. The commands
2841 should be methods of that AgentType's subclass.
2843 Parameters
2844 ----------
2845 agent : AgentType
2846 An instance of AgentType on which the commands will be run.
2847 command_list : [string]
2848 A list of commands that the agent should run, as methods.
2850 Returns
2851 -------
2852 agent : AgentType
2853 The same AgentType instance passed as input, after running the commands.
2854 """
2855 for command in command_list:
2856 # TODO: Code should be updated to pass in the method name instead of method()
2857 getattr(agent, command[:-2])()
2858 return agent