Coverage for HARK / core.py: 96%

1088 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-08 05:31 +0000

1""" 

2High-level functions and classes for solving a wide variety of economic models. 

3The "core" of HARK is a framework for "microeconomic" and "macroeconomic" 

4models. A micro model concerns the dynamic optimization problem for some type 

5of agents, where agents take the inputs to their problem as exogenous. A macro 

6model adds an additional layer, endogenizing some of the inputs to the micro 

7problem by finding a general equilibrium dynamic rule. 

8""" 

9 

10# Import basic modules 

11import inspect 

12import sys 

13from collections import namedtuple 

14from copy import copy, deepcopy 

15from dataclasses import dataclass, field 

16from time import time 

17from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union 

18from warnings import warn 

19import multiprocessing 

20from joblib import Parallel, delayed 

21from pandas import DataFrame 

22 

23import numpy as np 

24import pandas as pd 

25from xarray import DataArray 

26 

27from HARK.distributions import ( 

28 Distribution, 

29 IndexDistribution, 

30 combine_indep_dstns, 

31) 

32from HARK.utilities import NullFunc, get_arg_names, get_it_from 

33from HARK.simulator import make_simulator_from_agent 

34from HARK.SSJutils import ( 

35 make_basic_SSJ_matrices, 

36 calc_shock_response_manually, 

37) 

38from HARK.metric import MetricObject, distance_metric 

39 

40__all__ = [ 

41 "AgentType", 

42 "Market", 

43 "Parameters", 

44 "Model", 

45 "AgentPopulation", 

46 "multi_thread_commands", 

47 "multi_thread_commands_fake", 

48 "NullFunc", 

49 "make_one_period_oo_solver", 

50 "distribute_params", 

51] 

52 

53 

54class Parameters: 

55 """ 

56 A smart container for model parameters that handles age-varying dynamics. 

57 

58 This class stores parameters as an internal dictionary and manages their 

59 age-varying properties, providing both attribute-style and dictionary-style 

60 access. It is designed to handle the time-varying dynamics of parameters 

61 in economic models. 

62 

63 Attributes 

64 ---------- 

65 _length : int 

66 The terminal age of the agents in the model. 

67 _invariant_params : Set[str] 

68 A set of parameter names that are invariant over time. 

69 _varying_params : Set[str] 

70 A set of parameter names that vary over time. 

71 _parameters : Dict[str, Any] 

72 The internal dictionary storing all parameters. 

73 """ 

74 

75 __slots__ = ( 

76 "_length", 

77 "_invariant_params", 

78 "_varying_params", 

79 "_parameters", 

80 "_frozen", 

81 "_namedtuple_cache", 

82 ) 

83 

84 def __init__(self, **parameters: Any) -> None: 

85 """ 

86 Initialize a Parameters object and parse the age-varying dynamics of parameters. 

87 

88 Parameters 

89 ---------- 

90 T_cycle : int, optional 

91 The number of time periods in the model cycle (default: 1). 

92 Must be >= 1. 

93 frozen : bool, optional 

94 If True, the Parameters object will be immutable after initialization 

95 (default: False). 

96 _time_inv : List[str], optional 

97 List of parameter names to explicitly mark as time-invariant, 

98 overriding automatic inference. 

99 _time_vary : List[str], optional 

100 List of parameter names to explicitly mark as time-varying, 

101 overriding automatic inference. 

102 **parameters : Any 

103 Any number of parameters in the form key=value. 

104 

105 Raises 

106 ------ 

107 ValueError 

108 If T_cycle is less than 1. 

109 

110 Notes 

111 ----- 

112 Automatic time-variance inference rules: 

113 - Scalars (int, float, bool, None) are time-invariant 

114 - NumPy arrays are time-invariant (use lists/tuples for time-varying) 

115 - Single-element lists/tuples [x] are unwrapped to x and time-invariant 

116 - Multi-element lists/tuples are time-varying if length matches T_cycle 

117 - 2D arrays with first dimension matching T_cycle are time-varying 

118 - Distributions and Callables are time-invariant 

119 

120 Use _time_inv or _time_vary to override automatic inference when needed. 

121 """ 

122 # Extract special parameters 

123 self._length: int = parameters.pop("T_cycle", 1) 

124 frozen: bool = parameters.pop("frozen", False) 

125 time_inv_override: List[str] = parameters.pop("_time_inv", []) 

126 time_vary_override: List[str] = parameters.pop("_time_vary", []) 

127 

128 # Validate T_cycle 

129 if self._length < 1: 

130 raise ValueError(f"T_cycle must be >= 1, got {self._length}") 

131 

132 # Initialize internal state 

133 self._invariant_params: Set[str] = set() 

134 self._varying_params: Set[str] = set() 

135 self._parameters: Dict[str, Any] = {"T_cycle": self._length} 

136 self._frozen: bool = False # Set to False initially to allow setup 

137 self._namedtuple_cache: Optional[type] = None 

138 

139 # Set parameters using automatic inference 

140 for key, value in parameters.items(): 

141 self[key] = value 

142 

143 # Apply explicit overrides 

144 for param in time_inv_override: 

145 if param in self._parameters: 

146 self._invariant_params.add(param) 

147 self._varying_params.discard(param) 

148 

149 for param in time_vary_override: 

150 if param in self._parameters: 

151 self._varying_params.add(param) 

152 self._invariant_params.discard(param) 

153 

154 # Freeze if requested 

155 self._frozen = frozen 

156 

157 def __getitem__(self, item_or_key: Union[int, str]) -> Union["Parameters", Any]: 

158 """ 

159 Access parameters by age index or parameter name. 

160 

161 If item_or_key is an integer, returns a Parameters object with the parameters 

162 that apply to that age. This includes all invariant parameters and the 

163 `item_or_key`th element of all age-varying parameters. If item_or_key is a 

164 string, it returns the value of the parameter with that name. 

165 

166 Parameters 

167 ---------- 

168 item_or_key : Union[int, str] 

169 Age index or parameter name. 

170 

171 Returns 

172 ------- 

173 Union[Parameters, Any] 

174 A new Parameters object for the specified age, or the value of the 

175 specified parameter. 

176 

177 Raises 

178 ------ 

179 ValueError: 

180 If the age index is out of bounds. 

181 KeyError: 

182 If the parameter name is not found. 

183 TypeError: 

184 If the key is neither an integer nor a string. 

185 """ 

186 if isinstance(item_or_key, int): 

187 if item_or_key < 0 or item_or_key >= self._length: 

188 raise ValueError( 

189 f"Age {item_or_key} is out of bounds (valid: 0-{self._length - 1})." 

190 ) 

191 

192 params = {key: self._parameters[key] for key in self._invariant_params} 

193 params.update( 

194 { 

195 key: ( 

196 self._parameters[key][item_or_key] 

197 if isinstance(self._parameters[key], (list, tuple, np.ndarray)) 

198 else self._parameters[key] 

199 ) 

200 for key in self._varying_params 

201 } 

202 ) 

203 return Parameters(**params) 

204 elif isinstance(item_or_key, str): 

205 return self._parameters[item_or_key] 

206 else: 

207 raise TypeError("Key must be an integer (age) or string (parameter name).") 

208 

209 def __setitem__(self, key: str, value: Any) -> None: 

210 """ 

211 Set parameter values, automatically inferring time variance. 

212 

213 If the parameter is a scalar, numpy array, boolean, distribution, callable 

214 or None, it is assumed to be invariant over time. If the parameter is a 

215 list or tuple, it is assumed to be varying over time. If the parameter 

216 is a list or tuple of length greater than 1, the length of the list or 

217 tuple must match the `_length` attribute of the Parameters object. 

218 

219 2D numpy arrays with first dimension matching T_cycle are treated as 

220 time-varying parameters. 

221 

222 Parameters 

223 ---------- 

224 key : str 

225 Name of the parameter. 

226 value : Any 

227 Value of the parameter. 

228 

229 Raises 

230 ------ 

231 ValueError: 

232 If the parameter name is not a string or if the value type is unsupported. 

233 If the parameter value is inconsistent with the current model length. 

234 RuntimeError: 

235 If the Parameters object is frozen. 

236 """ 

237 if self._frozen: 

238 raise RuntimeError("Cannot modify frozen Parameters object") 

239 

240 if not isinstance(key, str): 

241 raise ValueError(f"Parameter name must be a string, got {type(key)}") 

242 

243 # Check for 2D numpy arrays with time-varying first dimension 

244 if isinstance(value, np.ndarray) and value.ndim >= 2: 

245 if value.shape[0] == self._length: 

246 self._varying_params.add(key) 

247 self._invariant_params.discard(key) 

248 else: 

249 self._invariant_params.add(key) 

250 self._varying_params.discard(key) 

251 elif isinstance( 

252 value, 

253 ( 

254 int, 

255 float, 

256 np.ndarray, 

257 type(None), 

258 Distribution, 

259 bool, 

260 Callable, 

261 MetricObject, 

262 ), 

263 ): 

264 self._invariant_params.add(key) 

265 self._varying_params.discard(key) 

266 elif isinstance(value, (list, tuple)): 

267 if len(value) == 1: 

268 value = value[0] 

269 self._invariant_params.add(key) 

270 self._varying_params.discard(key) 

271 elif self._length is None or self._length == 1: 

272 self._length = len(value) 

273 self._varying_params.add(key) 

274 self._invariant_params.discard(key) 

275 elif len(value) == self._length: 

276 self._varying_params.add(key) 

277 self._invariant_params.discard(key) 

278 else: 

279 raise ValueError( 

280 f"Parameter {key} must have length 1 or {self._length}, not {len(value)}" 

281 ) 

282 else: 

283 raise ValueError(f"Unsupported type for parameter {key}: {type(value)}") 

284 

285 self._parameters[key] = value 

286 

287 def __iter__(self) -> Iterator[str]: 

288 """Allow iteration over parameter names.""" 

289 return iter(self._parameters) 

290 

291 def __len__(self) -> int: 

292 """Return the number of parameters.""" 

293 return len(self._parameters) 

294 

295 def keys(self) -> Iterator[str]: 

296 """Return a view of parameter names.""" 

297 return self._parameters.keys() 

298 

299 def values(self) -> Iterator[Any]: 

300 """Return a view of parameter values.""" 

301 return self._parameters.values() 

302 

303 def items(self) -> Iterator[Tuple[str, Any]]: 

304 """Return a view of parameter (name, value) pairs.""" 

305 return self._parameters.items() 

306 

307 def to_dict(self) -> Dict[str, Any]: 

308 """ 

309 Convert parameters to a plain dictionary. 

310 

311 Returns 

312 ------- 

313 Dict[str, Any] 

314 A dictionary containing all parameters. 

315 """ 

316 return dict(self._parameters) 

317 

318 def to_namedtuple(self) -> namedtuple: 

319 """ 

320 Convert parameters to a namedtuple. 

321 

322 The namedtuple class is cached for efficiency on repeated calls. 

323 

324 Returns 

325 ------- 

326 namedtuple 

327 A namedtuple containing all parameters. 

328 """ 

329 if self._namedtuple_cache is None: 

330 self._namedtuple_cache = namedtuple("Parameters", self.keys()) 

331 return self._namedtuple_cache(**self.to_dict()) 

332 

333 def update(self, other: Union["Parameters", Dict[str, Any]]) -> None: 

334 """ 

335 Update parameters from another Parameters object or dictionary. 

336 

337 Parameters 

338 ---------- 

339 other : Union[Parameters, Dict[str, Any]] 

340 The source of parameters to update from. 

341 

342 Raises 

343 ------ 

344 TypeError 

345 If the input is neither a Parameters object nor a dictionary. 

346 """ 

347 if isinstance(other, Parameters): 

348 for key, value in other._parameters.items(): 

349 self[key] = value 

350 elif isinstance(other, dict): 

351 for key, value in other.items(): 

352 self[key] = value 

353 else: 

354 raise TypeError( 

355 "Update source must be a Parameters object or a dictionary." 

356 ) 

357 

358 def __repr__(self) -> str: 

359 """Return a detailed string representation of the Parameters object.""" 

360 return ( 

361 f"Parameters(_length={self._length}, " 

362 f"_invariant_params={self._invariant_params}, " 

363 f"_varying_params={self._varying_params}, " 

364 f"_parameters={self._parameters})" 

365 ) 

366 

367 def __str__(self) -> str: 

368 """Return a simple string representation of the Parameters object.""" 

369 return f"Parameters({str(self._parameters)})" 

370 

371 def __getattr__(self, name: str) -> Any: 

372 """ 

373 Allow attribute-style access to parameters. 

374 

375 Parameters 

376 ---------- 

377 name : str 

378 Name of the parameter to access. 

379 

380 Returns 

381 ------- 

382 Any 

383 The value of the specified parameter. 

384 

385 Raises 

386 ------ 

387 AttributeError: 

388 If the parameter name is not found. 

389 """ 

390 if name.startswith("_"): 

391 return super().__getattribute__(name) 

392 try: 

393 return self._parameters[name] 

394 except KeyError: 

395 raise AttributeError(f"'Parameters' object has no attribute '{name}'") 

396 

397 def __setattr__(self, name: str, value: Any) -> None: 

398 """ 

399 Allow attribute-style setting of parameters. 

400 

401 Parameters 

402 ---------- 

403 name : str 

404 Name of the parameter to set. 

405 value : Any 

406 Value to set for the parameter. 

407 """ 

408 if name.startswith("_"): 

409 super().__setattr__(name, value) 

410 else: 

411 self[name] = value 

412 

413 def __contains__(self, item: str) -> bool: 

414 """Check if a parameter exists in the Parameters object.""" 

415 return item in self._parameters 

416 

417 def copy(self) -> "Parameters": 

418 """ 

419 Create a deep copy of the Parameters object. 

420 

421 Returns 

422 ------- 

423 Parameters 

424 A new Parameters object with the same contents. 

425 """ 

426 return deepcopy(self) 

427 

428 def add_to_time_vary(self, *params: str) -> None: 

429 """ 

430 Adds any number of parameters to the time-varying set. 

431 

432 Parameters 

433 ---------- 

434 *params : str 

435 Any number of strings naming parameters to be added to time_vary. 

436 """ 

437 for param in params: 

438 if param in self._parameters: 

439 self._varying_params.add(param) 

440 self._invariant_params.discard(param) 

441 else: 

442 warn( 

443 f"Parameter '{param}' does not exist and cannot be added to time_vary." 

444 ) 

445 

446 def add_to_time_inv(self, *params: str) -> None: 

447 """ 

448 Adds any number of parameters to the time-invariant set. 

449 

450 Parameters 

451 ---------- 

452 *params : str 

453 Any number of strings naming parameters to be added to time_inv. 

454 """ 

455 for param in params: 

456 if param in self._parameters: 

457 self._invariant_params.add(param) 

458 self._varying_params.discard(param) 

459 else: 

460 warn( 

461 f"Parameter '{param}' does not exist and cannot be added to time_inv." 

462 ) 

463 

464 def del_from_time_vary(self, *params: str) -> None: 

465 """ 

466 Removes any number of parameters from the time-varying set. 

467 

468 Parameters 

469 ---------- 

470 *params : str 

471 Any number of strings naming parameters to be removed from time_vary. 

472 """ 

473 for param in params: 

474 self._varying_params.discard(param) 

475 

476 def del_from_time_inv(self, *params: str) -> None: 

477 """ 

478 Removes any number of parameters from the time-invariant set. 

479 

480 Parameters 

481 ---------- 

482 *params : str 

483 Any number of strings naming parameters to be removed from time_inv. 

484 """ 

485 for param in params: 

486 self._invariant_params.discard(param) 

487 

488 def get(self, key: str, default: Any = None) -> Any: 

489 """ 

490 Get a parameter value, returning a default if not found. 

491 

492 Parameters 

493 ---------- 

494 key : str 

495 The parameter name. 

496 default : Any, optional 

497 The default value to return if the key is not found. 

498 

499 Returns 

500 ------- 

501 Any 

502 The parameter value or the default. 

503 """ 

504 return self._parameters.get(key, default) 

505 

506 def set_many(self, **kwargs: Any) -> None: 

507 """ 

508 Set multiple parameters at once. 

509 

510 Parameters 

511 ---------- 

512 **kwargs : Keyword arguments representing parameter names and values. 

513 """ 

514 for key, value in kwargs.items(): 

515 self[key] = value 

516 

517 def is_time_varying(self, key: str) -> bool: 

518 """ 

519 Check if a parameter is time-varying. 

520 

521 Parameters 

522 ---------- 

523 key : str 

524 The parameter name. 

525 

526 Returns 

527 ------- 

528 bool 

529 True if the parameter is time-varying, False otherwise. 

530 """ 

531 return key in self._varying_params 

532 

533 def at_age(self, age: int) -> "Parameters": 

534 """ 

535 Get parameters for a specific age. 

536 

537 This is an alternative to integer indexing (params[age]) that is more 

538 explicit and avoids potential confusion with dictionary-style access. 

539 

540 Parameters 

541 ---------- 

542 age : int 

543 The age index to retrieve parameters for. 

544 

545 Returns 

546 ------- 

547 Parameters 

548 A new Parameters object with parameters for the specified age. 

549 

550 Raises 

551 ------ 

552 ValueError 

553 If the age index is out of bounds. 

554 

555 Examples 

556 -------- 

557 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97], sigma=2.0) 

558 >>> age_1_params = params.at_age(1) 

559 >>> age_1_params.beta 

560 0.96 

561 """ 

562 return self[age] 

563 

564 def validate(self) -> None: 

565 """ 

566 Validate parameter consistency. 

567 

568 Checks that all time-varying parameters have length matching T_cycle. 

569 This is useful after manual modifications or when parameters are set 

570 programmatically. 

571 

572 Raises 

573 ------ 

574 ValueError 

575 If any time-varying parameter has incorrect length. 

576 

577 Examples 

578 -------- 

579 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97]) 

580 >>> params.validate() # Passes 

581 >>> params.add_to_time_vary("beta") 

582 >>> params.validate() # Still passes 

583 """ 

584 errors = [] 

585 for param in self._varying_params: 

586 value = self._parameters[param] 

587 if isinstance(value, (list, tuple)): 

588 if len(value) != self._length: 

589 errors.append( 

590 f"Parameter '{param}' has length {len(value)}, expected {self._length}" 

591 ) 

592 elif isinstance(value, np.ndarray): 

593 if value.ndim == 0: 

594 errors.append( 

595 f"Parameter '{param}' is a 0-dimensional array (scalar), " 

596 "which should not be time-varying" 

597 ) 

598 elif value.ndim >= 2: 

599 if value.shape[0] != self._length: 

600 errors.append( 

601 f"Parameter '{param}' has first dimension {value.shape[0]}, expected {self._length}" 

602 ) 

603 elif value.ndim == 1: 

604 if len(value) != self._length: 

605 errors.append( 

606 f"Parameter '{param}' has length {len(value)}, expected {self._length}" 

607 ) 

608 elif value.ndim == 0: 

609 errors.append( 

610 f"Parameter '{param}' is a 0-dimensional numpy array, expected length {self._length}" 

611 ) 

612 

613 if errors: 

614 raise ValueError( 

615 "Parameter validation failed:\n" + "\n".join(f" - {e}" for e in errors) 

616 ) 

617 

618 

619class Model: 

620 """ 

621 A class with special handling of parameters assignment. 

622 """ 

623 

624 def __init__(self): 

625 if not hasattr(self, "parameters"): 

626 self.parameters = {} 

627 if not hasattr(self, "constructors"): 

628 self.constructors = {} 

629 

630 def assign_parameters(self, **kwds): 

631 """ 

632 Assign an arbitrary number of attributes to this agent. 

633 

634 Parameters 

635 ---------- 

636 **kwds : keyword arguments 

637 Any number of keyword arguments of the form key=value. 

638 Each value will be assigned to the attribute named in self. 

639 

640 Returns 

641 ------- 

642 None 

643 """ 

644 self.parameters.update(kwds) 

645 for key in kwds: 

646 setattr(self, key, kwds[key]) 

647 

648 def get_parameter(self, name): 

649 """ 

650 Returns a parameter of this model 

651 

652 Parameters 

653 ---------- 

654 name : str 

655 The name of the parameter to get 

656 

657 Returns 

658 ------- 

659 value : The value of the parameter 

660 """ 

661 return self.parameters[name] 

662 

663 def __eq__(self, other): 

664 if isinstance(other, type(self)): 

665 return self.parameters == other.parameters 

666 

667 return NotImplemented 

668 

669 def __str__(self): 

670 type_ = type(self) 

671 module = type_.__module__ 

672 qualname = type_.__qualname__ 

673 

674 s = f"<{module}.{qualname} object at {hex(id(self))}.\n" 

675 s += "Parameters:" 

676 

677 for p in self.parameters: 

678 s += f"\n{p}: {self.parameters[p]}" 

679 

680 s += ">" 

681 return s 

682 

683 def describe(self): 

684 return self.__str__() 

685 

686 def del_param(self, param_name): 

687 """ 

688 Deletes a parameter from this instance, removing it both from the object's 

689 namespace (if it's there) and the parameters dictionary (likewise). 

690 

691 Parameters 

692 ---------- 

693 param_name : str 

694 A string naming a parameter or data to be deleted from this instance. 

695 Removes information from self.parameters dictionary and own namespace. 

696 

697 Returns 

698 ------- 

699 None 

700 """ 

701 if param_name in self.parameters: 

702 del self.parameters[param_name] 

703 if hasattr(self, param_name): 

704 delattr(self, param_name) 

705 

706 def _gather_constructor_args(self, key, constructor): 

707 """ 

708 Gather all arguments needed to call the given constructor. 

709 

710 Handles both the special ``get_it_from`` case and normal callables. 

711 For a normal callable the method inspects the function signature to 

712 find every required argument, then resolves each argument from the 

713 instance namespace (``self.<arg>``) or from ``self.parameters``. 

714 Arguments that have a default value are silently skipped when they 

715 cannot be found; required arguments that cannot be resolved are 

716 recorded as missing. 

717 

718 Parameters 

719 ---------- 

720 key : str 

721 The name of the constructed object (used to record missing pairs). 

722 constructor : callable or get_it_from 

723 The constructor to be called. 

724 

725 Returns 

726 ------- 

727 temp_dict : dict 

728 Keyword arguments to pass to the constructor. 

729 any_missing : bool 

730 True when at least one required argument could not be resolved. 

731 missing_args : list of str 

732 Names of unresolved required arguments. 

733 missing_key_data : list of tuple 

734 ``(key, arg)`` pairs for every unresolved required argument. 

735 """ 

736 missing_key_data = [] 

737 

738 # SPECIAL: if the constructor is get_it_from, handle it separately 

739 if isinstance(constructor, get_it_from): 

740 try: 

741 parent = getattr(self, constructor.name) 

742 query = key 

743 any_missing = False 

744 missing_args = [] 

745 except AttributeError: 

746 parent = None 

747 query = None 

748 any_missing = True 

749 missing_args = [constructor.name] 

750 temp_dict = {"parent": parent, "query": query} 

751 return temp_dict, any_missing, missing_args, missing_key_data 

752 

753 # Normal constructor: inspect signature and gather arguments 

754 args_needed = get_arg_names(constructor) 

755 has_no_default = { 

756 k: v.default is inspect.Parameter.empty 

757 for k, v in inspect.signature(constructor).parameters.items() 

758 } 

759 temp_dict = {} 

760 any_missing = False 

761 missing_args = [] 

762 for this_arg in args_needed: 

763 if hasattr(self, this_arg): 

764 temp_dict[this_arg] = getattr(self, this_arg) 

765 else: 

766 try: 

767 temp_dict[this_arg] = self.parameters[this_arg] 

768 except KeyError: 

769 if has_no_default[this_arg]: 

770 # Record missing key-data pair 

771 any_missing = True 

772 missing_key_data.append((key, this_arg)) 

773 missing_args.append(this_arg) 

774 

775 return temp_dict, any_missing, missing_args, missing_key_data 

776 

777 def _attempt_construct(self, key, i, keys_complete, backup, errors, force): 

778 """ 

779 Attempt to construct the object for a single key. 

780 

781 The method looks up the constructor, gathers its arguments, runs it, 

782 and records any errors. It also handles the ``None``-constructor case 

783 (restore from backup) and the missing-args case (record and defer). 

784 

785 Parameters 

786 ---------- 

787 key : str 

788 The name of the object to construct. 

789 i : int 

790 Index of *key* inside the ``keys`` array (used to update 

791 ``keys_complete``). 

792 keys_complete : np.ndarray of bool 

793 Boolean array indicating which keys have been completed; mutated 

794 in place when this key succeeds. 

795 backup : dict 

796 Dictionary of pre-construction attribute values. 

797 errors : dict 

798 ``self._constructor_errors``; mutated in place. 

799 force : bool 

800 When True, swallow exceptions and continue; when False, re-raise. 

801 

802 Returns 

803 ------- 

804 accomplished : bool 

805 True when the key was completed (constructor ran or was None). 

806 missing_key_data : list of tuple 

807 ``(key, arg)`` pairs recorded for missing required arguments. 

808 """ 

809 missing_key_data = [] 

810 

811 # Look up the constructor for this key 

812 try: 

813 constructor = self.constructors[key] 

814 except Exception as not_found: 

815 errors[key] = "No constructor found for " + str(not_found) 

816 if force: 

817 return False, missing_key_data 

818 else: 

819 raise KeyError("No constructor found for " + key) from None 

820 

821 # If the constructor is None, restore from backup and mark complete 

822 if constructor is None: 

823 if key in backup.keys(): 

824 setattr(self, key, backup[key]) 

825 self.parameters[key] = backup[key] 

826 keys_complete[i] = True 

827 return True, missing_key_data 

828 

829 # Gather arguments for the constructor 

830 temp_dict, any_missing, missing_args, missing_key_data = ( 

831 self._gather_constructor_args(key, constructor) 

832 ) 

833 

834 # If all required data was found, run the constructor and store the result 

835 if not any_missing: 

836 try: 

837 temp = constructor(**temp_dict) 

838 except Exception as problem: 

839 errors[key] = str(type(problem)) + ": " + str(problem) 

840 self.del_param(key) 

841 if force: 

842 return False, missing_key_data 

843 else: 

844 raise 

845 setattr(self, key, temp) 

846 self.parameters[key] = temp 

847 if key in errors: 

848 del errors[key] 

849 keys_complete[i] = True 

850 return True, missing_key_data 

851 

852 # Some required arguments were missing; record and defer 

853 msg = "Missing required arguments: " + ", ".join(missing_args) 

854 errors[key] = msg 

855 self.del_param(key) 

856 # Never raise exceptions here, as the arguments might be filled in later 

857 return False, missing_key_data 

858 

859 def _construct_pass(self, keys, keys_complete, backup, errors, force): 

860 """ 

861 Perform one full sweep over all incomplete keys. 

862 

863 Calls ``_attempt_construct`` for every key that has not yet been 

864 completed and accumulates results. 

865 

866 Parameters 

867 ---------- 

868 keys : sequence of str 

869 All keys requested for construction. 

870 keys_complete : np.ndarray of bool 

871 Boolean array indicating which keys have been completed; mutated 

872 in place by ``_attempt_construct``. 

873 backup : dict 

874 Dictionary of pre-construction attribute values. 

875 errors : dict 

876 ``self._constructor_errors``; mutated in place. 

877 force : bool 

878 Passed through to ``_attempt_construct``. 

879 

880 Returns 

881 ------- 

882 anything_accomplished : bool 

883 True when at least one key was completed during this pass. 

884 missing_key_data : list of tuple 

885 Accumulated ``(key, arg)`` pairs for all unresolved required 

886 arguments across every incomplete key. 

887 """ 

888 anything_accomplished = False 

889 missing_key_data = [] 

890 

891 for i, key in enumerate(keys): 

892 if keys_complete[i]: 

893 continue # This key has already been built 

894 

895 accomplished, key_missing = self._attempt_construct( 

896 key, i, keys_complete, backup, errors, force 

897 ) 

898 if accomplished: 

899 anything_accomplished = True 

900 missing_key_data.extend(key_missing) 

901 

902 return anything_accomplished, missing_key_data 

903 

904 def construct(self, *args, force=False): 

905 """ 

906 Top-level method for building constructed inputs. If called without any 

907 inputs, construct builds each of the objects named in the keys of the 

908 constructors dictionary; it draws inputs for the constructors from the 

909 parameters dictionary and adds its results to the same. If passed one or 

910 more strings as arguments, the method builds only the named keys. The 

911 method will do multiple "passes" over the requested keys, as some cons- 

912 tructors require inputs built by other constructors. If any requested 

913 constructors failed to build due to missing data, those keys (and the 

914 missing data) will be named in self._missing_key_data. Other errors are 

915 recorded in the dictionary attribute _constructor_errors. 

916 

917 This method tries to "start from scratch" by removing prior constructed 

918 objects, holding them in a backup dictionary during construction. This 

919 is done so that dependencies among constructors are resolved properly, 

920 without mistakenly relying on "old information". A backup value is used 

921 if a constructor function is set to None (i.e. "don't do anything"), or 

922 if the construct method fails to produce a new object. 

923 

924 Parameters 

925 ---------- 

926 *args : str, optional 

927 Keys of self.constructors that are requested to be constructed. 

928 If no arguments are passed, *all* elements of the dictionary are implied. 

929 force : bool, optional 

930 When True, the method will force its way past any errors, including 

931 missing constructors, missing arguments for constructors, and errors 

932 raised during execution of constructors. Information about all such 

933 errors is stored in the dictionary attributes described above. When 

934 False (default), any errors or exception will be raised. 

935 

936 Returns 

937 ------- 

938 None 

939 """ 

940 # Set up the requested work 

941 if len(args) > 0: 

942 keys = args 

943 else: 

944 keys = list(self.constructors.keys()) 

945 N_keys = len(keys) 

946 keys_complete = np.zeros(N_keys, dtype=bool) 

947 if N_keys == 0: 

948 return # Do nothing if there are no constructed objects 

949 

950 # Remove pre-existing constructed objects, preventing "incomplete" updates, 

951 # but store the current values in a backup dictionary in case something fails 

952 backup = {} 

953 for key in keys: 

954 if hasattr(self, key): 

955 backup[key] = getattr(self, key) 

956 self.del_param(key) 

957 

958 # Get the dictionary of constructor errors 

959 if not hasattr(self, "_constructor_errors"): 

960 self._constructor_errors = {} 

961 errors = self._constructor_errors 

962 

963 # As long as the work isn't complete and we made some progress on the last 

964 # pass, repeatedly perform passes of trying to construct objects 

965 any_keys_incomplete = np.any(np.logical_not(keys_complete)) 

966 go = any_keys_incomplete 

967 while go: 

968 anything_accomplished, missing_key_data = self._construct_pass( 

969 keys, keys_complete, backup, errors, force 

970 ) 

971 any_keys_incomplete = np.any(np.logical_not(keys_complete)) 

972 go = any_keys_incomplete and anything_accomplished 

973 

974 # Store missing key-data pairs and exit 

975 self._missing_key_data = missing_key_data 

976 self._constructor_errors = errors 

977 if any_keys_incomplete: 

978 msg = "Did not construct these objects:" 

979 for i in range(N_keys): 

980 if keys_complete[i]: 

981 continue 

982 msg += " " + keys[i] + "," 

983 key = keys[i] 

984 if key in backup.keys(): 

985 setattr(self, key, backup[key]) 

986 self.parameters[key] = backup[key] 

987 msg = msg[:-1] 

988 if not force: 

989 raise ValueError(msg) 

990 return 

991 

992 def describe_constructors(self, *args): 

993 """ 

994 Prints to screen a string describing this instance's constructed objects, 

995 including their names, the function that constructs them, the names of 

996 those functions inputs, and whether those inputs are present. 

997 

998 Parameters 

999 ---------- 

1000 *args : str, optional 

1001 Optional list of strings naming constructed inputs to be described. 

1002 If none are passed, all constructors are described. 

1003 

1004 Returns 

1005 ------- 

1006 None 

1007 """ 

1008 if len(args) > 0: 

1009 keys = args 

1010 else: 

1011 keys = list(self.constructors.keys()) 

1012 yes = "\u2713" 

1013 no = "X" 

1014 maybe = "*" 

1015 noyes = [no, yes] 

1016 

1017 out = "" 

1018 for key in keys: 

1019 has_val = hasattr(self, key) or (key in self.parameters) 

1020 

1021 try: 

1022 constructor = self.constructors[key] 

1023 except KeyError: 

1024 out += noyes[int(has_val)] + " " + key + " : NO CONSTRUCTOR FOUND\n" 

1025 continue 

1026 

1027 # Get the constructor function if possible 

1028 if isinstance(constructor, get_it_from): 

1029 parent_name = self.constructors[key].name 

1030 out += ( 

1031 noyes[int(has_val)] 

1032 + " " 

1033 + key 

1034 + " : get it from " 

1035 + parent_name 

1036 + "\n" 

1037 ) 

1038 continue 

1039 else: 

1040 out += ( 

1041 noyes[int(has_val)] 

1042 + " " 

1043 + key 

1044 + " : " 

1045 + constructor.__name__ 

1046 + "\n" 

1047 ) 

1048 

1049 # Get constructor argument names 

1050 arg_names = get_arg_names(constructor) 

1051 has_no_default = { 

1052 k: v.default is inspect.Parameter.empty 

1053 for k, v in inspect.signature(constructor).parameters.items() 

1054 } 

1055 

1056 # Check whether each argument exists 

1057 for j in range(len(arg_names)): 

1058 this_arg = arg_names[j] 

1059 if hasattr(self, this_arg) or this_arg in self.parameters: 

1060 symb = yes 

1061 elif not has_no_default[this_arg]: 

1062 symb = maybe 

1063 else: 

1064 symb = no 

1065 out += " " + symb + " " + this_arg + "\n" 

1066 

1067 # Print the string to screen 

1068 print(out) 

1069 return 

1070 

1071 # This is a "synonym" method so that old calls to update() still work 

1072 def update(self, *args, **kwargs): 

1073 self.construct(*args, **kwargs) 

1074 

1075 

1076class AgentType(Model): 

1077 """ 

1078 A superclass for economic agents in the HARK framework. Each model should 

1079 specify its own subclass of AgentType, inheriting its methods and overwriting 

1080 as necessary. Critically, every subclass of AgentType should define class- 

1081 specific static values of the attributes time_vary and time_inv as lists of 

1082 strings. Each element of time_vary is the name of a field in AgentSubType 

1083 that varies over time in the model. Each element of time_inv is the name of 

1084 a field in AgentSubType that is constant over time in the model. 

1085 

1086 Parameters 

1087 ---------- 

1088 solution_terminal : Solution 

1089 A representation of the solution to the terminal period problem of 

1090 this AgentType instance, or an initial guess of the solution if this 

1091 is an infinite horizon problem. 

1092 cycles : int 

1093 The number of times the sequence of periods is experienced by this 

1094 AgentType in their "lifetime". cycles=1 corresponds to a lifecycle 

1095 model, with a certain sequence of one period problems experienced 

1096 once before terminating. cycles=0 corresponds to an infinite horizon 

1097 model, with a sequence of one period problems repeating indefinitely. 

1098 pseudo_terminal : bool 

1099 Indicates whether solution_terminal isn't actually part of the 

1100 solution to the problem (as a known solution to the terminal period 

1101 problem), but instead represents a "scrap value"-style termination. 

1102 When True, solution_terminal is not included in the solution; when 

1103 False, solution_terminal is the last element of the solution. 

1104 tolerance : float 

1105 Maximum acceptable "distance" between successive solutions to the 

1106 one period problem in an infinite horizon (cycles=0) model in order 

1107 for the solution to be considered as having "converged". Inoperative 

1108 when cycles>0. 

1109 verbose : int 

1110 Level of output to be displayed by this instance, default is 1. 

1111 quiet : bool 

1112 Indicator for whether this instance should operate "quietly", default False. 

1113 seed : int 

1114 A seed for this instance's random number generator. 

1115 construct : bool 

1116 Indicator for whether this instance's construct() method should be run 

1117 when initialized (default True). When False, an instance of the class 

1118 can be created even if not all of its attributes can be constructed. 

1119 use_defaults : bool 

1120 Indicator for whether this instance should use the values in the class' 

1121 default dictionary to fill in parameters and constructors for those not 

1122 provided by the user (default True). Setting this to False is useful for 

1123 situations where the user wants to be absolutely sure that they know what 

1124 is being passed to the class initializer, without resorting to defaults. 

1125 

1126 Attributes 

1127 ---------- 

1128 AgentCount : int 

1129 The number of agents of this type to use in simulation. 

1130 

1131 state_vars : list of string 

1132 The string labels for this AgentType's model state variables. 

1133 """ 

1134 

1135 time_vary_ = [] 

1136 time_inv_ = [] 

1137 shock_vars_ = [] 

1138 state_vars = [] 

1139 poststate_vars = [] 

1140 market_vars = [] 

1141 distributions = [] 

1142 default_ = {"params": {}, "solver": NullFunc()} 

1143 

1144 def __init__( 

1145 self, 

1146 solution_terminal=None, 

1147 pseudo_terminal=True, 

1148 tolerance=0.000001, 

1149 verbose=1, 

1150 quiet=False, 

1151 seed=0, 

1152 construct=True, 

1153 use_defaults=True, 

1154 **kwds, 

1155 ): 

1156 super().__init__() 

1157 params = deepcopy(self.default_["params"]) if use_defaults else {} 

1158 params.update(kwds) 

1159 

1160 # Correctly handle constructors that have been passed in kwds 

1161 if "constructors" in self.default_["params"].keys() and use_defaults: 

1162 constructors = deepcopy(self.default_["params"]["constructors"]) 

1163 else: 

1164 constructors = {} 

1165 if "constructors" in kwds.keys(): 

1166 constructors.update(kwds["constructors"]) 

1167 params["constructors"] = constructors 

1168 

1169 # Set default track_vars 

1170 if "track_vars" in self.default_.keys() and use_defaults: 

1171 self.track_vars = copy(self.default_["track_vars"]) 

1172 else: 

1173 self.track_vars = [] 

1174 

1175 # Set model file name if possible 

1176 try: 

1177 self.model_file = copy(self.default_["model"]) 

1178 except (KeyError, TypeError): 

1179 # Fallback to None if "model" key is missing or invalid for copying 

1180 self.model_file = None 

1181 

1182 if solution_terminal is None: 

1183 solution_terminal = NullFunc() 

1184 

1185 self.solve_one_period = self.default_["solver"] # NOQA 

1186 self.solution_terminal = solution_terminal # NOQA 

1187 self.pseudo_terminal = pseudo_terminal # NOQA 

1188 self.tolerance = tolerance # NOQA 

1189 self.verbose = verbose 

1190 self.quiet = quiet 

1191 self.seed = seed # NOQA 

1192 self.state_now = {sv: None for sv in self.state_vars} 

1193 self.state_prev = self.state_now.copy() 

1194 self.controls = {} 

1195 self.shocks = {} 

1196 self.read_shocks = False # NOQA 

1197 self.shock_history = {} 

1198 self.newborn_init_history = {} 

1199 self.history = {} 

1200 self.assign_parameters(**params) # NOQA 

1201 self.reset_rng() # NOQA 

1202 self.bilt = {} 

1203 if construct: 

1204 self.construct() 

1205 

1206 # Add instance-level lists and objects 

1207 self.time_vary = deepcopy(self.time_vary_) 

1208 self.time_inv = deepcopy(self.time_inv_) 

1209 self.shock_vars = deepcopy(self.shock_vars_) 

1210 

1211 def add_to_time_vary(self, *params): 

1212 """ 

1213 Adds any number of parameters to time_vary for this instance. 

1214 

1215 Parameters 

1216 ---------- 

1217 params : string 

1218 Any number of strings naming attributes to be added to time_vary 

1219 

1220 Returns 

1221 ------- 

1222 None 

1223 """ 

1224 for param in params: 

1225 if param not in self.time_vary: 

1226 self.time_vary.append(param) 

1227 

1228 def add_to_time_inv(self, *params): 

1229 """ 

1230 Adds any number of parameters to time_inv for this instance. 

1231 

1232 Parameters 

1233 ---------- 

1234 params : string 

1235 Any number of strings naming attributes to be added to time_inv 

1236 

1237 Returns 

1238 ------- 

1239 None 

1240 """ 

1241 for param in params: 

1242 if param not in self.time_inv: 

1243 self.time_inv.append(param) 

1244 

1245 def del_from_time_vary(self, *params): 

1246 """ 

1247 Removes any number of parameters from time_vary for this instance. 

1248 

1249 Parameters 

1250 ---------- 

1251 params : string 

1252 Any number of strings naming attributes to be removed from time_vary 

1253 

1254 Returns 

1255 ------- 

1256 None 

1257 """ 

1258 for param in params: 

1259 if param in self.time_vary: 

1260 self.time_vary.remove(param) 

1261 

1262 def del_from_time_inv(self, *params): 

1263 """ 

1264 Removes any number of parameters from time_inv for this instance. 

1265 

1266 Parameters 

1267 ---------- 

1268 params : string 

1269 Any number of strings naming attributes to be removed from time_inv 

1270 

1271 Returns 

1272 ------- 

1273 None 

1274 """ 

1275 for param in params: 

1276 if param in self.time_inv: 

1277 self.time_inv.remove(param) 

1278 

1279 def unpack(self, name): 

1280 """ 

1281 Unpacks an attribute from a solution object for easier access. 

1282 After the model has been solved, its components (like consumption function) 

1283 reside in the attributes of each element of `ThisType.solution` (e.g. `cFunc`). 

1284 This method creates a (time varying) attribute of the given attribute name 

1285 that contains a list of elements accessible by `ThisType.parameter`. 

1286 

1287 Parameters 

1288 ---------- 

1289 name: str 

1290 Name of the attribute to unpack from the solution 

1291 

1292 Returns 

1293 ------- 

1294 none 

1295 """ 

1296 # Use list comprehension for better performance instead of loop with append 

1297 if type(self.solution[0]) is dict: 

1298 setattr(self, name, [soln_t[name] for soln_t in self.solution]) 

1299 else: 

1300 setattr(self, name, [soln_t.__dict__[name] for soln_t in self.solution]) 

1301 self.add_to_time_vary(name) 

1302 

1303 def solve( 

1304 self, 

1305 verbose=False, 

1306 presolve=True, 

1307 postsolve=True, 

1308 from_solution=None, 

1309 from_t=None, 

1310 ): 

1311 """ 

1312 Solve the model for this instance of an agent type by backward induction. 

1313 Loops through the sequence of one period problems, passing the solution 

1314 from period t+1 to the problem for period t. 

1315 

1316 Parameters 

1317 ---------- 

1318 verbose : bool, optional 

1319 If True, solution progress is printed to screen. Default False. 

1320 presolve : bool, optional 

1321 If True (default), the pre_solve method is run before solving. 

1322 postsolve : bool, optional 

1323 If True (default), the post_solve method is run after solving. 

1324 from_solution: Solution 

1325 If different from None, will be used as the starting point of backward 

1326 induction, instead of self.solution_terminal. 

1327 from_t : int or None 

1328 If not None, indicates which period of the model the solver should start 

1329 from. It should usually only be used in combination with from_solution. 

1330 Stands for the time index that from_solution represents, and thus is 

1331 only compatible with cycles=1 and will be reset to None otherwise. 

1332 

1333 Returns 

1334 ------- 

1335 none 

1336 """ 

1337 

1338 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep- 

1339 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is 

1340 # -np.inf, np.inf/np.inf is np.nan and so on. 

1341 with np.errstate( 

1342 divide="ignore", over="ignore", under="ignore", invalid="ignore" 

1343 ): 

1344 if presolve: 

1345 self.pre_solve() # Do pre-solution stuff 

1346 self.solution = solve_agent( 

1347 self, 

1348 verbose, 

1349 from_solution, 

1350 from_t, 

1351 ) # Solve the model by backward induction 

1352 if postsolve: 

1353 self.post_solve() # Do post-solution stuff 

1354 

1355 def reset_rng(self): 

1356 """ 

1357 Reset the random number generator and all distributions for this type. 

1358 Type-checking for lists is to handle the following three cases: 

1359 

1360 1) The target is a single distribution object 

1361 2) The target is a list of distribution objects (probably time-varying) 

1362 3) The target is a nested list of distributions, as in ConsMarkovModel. 

1363 """ 

1364 self.RNG = np.random.default_rng(self.seed) 

1365 for name in self.distributions: 

1366 if not hasattr(self, name): 

1367 continue 

1368 

1369 dstn = getattr(self, name) 

1370 if isinstance(dstn, list): 

1371 for D in dstn: 

1372 if isinstance(D, list): 

1373 for d in D: 

1374 d.reset() 

1375 else: 

1376 D.reset() 

1377 else: 

1378 dstn.reset() 

1379 

1380 def check_elements_of_time_vary_are_lists(self): 

1381 """ 

1382 A method to check that elements of time_vary are lists. 

1383 """ 

1384 for param in self.time_vary: 

1385 if not hasattr(self, param): 

1386 continue 

1387 if not isinstance( 

1388 getattr(self, param), 

1389 (IndexDistribution,), 

1390 ): 

1391 assert type(getattr(self, param)) == list, ( 

1392 param 

1393 + " is not a list or time varying distribution," 

1394 + " but should be because it is in time_vary" 

1395 ) 

1396 

1397 def check_restrictions(self): 

1398 """ 

1399 A method to check that various restrictions are met for the model class. 

1400 """ 

1401 return 

1402 

1403 def pre_solve(self): 

1404 """ 

1405 A method that is run immediately before the model is solved, to check inputs or to prepare 

1406 the terminal solution, perhaps. 

1407 

1408 Parameters 

1409 ---------- 

1410 none 

1411 

1412 Returns 

1413 ------- 

1414 none 

1415 """ 

1416 self.check_restrictions() 

1417 self.check_elements_of_time_vary_are_lists() 

1418 return None 

1419 

1420 def post_solve(self): 

1421 """ 

1422 A method that is run immediately after the model is solved, to finalize 

1423 the solution in some way. Does nothing here. 

1424 

1425 Parameters 

1426 ---------- 

1427 none 

1428 

1429 Returns 

1430 ------- 

1431 none 

1432 """ 

1433 return None 

1434 

1435 def get_market_params(self, mkt, construct=True): 

1436 """ 

1437 Fetch data named in class attribute market_vars and assign it as attributes 

1438 (and parameters) of self. By default, the construct method is run within 

1439 this method, because the market parameters often have information needed to 

1440 "complete" the microeconomic problem. 

1441 

1442 This method is called automatically by the Market.give_agent_params() 

1443 method for all agents. 

1444 

1445 Parameters 

1446 ---------- 

1447 mkt : Market 

1448 Market to which this AgentType belongs. 

1449 construct : bool 

1450 Indicator for whether constructed attributes should be updated after 

1451 fetching data / parameters from mkt (default True) 

1452 

1453 Returns 

1454 ------- 

1455 None 

1456 """ 

1457 temp_dict = {} 

1458 for name in self.market_vars: 

1459 temp_dict[name] = copy(getattr(mkt, name)) 

1460 self.assign_parameters(**temp_dict) 

1461 if construct: 

1462 self.construct() 

1463 

1464 def initialize_sym(self, **kwargs): 

1465 """ 

1466 Use the new simulator structure to build a simulator from the agents' 

1467 attributes, storing it in a private attribute. 

1468 """ 

1469 self.reset_rng() # ensure seeds are set identically each time 

1470 self._simulator = make_simulator_from_agent(self, **kwargs) 

1471 self._simulator.reset() 

1472 

1473 def find_target(self, target_var, force_list=False, **kwargs): 

1474 """ 

1475 Find the "target" level of a named variable such that E[\Delta x] = 0, 

1476 with E[\Delta x-eps] > 0 and E[\Delta x+eps] < 0 (locally stable). 

1477 Returns a single real value if there is only one target, and a list if multiple; 

1478 returns np.nan if no target is found. Pass force_list=True to always get a list. 

1479 See documentation for HARK.simulator.find_target_state for more options. 

1480 """ 

1481 if not hasattr(self, "solution"): 

1482 raise AttributeError("Model must be solved before using find_target!") 

1483 temp_simulator = make_simulator_from_agent(self) 

1484 target_vals = temp_simulator.find_target_state(target_var, **kwargs) 

1485 if force_list: 

1486 return target_vals 

1487 if len(target_vals) == 0: 

1488 return np.nan 

1489 elif len(target_vals) == 1: 

1490 return target_vals[0] 

1491 else: 

1492 return target_vals 

1493 

1494 def initialize_sim(self): 

1495 """ 

1496 Prepares this AgentType for a new simulation. Resets the internal random number generator, 

1497 makes initial states for all agents (using sim_birth), clears histories of tracked variables. 

1498 

1499 Parameters 

1500 ---------- 

1501 None 

1502 

1503 Returns 

1504 ------- 

1505 None 

1506 """ 

1507 if not hasattr(self, "T_sim"): 

1508 raise Exception( 

1509 "To initialize simulation variables it is necessary to first " 

1510 + "set the attribute T_sim to the largest number of observations " 

1511 + "you plan to simulate for each agent including re-births." 

1512 ) 

1513 elif self.T_sim <= 0: 

1514 raise Exception( 

1515 "T_sim represents the largest number of observations " 

1516 + "that can be simulated for an agent, and must be a positive number." 

1517 ) 

1518 

1519 self.reset_rng() 

1520 self.t_sim = 0 

1521 all_agents = np.ones(self.AgentCount, dtype=bool) 

1522 blank_array = np.empty(self.AgentCount) 

1523 blank_array[:] = np.nan 

1524 for var in self.state_vars: 

1525 self.state_now[var] = copy(blank_array) 

1526 

1527 # Number of periods since agent entry 

1528 self.t_age = np.zeros(self.AgentCount, dtype=int) 

1529 # Which cycle period each agent is on 

1530 self.t_cycle = np.zeros(self.AgentCount, dtype=int) 

1531 self.sim_birth(all_agents) 

1532 

1533 # If we are asked to use existing shocks and a set of initial conditions 

1534 # exist, use them 

1535 if self.read_shocks and bool(self.newborn_init_history): 

1536 for var_name in self.state_now: 

1537 # Check that we are actually given a value for the variable 

1538 if var_name in self.newborn_init_history.keys(): 

1539 # Copy only array-like idiosyncratic states. Aggregates should 

1540 # not be set by newborns 

1541 idio = ( 

1542 isinstance(self.state_now[var_name], np.ndarray) 

1543 and len(self.state_now[var_name]) == self.AgentCount 

1544 ) 

1545 if idio: 

1546 self.state_now[var_name] = self.newborn_init_history[var_name][ 

1547 0 

1548 ] 

1549 

1550 else: 

1551 warn( 

1552 "The option for reading shocks was activated but " 

1553 + "the model requires state " 

1554 + var_name 

1555 + ", not contained in " 

1556 + "newborn_init_history." 

1557 ) 

1558 

1559 self.clear_history() 

1560 

1561 def _export_single_var_by_time(self, history, var, t, dtype): 

1562 """ 

1563 Mode 1a: single variable, by_age=False. 

1564 

1565 Returns a DataFrame whose columns are simulation periods (t_sim) and 

1566 whose rows are agent indices. When t is None all T_sim periods are 

1567 included; when t is an array only those periods are included. 

1568 """ 

1569 try: 

1570 data = history[var] 

1571 except KeyError: 

1572 raise KeyError("Variable named " + var + " not found in simulated data!") 

1573 

1574 if t is None: 

1575 cols = [str(i) for i in range(self.T_sim)] 

1576 df = DataFrame(data=data.T, columns=cols, dtype=dtype) 

1577 else: 

1578 cols = [str(t_val) for t_val in t] 

1579 df = DataFrame(data=data[t, :].T, columns=cols, dtype=dtype) 

1580 return df 

1581 

1582 def _export_single_var_by_age(self, history, var, age, t, dtype, sym): 

1583 """ 

1584 Mode 1b: single variable, by_age=True. 

1585 

1586 Returns a DataFrame whose columns are within-agent model ages (t_age) 

1587 and whose rows are individual agent lifetimes. Observations after 

1588 death (or before birth) are NaN. When t is None all ages up to 

1589 max(t_age) are included; when t is an array only those ages are used. 

1590 The sym flag controls newborn detection: age == 0 when sym is True, 

1591 age == 1 when sym is False. 

1592 """ 

1593 try: 

1594 data = history[var] 

1595 except KeyError: 

1596 raise KeyError("Variable named " + var + " not found in simulated data!") 

1597 

1598 # Determine which ages to include and mark qualifying observations 

1599 if t is None: 

1600 age_set = np.arange(np.max(age) + 1) 

1601 in_age_set = np.ones_like(data, dtype=bool) 

1602 else: 

1603 age_set = t 

1604 in_age_set = np.zeros_like(data, dtype=bool) 

1605 for j in age_set: 

1606 these = age == j 

1607 in_age_set[these] = True 

1608 

1609 # Locate newborns to determine the number of individual lifetimes (rows) 

1610 newborns = age == 0 if sym else age == 1 

1611 T = age_set.size # number of age columns 

1612 N = np.sum(newborns) # number of agent lifetimes (rows) 

1613 out = np.full((N, T), np.nan) 

1614 

1615 # Extract each individual's sequence and place it into the output array 

1616 n = 0 

1617 for i in range(self.AgentCount): 

1618 data_i = data[:, i] 

1619 births = np.where(newborns[:, i])[0] 

1620 K = births.size 

1621 for k in range(K): 

1622 start = births[k] 

1623 stop = births[k + 1] if (k < K - 1) else self.T_sim 

1624 use = in_age_set[start:stop, i] 

1625 temp = data_i[start:stop][use] 

1626 out[n, : temp.size] = temp 

1627 n += 1 

1628 

1629 cols = [str(a) for a in age_set] 

1630 df = DataFrame(data=out, columns=cols, dtype=dtype) 

1631 return df 

1632 

1633 def _export_single_t_by_time(self, history, var_list, t, dtype): 

1634 """ 

1635 Mode 2a: single time period, by_age=False. 

1636 

1637 Returns a DataFrame with one row per agent and one column per variable, 

1638 drawn from the absolute simulation period t (i.e. history[name][t, :]). 

1639 """ 

1640 K = len(var_list) 

1641 N = self.AgentCount 

1642 out = np.full((N, K), np.nan) 

1643 for k, name in enumerate(var_list): 

1644 out[:, k] = history[name][t, :] 

1645 df = DataFrame(data=out, columns=var_list, dtype=dtype) 

1646 return df 

1647 

1648 def _export_single_t_by_age(self, history, var_list, age, t, dtype): 

1649 """ 

1650 Mode 2b: single time period, by_age=True. 

1651 

1652 Returns a DataFrame with one row per agent-period at which t_age == t 

1653 and one column per variable. 

1654 """ 

1655 right_age = age == t 

1656 N = np.sum(right_age) 

1657 K = len(var_list) 

1658 out = np.full((N, K), np.nan) 

1659 for k, name in enumerate(var_list): 

1660 out[:, k] = history[name][right_age] 

1661 df = DataFrame(data=out, columns=var_list, dtype=dtype) 

1662 return df 

1663 

1664 def export_to_df(self, var=None, t=None, by_age=False, dtype=None, sym=False): 

1665 """ 

1666 Export an AgentType instance's simulated data to a pandas dataframe object. 

1667 There are four construction modes depending on the arguments passed: 

1668 

1669 1a) If exactly one simulated variable is named as var and by_age is False, 

1670 then the dataframe will contain T_sim columns, each representing one 

1671 simulated period in absolute simulation time t_sim. Each row of the 

1672 dataframe will represent one *agent index* of the population, with death 

1673 and replacement occurring within a row. Optionally, argument t can be 

1674 provided as an array to specify which periods to include (default all). 

1675 

1676 1b) If exactly one simulated variable is named as var and by_age is True, 

1677 then the dataframe's columns will correspond to within-agent model age 

1678 t_age. Each row of the dataframe will represent one specific agent from 

1679 model entry (t_age=0) to model death. All observations after death will 

1680 be NaN. Optionally, argument t can be provided as an array to specify 

1681 which ages to include (default all). Number of columns in dataframe will 

1682 depend on max(t_age) and/or argument t. 

1683 

1684 2a) If an integer is provided as t and by_age is False, then each column of 

1685 the dataframe will represent a different simulated variable, using the 

1686 value for the specified absolute simulated period t=t_sim. Optionally, 

1687 the var argument can be provided as a list of strings naming which var- 

1688 iables should be included in the dataframe (default all). 

1689 

1690 2b) If an integer is provided as t and by_age is True, then each column of 

1691 the dataframe will represent a different simulated variable, taken from 

1692 all agent-periods at which t == t_age, within-agent model age. Optionally, 

1693 the var argument can be provided as a list of strings naming which var- 

1694 iables should be included in the dataframe (default all). 

1695 

1696 In summary, *either* var should be a single string *or* t should be an integer. 

1697 Any other combination of var and t will raise an exception. 

1698 

1699 Parameters 

1700 ---------- 

1701 var : str or [str] or None 

1702 If a single string is provided, it represents the name of the one simulated 

1703 variable to export. If a list of strings, then the argument t must also be 

1704 provided to indicate which time period the dataframe will represent. Name(s) 

1705 must correspond to a key for history or hystory dictionary (i.e. named in track_vars). 

1706 If not provided, then all keys in history or hystory are included. 

1707 t : int or np.array or None 

1708 If an integer, indicates which one period will be included in the dataframe. 

1709 When by_age is False (default), t refers to absolute simulated time t_sim: 

1710 literally the t-th row of history[key]. When by_age is True, t refers to 

1711 within-agent model age t_age; the dataframe will include all agent-periods 

1712 where the agent has exactly t_age==t. If var is a single string, then t is 

1713 an optional input as an array of periods (or ages) to include (default all). 

1714 by_age : bool 

1715 Indicator for whether observation selection should be on the basis of absolute 

1716 simulated time t_sim or within-agent model age t_age. If True, then t_age 

1717 must be in track_vars so that it appears in the simulated data. Additionally, 

1718 argument dtype should *not* be provided when by_age is True, as this will 

1719 result in NaNs being cast to a datatype that doesn't necessarily support them. 

1720 dtype : type or None 

1721 Optional data type to cast the dataframe. By default, uses the datatype from 

1722 the entry in history or hystory. 

1723 sym : bool 

1724 Indicator for whether the dataframe should look for simulated data in the 

1725 history (False, default) or hystory (True) dictionary attribute. This option 

1726 will be deprecated in the future when legacy simulation methods are removed. 

1727 

1728 Returns 

1729 ------- 

1730 df : pandas.DataFrame 

1731 The requested dataframe, constructed from this instance's simulated data. 

1732 """ 

1733 # Validate arguments 

1734 single_var = type(var) is str 

1735 single_t = isinstance(t, (int, np.integer)) 

1736 if not (single_var ^ single_t): 

1737 raise ValueError( 

1738 "Either var must be a single string, or t must be a single integer!" 

1739 ) 

1740 if dtype is not None and by_age: 

1741 raise ValueError( 

1742 "Can't specify dtype when using by_age is True because of potential incompatibility with representing NaN" 

1743 ) 

1744 

1745 # Get the relevant history dictionary (deprecate in future) 

1746 history = self.hystory if sym else self.history 

1747 

1748 # Retrieve age array once if needed (raises a clear error when missing) 

1749 if by_age: 

1750 try: 

1751 age = history["t_age"] 

1752 except KeyError: 

1753 raise KeyError( 

1754 "t_age must be in track_vars if by_age=True will be used!" 

1755 ) 

1756 

1757 # Route to the appropriate private method 

1758 if single_var and not by_age: 

1759 return self._export_single_var_by_time(history, var, t, dtype) 

1760 elif single_var and by_age: 

1761 return self._export_single_var_by_age(history, var, age, t, dtype, sym) 

1762 else: # single_t 

1763 # Build and validate the variable list 

1764 if var is None: 

1765 var_list = list(history.keys()) 

1766 else: 

1767 var_list = copy(var) 

1768 sim_keys = list(history.keys()) 

1769 for name in var_list: 

1770 if name not in sim_keys: 

1771 raise KeyError( 

1772 "Variable called " + name + " not found in simulation data!" 

1773 ) 

1774 if by_age: 

1775 return self._export_single_t_by_age(history, var_list, age, t, dtype) 

1776 else: 

1777 return self._export_single_t_by_time(history, var_list, t, dtype) 

1778 

1779 def sim_one_period(self): 

1780 """ 

1781 Simulates one period for this type. Calls the methods get_mortality(), get_shocks() or 

1782 read_shocks, get_states(), get_controls(), and get_poststates(). These should be defined for 

1783 AgentType subclasses, except get_mortality (define its components sim_death and sim_birth 

1784 instead) and read_shocks. 

1785 

1786 Parameters 

1787 ---------- 

1788 None 

1789 

1790 Returns 

1791 ------- 

1792 None 

1793 """ 

1794 if not hasattr(self, "solution"): 

1795 raise Exception( 

1796 "Model instance does not have a solution stored. To simulate, it is necessary" 

1797 " to run the `solve()` method first." 

1798 ) 

1799 

1800 # Mortality adjusts the agent population 

1801 self.get_mortality() # Replace some agents with "newborns" 

1802 

1803 # state_{t-1} 

1804 for var in self.state_now: 

1805 self.state_prev[var] = self.state_now[var] 

1806 

1807 if isinstance(self.state_now[var], np.ndarray): 

1808 self.state_now[var] = np.empty(self.AgentCount) 

1809 else: 

1810 # Probably an aggregate variable. It may be getting set by the Market. 

1811 pass 

1812 

1813 if self.read_shocks: # If shock histories have been pre-specified, use those 

1814 self.read_shocks_from_history() 

1815 else: # Otherwise, draw shocks as usual according to subclass-specific method 

1816 self.get_shocks() 

1817 self.get_states() # Determine each agent's state at decision time 

1818 self.get_controls() # Determine each agent's choice or control variables based on states 

1819 self.get_poststates() # Calculate variables that come *after* decision-time 

1820 

1821 # Advance time for all agents 

1822 self.t_age = self.t_age + 1 # Age all consumers by one period 

1823 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle 

1824 self.t_cycle[self.t_cycle == self.T_cycle] = ( 

1825 0 # Resetting to zero for those who have reached the end 

1826 ) 

1827 

1828 def make_shock_history(self): 

1829 """ 

1830 Makes a pre-specified history of shocks for the simulation. Shock variables should be named 

1831 in self.shock_vars, a list of strings that is subclass-specific. This method runs a subset 

1832 of the standard simulation loop by simulating only mortality and shocks; each variable named 

1833 in shock_vars is stored in a T_sim x AgentCount array in history dictionary self.history[X]. 

1834 Automatically sets self.read_shocks to True so that these pre-specified shocks are used for 

1835 all subsequent calls to simulate(). 

1836 

1837 Parameters 

1838 ---------- 

1839 None 

1840 

1841 Returns 

1842 ------- 

1843 None 

1844 """ 

1845 # Re-initialize the simulation 

1846 self.initialize_sim() 

1847 

1848 # Make blank history arrays for each shock variable (and mortality) 

1849 for var_name in self.shock_vars: 

1850 self.shock_history[var_name] = ( 

1851 np.zeros((self.T_sim, self.AgentCount)) + np.nan 

1852 ) 

1853 self.shock_history["who_dies"] = np.zeros( 

1854 (self.T_sim, self.AgentCount), dtype=bool 

1855 ) 

1856 

1857 # Also make blank arrays for the draws of newborns' initial conditions 

1858 for var_name in self.state_vars: 

1859 self.newborn_init_history[var_name] = ( 

1860 np.zeros((self.T_sim, self.AgentCount)) + np.nan 

1861 ) 

1862 

1863 # Record the initial condition of the newborns created by 

1864 # initialize_sim -> sim_births 

1865 for var_name in self.state_vars: 

1866 # Check whether the state is idiosyncratic or an aggregate 

1867 idio = ( 

1868 isinstance(self.state_now[var_name], np.ndarray) 

1869 and len(self.state_now[var_name]) == self.AgentCount 

1870 ) 

1871 if idio: 

1872 self.newborn_init_history[var_name][self.t_sim] = self.state_now[ 

1873 var_name 

1874 ] 

1875 else: 

1876 # Aggregate state is a scalar. Assign it to every agent. 

1877 self.newborn_init_history[var_name][self.t_sim, :] = self.state_now[ 

1878 var_name 

1879 ] 

1880 

1881 # Make and store the history of shocks for each period 

1882 for t in range(self.T_sim): 

1883 # Deaths 

1884 self.get_mortality() 

1885 self.shock_history["who_dies"][t, :] = self.who_dies 

1886 

1887 # Initial conditions of newborns 

1888 if self.who_dies.any(): 

1889 for var_name in self.state_vars: 

1890 # Check whether the state is idiosyncratic or an aggregate 

1891 idio = ( 

1892 isinstance(self.state_now[var_name], np.ndarray) 

1893 and len(self.state_now[var_name]) == self.AgentCount 

1894 ) 

1895 if idio: 

1896 self.newborn_init_history[var_name][t, self.who_dies] = ( 

1897 self.state_now[var_name][self.who_dies] 

1898 ) 

1899 else: 

1900 self.newborn_init_history[var_name][t, self.who_dies] = ( 

1901 self.state_now[var_name] 

1902 ) 

1903 

1904 # Other Shocks 

1905 self.get_shocks() 

1906 for var_name in self.shock_vars: 

1907 self.shock_history[var_name][t, :] = self.shocks[var_name] 

1908 

1909 self.t_sim += 1 

1910 self.t_age = self.t_age + 1 # Age all consumers by one period 

1911 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle 

1912 self.t_cycle[self.t_cycle == self.T_cycle] = ( 

1913 0 # Resetting to zero for those who have reached the end 

1914 ) 

1915 

1916 # Flag that shocks can be read rather than simulated 

1917 self.read_shocks = True 

1918 

1919 def get_mortality(self): 

1920 """ 

1921 Simulates mortality or agent turnover according to some model-specific rules named sim_death 

1922 and sim_birth (methods of an AgentType subclass). sim_death takes no arguments and returns 

1923 a Boolean array of size AgentCount, indicating which agents of this type have "died" and 

1924 must be replaced. sim_birth takes such a Boolean array as an argument and generates initial 

1925 post-decision states for those agent indices. 

1926 

1927 Parameters 

1928 ---------- 

1929 None 

1930 

1931 Returns 

1932 ------- 

1933 None 

1934 """ 

1935 if self.read_shocks: 

1936 who_dies = self.shock_history["who_dies"][self.t_sim, :] 

1937 # Instead of simulating births, assign the saved newborn initial conditions 

1938 if who_dies.any(): 

1939 for var_name in self.state_now: 

1940 if var_name in self.newborn_init_history.keys(): 

1941 # Copy only array-like idiosyncratic states. Aggregates should 

1942 # not be set by newborns 

1943 idio = ( 

1944 isinstance(self.state_now[var_name], np.ndarray) 

1945 and len(self.state_now[var_name]) == self.AgentCount 

1946 ) 

1947 if idio: 

1948 self.state_now[var_name][who_dies] = ( 

1949 self.newborn_init_history[var_name][ 

1950 self.t_sim, who_dies 

1951 ] 

1952 ) 

1953 

1954 else: 

1955 warn( 

1956 "The option for reading shocks was activated but " 

1957 + "the model requires state " 

1958 + var_name 

1959 + ", not contained in " 

1960 + "newborn_init_history." 

1961 ) 

1962 

1963 # Reset ages of newborns 

1964 self.t_age[who_dies] = 0 

1965 self.t_cycle[who_dies] = 0 

1966 else: 

1967 who_dies = self.sim_death() 

1968 self.sim_birth(who_dies) 

1969 self.who_dies = who_dies 

1970 return None 

1971 

1972 def sim_death(self): 

1973 """ 

1974 Determines which agents in the current population "die" or should be replaced. Takes no 

1975 inputs, returns a Boolean array of size self.AgentCount, which has True for agents who die 

1976 and False for those that survive. Returns all False by default, must be overwritten by a 

1977 subclass to have replacement events. 

1978 

1979 Parameters 

1980 ---------- 

1981 None 

1982 

1983 Returns 

1984 ------- 

1985 who_dies : np.array 

1986 Boolean array of size self.AgentCount indicating which agents die and are replaced. 

1987 """ 

1988 who_dies = np.zeros(self.AgentCount, dtype=bool) 

1989 return who_dies 

1990 

1991 def sim_birth(self, which_agents): # pragma: nocover 

1992 """ 

1993 Makes new agents for the simulation. Takes a boolean array as an input, indicating which 

1994 agent indices are to be "born". Does nothing by default, must be overwritten by a subclass. 

1995 

1996 Parameters 

1997 ---------- 

1998 which_agents : np.array(Bool) 

1999 Boolean array of size self.AgentCount indicating which agents should be "born". 

2000 

2001 Returns 

2002 ------- 

2003 None 

2004 """ 

2005 raise Exception("AgentType subclass must define method sim_birth!") 

2006 

2007 def get_shocks(self): # pragma: nocover 

2008 """ 

2009 Gets values of shock variables for the current period. Does nothing by default, but can 

2010 be overwritten by subclasses of AgentType. 

2011 

2012 Parameters 

2013 ---------- 

2014 None 

2015 

2016 Returns 

2017 ------- 

2018 None 

2019 """ 

2020 return None 

2021 

2022 def read_shocks_from_history(self): 

2023 """ 

2024 Reads values of shock variables for the current period from history arrays. 

2025 For each variable X named in self.shock_vars, this attribute of self is 

2026 set to self.history[X][self.t_sim,:]. 

2027 

2028 This method is only ever called if self.read_shocks is True. This can 

2029 be achieved by using the method make_shock_history() (or manually after 

2030 storing a "handcrafted" shock history). 

2031 

2032 Parameters 

2033 ---------- 

2034 None 

2035 

2036 Returns 

2037 ------- 

2038 None 

2039 """ 

2040 for var_name in self.shock_vars: 

2041 self.shocks[var_name] = self.shock_history[var_name][self.t_sim, :] 

2042 

2043 def get_states(self): 

2044 """ 

2045 Gets values of state variables for the current period. 

2046 By default, calls transition function and assigns values 

2047 to the state_now dictionary. 

2048 

2049 Parameters 

2050 ---------- 

2051 None 

2052 

2053 Returns 

2054 ------- 

2055 None 

2056 """ 

2057 new_states = self.transition() 

2058 

2059 for i, var in enumerate(self.state_now): 

2060 # a hack for now to deal with 'post-states' 

2061 if i < len(new_states): 

2062 self.state_now[var] = new_states[i] 

2063 

2064 def transition(self): # pragma: nocover 

2065 """ 

2066 

2067 Parameters 

2068 ---------- 

2069 None 

2070 

2071 [Eventually, to match dolo spec: 

2072 exogenous_prev, endogenous_prev, controls, exogenous, parameters] 

2073 

2074 Returns 

2075 ------- 

2076 

2077 endogenous_state: () 

2078 Tuple with new values of the endogenous states 

2079 """ 

2080 return () 

2081 

2082 def get_controls(self): # pragma: nocover 

2083 """ 

2084 Gets values of control variables for the current period, probably by using current states. 

2085 Does nothing by default, but can be overwritten by subclasses of AgentType. 

2086 

2087 Parameters 

2088 ---------- 

2089 None 

2090 

2091 Returns 

2092 ------- 

2093 None 

2094 """ 

2095 return None 

2096 

2097 def get_poststates(self): 

2098 """ 

2099 Gets values of post-decision state variables for the current period, 

2100 probably by current 

2101 states and controls and maybe market-level events or shock variables. 

2102 Does nothing by 

2103 default, but can be overwritten by subclasses of AgentType. 

2104 

2105 Parameters 

2106 ---------- 

2107 None 

2108 

2109 Returns 

2110 ------- 

2111 None 

2112 """ 

2113 return None 

2114 

2115 def symulate(self, T=None): 

2116 """ 

2117 Run the new simulation structure, with history results written to the 

2118 hystory attribute of self. 

2119 """ 

2120 self._simulator.simulate(T) 

2121 self.hystory = self._simulator.history 

2122 

2123 def describe_model(self, display=True): 

2124 """ 

2125 Print to screen information about this agent's model, based on its model 

2126 file. This is useful for learning about outcome variable names for tracking 

2127 during simulation, or for use with sequence space Jacobians. 

2128 """ 

2129 if not hasattr(self, "_simulator"): 

2130 self.initialize_sym() 

2131 self._simulator.describe(display=display) 

2132 

2133 def simulate(self, sim_periods=None): 

2134 """ 

2135 Simulates this agent type for a given number of periods. Defaults to self.T_sim, 

2136 or all remaining periods to simulate (T_sim - t_sim). Records histories of 

2137 attributes named in self.track_vars in self.history[varname]. 

2138 

2139 Parameters 

2140 ---------- 

2141 sim_periods : int or None 

2142 Number of periods to simulate. Default is all remaining periods (usually T_sim). 

2143 

2144 Returns 

2145 ------- 

2146 history : dict 

2147 The history tracked during the simulation. 

2148 """ 

2149 if not hasattr(self, "t_sim"): 

2150 raise Exception( 

2151 "It seems that the simulation variables were not initialize before calling " 

2152 + "simulate(). Call initialize_sim() to initialize the variables before calling simulate() again." 

2153 ) 

2154 

2155 if not hasattr(self, "T_sim"): 

2156 raise Exception( 

2157 "This agent type instance must have the attribute T_sim set to a positive integer." 

2158 + "Set T_sim to match the largest dataset you might simulate, and run this agent's" 

2159 + "initialize_sim() method before running simulate() again." 

2160 ) 

2161 

2162 if sim_periods is not None and self.T_sim < sim_periods: 

2163 raise Exception( 

2164 "To simulate, sim_periods has to be larger than the maximum data set size " 

2165 + "T_sim. Either increase the attribute T_sim of this agent type instance " 

2166 + "and call the initialize_sim() method again, or set sim_periods <= T_sim." 

2167 ) 

2168 

2169 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep- 

2170 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is 

2171 # -np.inf, np.inf/np.inf is np.nan and so on. 

2172 with np.errstate( 

2173 divide="ignore", over="ignore", under="ignore", invalid="ignore" 

2174 ): 

2175 if sim_periods is None: 

2176 sim_periods = self.T_sim - self.t_sim 

2177 

2178 for t in range(sim_periods): 

2179 self.sim_one_period() 

2180 

2181 for var_name in self.track_vars: 

2182 if var_name in self.state_now: 

2183 self.history[var_name][self.t_sim, :] = self.state_now[var_name] 

2184 elif var_name in self.shocks: 

2185 self.history[var_name][self.t_sim, :] = self.shocks[var_name] 

2186 elif var_name in self.controls: 

2187 self.history[var_name][self.t_sim, :] = self.controls[var_name] 

2188 else: 

2189 if var_name == "who_dies" and self.t_sim > 1: 

2190 self.history[var_name][self.t_sim - 1, :] = getattr( 

2191 self, var_name 

2192 ) 

2193 else: 

2194 self.history[var_name][self.t_sim, :] = getattr( 

2195 self, var_name 

2196 ) 

2197 self.t_sim += 1 

2198 

2199 def clear_history(self): 

2200 """ 

2201 Clears the histories of the attributes named in self.track_vars. 

2202 

2203 Parameters 

2204 ---------- 

2205 None 

2206 

2207 Returns 

2208 ------- 

2209 None 

2210 """ 

2211 for var_name in self.track_vars: 

2212 self.history[var_name] = np.empty((self.T_sim, self.AgentCount)) 

2213 self.history[var_name].fill(np.nan) 

2214 

2215 def make_basic_SSJ(self, shock, outcomes, grids, **kwargs): 

2216 """ 

2217 Construct and return sequence space Jacobian matrices for specified outcomes 

2218 with respect to specified "shock" variable. This "basic" method only works 

2219 for "one period infinite horizon" models (cycles=0, T_cycle=1). See documen- 

2220 tation for simulator.make_basic_SSJ_matrices for more information. 

2221 """ 

2222 return make_basic_SSJ_matrices(self, shock, outcomes, grids, **kwargs) 

2223 

2224 def calc_impulse_response_manually(self, shock, outcomes, grids, **kwargs): 

2225 """ 

2226 Calculate and return the impulse response(s) of a perturbation to the shock 

2227 parameter in period t=s, essentially computing one column of the sequence 

2228 space Jacobian matrix manually. This "basic" method only works for "one 

2229 period infinite horizon" models (cycles=0, T_cycle=1). See documentation 

2230 for simulator.calc_shock_response_manually for more information. 

2231 """ 

2232 return calc_shock_response_manually(self, shock, outcomes, grids, **kwargs) 

2233 

2234 

2235def solve_agent(agent, verbose, from_solution=None, from_t=None): 

2236 """ 

2237 Solve the dynamic model for one agent type using backwards induction. This 

2238 function iterates on "cycles" of an agent's model either a given number of 

2239 times or until solution convergence if an infinite horizon model is used 

2240 (with agent.cycles = 0). 

2241 

2242 Parameters 

2243 ---------- 

2244 agent : AgentType 

2245 The microeconomic AgentType whose dynamic problem 

2246 is to be solved. 

2247 verbose : boolean 

2248 If True, solution progress is printed to screen (when cycles != 1). 

2249 from_solution: Solution 

2250 If different from None, will be used as the starting point of backward 

2251 induction, instead of self.solution_terminal 

2252 from_t : int or None 

2253 If not None, indicates which period of the model the solver should start 

2254 from. It should usually only be used in combination with from_solution. 

2255 Stands for the time index that from_solution represents, and thus is 

2256 only compatible with cycles=1 and will be reset to None otherwise. 

2257 

2258 Returns 

2259 ------- 

2260 solution : [Solution] 

2261 A list of solutions to the one period problems that the agent will 

2262 encounter in his "lifetime". 

2263 """ 

2264 # Check to see whether this is an (in)finite horizon problem 

2265 cycles_left = agent.cycles # NOQA 

2266 infinite_horizon = cycles_left == 0 # NOQA 

2267 

2268 if from_solution is None: 

2269 solution_last = agent.solution_terminal # NOQA 

2270 else: 

2271 solution_last = from_solution 

2272 if agent.cycles != 1: 

2273 from_t = None 

2274 

2275 # Initialize the solution, which includes the terminal solution if it's not a pseudo-terminal period 

2276 solution = [] 

2277 if not agent.pseudo_terminal: 

2278 solution.insert(0, deepcopy(solution_last)) 

2279 

2280 # Initialize the process, then loop over cycles 

2281 go = True # NOQA 

2282 completed_cycles = 0 # NOQA 

2283 max_cycles = 5000 # NOQA - escape clause 

2284 if verbose: 

2285 t_last = time() 

2286 while go: 

2287 # Solve a cycle of the model, recording it if horizon is finite 

2288 solution_cycle = solve_one_cycle(agent, solution_last, from_t) 

2289 if not infinite_horizon: 

2290 solution = solution_cycle + solution 

2291 

2292 # Check for termination: identical solutions across 

2293 # cycle iterations or run out of cycles 

2294 solution_now = solution_cycle[0] 

2295 if infinite_horizon: 

2296 if completed_cycles > 0: 

2297 solution_distance = distance_metric(solution_now, solution_last) 

2298 agent.solution_distance = ( 

2299 solution_distance # Add these attributes so users can 

2300 ) 

2301 agent.completed_cycles = ( 

2302 completed_cycles # query them to see if solution is ready 

2303 ) 

2304 go = ( 

2305 solution_distance > agent.tolerance 

2306 and completed_cycles < max_cycles 

2307 ) 

2308 else: # Assume solution does not converge after only one cycle 

2309 solution_distance = 100.0 

2310 go = True 

2311 else: 

2312 cycles_left += -1 

2313 go = cycles_left > 0 

2314 

2315 # Update the "last period solution" 

2316 solution_last = solution_now 

2317 completed_cycles += 1 

2318 

2319 # Display progress if requested 

2320 if verbose: 

2321 t_now = time() 

2322 if infinite_horizon: 

2323 print( 

2324 "Finished cycle #" 

2325 + str(completed_cycles) 

2326 + " in " 

2327 + str(t_now - t_last) 

2328 + " seconds, solution distance = " 

2329 + str(solution_distance) 

2330 ) 

2331 else: 

2332 print( 

2333 "Finished cycle #" 

2334 + str(completed_cycles) 

2335 + " of " 

2336 + str(agent.cycles) 

2337 + " in " 

2338 + str(t_now - t_last) 

2339 + " seconds." 

2340 ) 

2341 t_last = t_now 

2342 

2343 # Record the last cycle if horizon is infinite (solution is still empty!) 

2344 if infinite_horizon: 

2345 solution = ( 

2346 solution_cycle # PseudoTerminal=False impossible for infinite horizon 

2347 ) 

2348 

2349 return solution 

2350 

2351 

2352def solve_one_cycle(agent, solution_last, from_t): 

2353 """ 

2354 Solve one "cycle" of the dynamic model for one agent type. This function 

2355 iterates over the periods within an agent's cycle, updating the time-varying 

2356 parameters and passing them to the single period solver(s). 

2357 

2358 Parameters 

2359 ---------- 

2360 agent : AgentType 

2361 The microeconomic AgentType whose dynamic problem is to be solved. 

2362 solution_last : Solution 

2363 A representation of the solution of the period that comes after the 

2364 end of the sequence of one period problems. This might be the term- 

2365 inal period solution, a "pseudo terminal" solution, or simply the 

2366 solution to the earliest period from the succeeding cycle. 

2367 from_t : int or None 

2368 If not None, indicates which period of the model the solver should start 

2369 from. When used, represents the time index that solution_last is from. 

2370 

2371 Returns 

2372 ------- 

2373 solution_cycle : [Solution] 

2374 A list of one period solutions for one "cycle" of the AgentType's 

2375 microeconomic model. 

2376 """ 

2377 

2378 # Check if the agent has a 'Parameters' attribute of the 'Parameters' class 

2379 # if so, take advantage of it. Else, use the old method 

2380 if hasattr(agent, "parameters") and isinstance(agent.parameters, Parameters): 

2381 T = agent.parameters._length if from_t is None else from_t 

2382 

2383 # Initialize the solution for this cycle, then iterate on periods 

2384 solution_cycle = [] 

2385 solution_next = solution_last 

2386 

2387 cycles_range = [0] + list(range(T - 1, 0, -1)) 

2388 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range: 

2389 # Update which single period solver to use (if it depends on time) 

2390 if hasattr(agent.solve_one_period, "__getitem__"): 

2391 solve_one_period = agent.solve_one_period[k] 

2392 else: 

2393 solve_one_period = agent.solve_one_period 

2394 

2395 if hasattr(solve_one_period, "solver_args"): 

2396 these_args = solve_one_period.solver_args 

2397 else: 

2398 these_args = get_arg_names(solve_one_period) 

2399 

2400 # Make a temporary dictionary for this period 

2401 temp_pars = agent.parameters[k] 

2402 temp_dict = { 

2403 name: solution_next if name == "solution_next" else temp_pars[name] 

2404 for name in these_args 

2405 } 

2406 

2407 # Solve one period, add it to the solution, and move to the next period 

2408 solution_t = solve_one_period(**temp_dict) 

2409 solution_cycle.insert(0, solution_t) 

2410 solution_next = solution_t 

2411 

2412 else: 

2413 # Calculate number of periods per cycle, defaults to 1 if all variables are time invariant 

2414 if len(agent.time_vary) > 0: 

2415 T = agent.T_cycle if from_t is None else from_t 

2416 else: 

2417 T = 1 

2418 

2419 solve_dict = { 

2420 parameter: agent.__dict__[parameter] for parameter in agent.time_inv 

2421 } 

2422 solve_dict.update({parameter: None for parameter in agent.time_vary}) 

2423 

2424 # Initialize the solution for this cycle, then iterate on periods 

2425 solution_cycle = [] 

2426 solution_next = solution_last 

2427 

2428 cycles_range = [0] + list(range(T - 1, 0, -1)) 

2429 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range: 

2430 # Update which single period solver to use (if it depends on time) 

2431 if hasattr(agent.solve_one_period, "__getitem__"): 

2432 solve_one_period = agent.solve_one_period[k] 

2433 else: 

2434 solve_one_period = agent.solve_one_period 

2435 

2436 if hasattr(solve_one_period, "solver_args"): 

2437 these_args = solve_one_period.solver_args 

2438 else: 

2439 these_args = get_arg_names(solve_one_period) 

2440 

2441 # Update time-varying single period inputs 

2442 for name in agent.time_vary: 

2443 if name in these_args: 

2444 solve_dict[name] = agent.__dict__[name][k] 

2445 solve_dict["solution_next"] = solution_next 

2446 

2447 # Make a temporary dictionary for this period 

2448 temp_dict = {name: solve_dict[name] for name in these_args} 

2449 

2450 # Solve one period, add it to the solution, and move to the next period 

2451 solution_t = solve_one_period(**temp_dict) 

2452 solution_cycle.insert(0, solution_t) 

2453 solution_next = solution_t 

2454 

2455 # Return the list of per-period solutions 

2456 return solution_cycle 

2457 

2458 

2459def make_one_period_oo_solver(solver_class): 

2460 """ 

2461 Returns a function that solves a single period consumption-saving 

2462 problem. 

2463 Parameters 

2464 ---------- 

2465 solver_class : Solver 

2466 A class of Solver to be used. 

2467 ------- 

2468 solver_function : function 

2469 A function for solving one period of a problem. 

2470 """ 

2471 

2472 def one_period_solver(**kwds): 

2473 solver = solver_class(**kwds) 

2474 

2475 # not ideal; better if this is defined in all Solver classes 

2476 if hasattr(solver, "prepare_to_solve"): 

2477 solver.prepare_to_solve() 

2478 

2479 solution_now = solver.solve() 

2480 return solution_now 

2481 

2482 one_period_solver.solver_class = solver_class 

2483 # This can be revisited once it is possible to export parameters 

2484 one_period_solver.solver_args = get_arg_names(solver_class.__init__)[1:] 

2485 

2486 return one_period_solver 

2487 

2488 

2489# ======================================================================== 

2490# ======================================================================== 

2491 

2492 

2493class Market(Model): 

2494 """ 

2495 A superclass to represent a central clearinghouse of information. Used for 

2496 dynamic general equilibrium models to solve the "macroeconomic" model as a 

2497 layer on top of the "microeconomic" models of one or more AgentTypes. 

2498 

2499 Parameters 

2500 ---------- 

2501 agents : [AgentType] 

2502 A list of all the AgentTypes in this market. 

2503 sow_vars : [string] 

2504 Names of variables generated by the "aggregate market process" that should 

2505 be "sown" to the agents in the market. Aggregate state, etc. 

2506 reap_vars : [string] 

2507 Names of variables to be collected ("reaped") from agents in the market 

2508 to be used in the "aggregate market process". 

2509 const_vars : [string] 

2510 Names of attributes of the Market instance that are used in the "aggregate 

2511 market process" but do not come from agents-- they are constant or simply 

2512 parameters inherent to the process. 

2513 track_vars : [string] 

2514 Names of variables generated by the "aggregate market process" that should 

2515 be tracked as a "history" so that a new dynamic rule can be calculated. 

2516 This is often a subset of sow_vars. 

2517 dyn_vars : [string] 

2518 Names of variables that constitute a "dynamic rule". 

2519 mill_rule : function 

2520 A function that takes inputs named in reap_vars and returns a tuple the 

2521 same size and order as sow_vars. The "aggregate market process" that 

2522 transforms individual agent actions/states/data into aggregate data to 

2523 be sent back to agents. 

2524 calc_dynamics : function 

2525 A function that takes inputs named in track_vars and returns an object 

2526 with attributes named in dyn_vars. Looks at histories of aggregate 

2527 variables and generates a new "dynamic rule" for agents to believe and 

2528 act on. 

2529 act_T : int 

2530 The number of times that the "aggregate market process" should be run 

2531 in order to generate a history of aggregate variables. 

2532 tolerance: float 

2533 Minimum acceptable distance between "dynamic rules" to consider the 

2534 Market solution process converged. Distance is a user-defined metric. 

2535 """ 

2536 

2537 def __init__( 

2538 self, 

2539 agents=None, 

2540 sow_vars=None, 

2541 reap_vars=None, 

2542 const_vars=None, 

2543 track_vars=None, 

2544 dyn_vars=None, 

2545 mill_rule=None, 

2546 calc_dynamics=None, 

2547 distributions=None, 

2548 act_T=1000, 

2549 tolerance=0.000001, 

2550 seed=0, 

2551 **kwds, 

2552 ): 

2553 super().__init__() 

2554 self.agents = agents if agents is not None else list() # NOQA 

2555 

2556 self.reap_vars = reap_vars if reap_vars is not None else list() # NOQA 

2557 self.reap_state = {var: [] for var in self.reap_vars} 

2558 

2559 self.sow_vars = sow_vars if sow_vars is not None else list() # NOQA 

2560 # dictionaries for tracking initial and current values 

2561 # of the sow variables. 

2562 self.sow_init = {var: None for var in self.sow_vars} 

2563 self.sow_state = {var: None for var in self.sow_vars} 

2564 

2565 const_vars = const_vars if const_vars is not None else list() # NOQA 

2566 self.const_vars = {var: None for var in const_vars} 

2567 

2568 self.track_vars = track_vars if track_vars is not None else list() # NOQA 

2569 self.dyn_vars = dyn_vars if dyn_vars is not None else list() # NOQA 

2570 self.distributions = distributions if distributions is not None else list() # NOQA 

2571 

2572 if mill_rule is not None: # To prevent overwriting of method-based mill_rules 

2573 self.mill_rule = mill_rule 

2574 if calc_dynamics is not None: # Ditto for calc_dynamics 

2575 self.calc_dynamics = calc_dynamics 

2576 self.act_T = act_T # NOQA 

2577 self.tolerance = tolerance # NOQA 

2578 self.seed = seed 

2579 self.max_loops = 1000 # NOQA 

2580 self.history = {} 

2581 self.assign_parameters(**kwds) 

2582 self.RNG = np.random.default_rng(self.seed) 

2583 

2584 self.print_parallel_error_once = True 

2585 # Print the error associated with calling the parallel method 

2586 # "solve_agents" one time. If set to false, the error will never 

2587 # print. See "solve_agents" for why this prints once or never. 

2588 

2589 def give_agent_params(self, construct=True): 

2590 """ 

2591 Distribute relevant market-level parameters to each AgentType in self.agents 

2592 by having them call their get_market_params method. 

2593 

2594 Parameters 

2595 ---------- 

2596 construct : bool, optional 

2597 Whether agents should run their construct method after fetching market 

2598 data (default True). 

2599 

2600 Returns 

2601 ------- 

2602 None 

2603 """ 

2604 for agent in self.agents: 

2605 agent.get_market_params(self, construct) 

2606 

2607 def solve_agents(self): 

2608 """ 

2609 Solves the microeconomic problem for all AgentTypes in this market. 

2610 

2611 Parameters 

2612 ---------- 

2613 None 

2614 

2615 Returns 

2616 ------- 

2617 None 

2618 """ 

2619 try: 

2620 multi_thread_commands(self.agents, ["solve()"]) 

2621 except Exception as err: 

2622 if self.print_parallel_error_once: 

2623 # Set flag to False so this is only printed once. 

2624 self.print_parallel_error_once = False 

2625 print( 

2626 "**** WARNING: could not execute multi_thread_commands in HARK.core.Market.solve_agents() ", 

2627 "so using the serial version instead. This will likely be slower. " 

2628 "The multi_thread_commands() functions failed with the following error:", 

2629 "\n", 

2630 sys.exc_info()[0], 

2631 ":", 

2632 err, 

2633 ) # sys.exc_info()[0]) 

2634 multi_thread_commands_fake(self.agents, ["solve()"]) 

2635 

2636 def solve(self): 

2637 """ 

2638 "Solves" the market by finding a "dynamic rule" that governs the aggregate 

2639 market state such that when agents believe in these dynamics, their actions 

2640 collectively generate the same dynamic rule. 

2641 

2642 Parameters 

2643 ---------- 

2644 None 

2645 

2646 Returns 

2647 ------- 

2648 None 

2649 """ 

2650 go = True 

2651 max_loops = self.max_loops # Failsafe against infinite solution loop 

2652 completed_loops = 0 

2653 old_dynamics = None 

2654 

2655 while go: # Loop until the dynamic process converges or we hit the loop cap 

2656 self.solve_agents() # Solve each AgentType's micro problem 

2657 self.make_history() # "Run" the model while tracking aggregate variables 

2658 new_dynamics = self.update_dynamics() # Find a new aggregate dynamic rule 

2659 

2660 # Check to see if the dynamic rule has converged (if this is not the first loop) 

2661 if completed_loops > 0: 

2662 distance = new_dynamics.distance(old_dynamics) 

2663 else: 

2664 distance = 1000000.0 

2665 

2666 # Move to the next loop if the terminal conditions are not met 

2667 old_dynamics = new_dynamics 

2668 completed_loops += 1 

2669 go = distance >= self.tolerance and completed_loops < max_loops 

2670 

2671 self.dynamics = new_dynamics # Store the final dynamic rule in self 

2672 

2673 def reap(self): 

2674 """ 

2675 Collects attributes named in reap_vars from each AgentType in the market, 

2676 storing them in respectively named attributes of self. 

2677 

2678 Parameters 

2679 ---------- 

2680 none 

2681 

2682 Returns 

2683 ------- 

2684 none 

2685 """ 

2686 for var in self.reap_state: 

2687 harvest = [] 

2688 

2689 for agent in self.agents: 

2690 # TODO: generalized variable lookup across namespaces 

2691 if var in agent.state_now: 

2692 # or state_now ?? 

2693 harvest.append(agent.state_now[var]) 

2694 

2695 self.reap_state[var] = harvest 

2696 

2697 def sow(self): 

2698 """ 

2699 Distributes attrributes named in sow_vars from self to each AgentType 

2700 in the market, storing them in respectively named attributes. 

2701 

2702 Parameters 

2703 ---------- 

2704 none 

2705 

2706 Returns 

2707 ------- 

2708 none 

2709 """ 

2710 for sow_var in self.sow_state: 

2711 for this_type in self.agents: 

2712 if sow_var in this_type.state_now: 

2713 this_type.state_now[sow_var] = self.sow_state[sow_var] 

2714 if sow_var in this_type.shocks: 

2715 this_type.shocks[sow_var] = self.sow_state[sow_var] 

2716 else: 

2717 setattr(this_type, sow_var, self.sow_state[sow_var]) 

2718 

2719 def mill(self): 

2720 """ 

2721 Processes the variables collected from agents using the function mill_rule, 

2722 storing the results in attributes named in aggr_sow. 

2723 

2724 Parameters 

2725 ---------- 

2726 none 

2727 

2728 Returns 

2729 ------- 

2730 none 

2731 """ 

2732 # Make a dictionary of inputs for the mill_rule 

2733 mill_dict = copy(self.reap_state) 

2734 mill_dict.update(self.const_vars) 

2735 

2736 # Run the mill_rule and store its output in self 

2737 product = self.mill_rule(**mill_dict) 

2738 

2739 for i, sow_var in enumerate(self.sow_state): 

2740 self.sow_state[sow_var] = product[i] 

2741 

2742 def cultivate(self): 

2743 """ 

2744 Has each AgentType in agents perform their market_action method, using 

2745 variables sown from the market (and maybe also "private" variables). 

2746 The market_action method should store new results in attributes named in 

2747 reap_vars to be reaped later. 

2748 

2749 Parameters 

2750 ---------- 

2751 none 

2752 

2753 Returns 

2754 ------- 

2755 none 

2756 """ 

2757 for this_type in self.agents: 

2758 this_type.market_action() 

2759 

2760 def reset(self): 

2761 """ 

2762 Reset the state of the market (attributes in sow_vars, etc) to some 

2763 user-defined initial state, and erase the histories of tracked variables. 

2764 Also resets the internal RNG so that draws can be reproduced. 

2765 

2766 Parameters 

2767 ---------- 

2768 none 

2769 

2770 Returns 

2771 ------- 

2772 none 

2773 """ 

2774 # Reset internal RNG and distributions 

2775 for name in self.distributions: 

2776 if not hasattr(self, name): 

2777 continue 

2778 dstn = getattr(self, name) 

2779 if isinstance(dstn, list): 

2780 for D in dstn: 

2781 D.reset() 

2782 else: 

2783 dstn.reset() 

2784 

2785 # Reset the history of tracked variables 

2786 self.history = {var_name: [] for var_name in self.track_vars} 

2787 

2788 # Set the sow variables to their initial levels 

2789 for var_name in self.sow_state: 

2790 self.sow_state[var_name] = self.sow_init[var_name] 

2791 

2792 # Reset each AgentType in the market 

2793 for this_type in self.agents: 

2794 this_type.reset() 

2795 

2796 def store(self): 

2797 """ 

2798 Record the current value of each variable X named in track_vars in an 

2799 dictionary field named history[X]. 

2800 

2801 Parameters 

2802 ---------- 

2803 none 

2804 

2805 Returns 

2806 ------- 

2807 none 

2808 """ 

2809 for var_name in self.track_vars: 

2810 if var_name in self.sow_state: 

2811 value_now = self.sow_state[var_name] 

2812 elif var_name in self.reap_state: 

2813 value_now = self.reap_state[var_name] 

2814 elif var_name in self.const_vars: 

2815 value_now = self.const_vars[var_name] 

2816 else: 

2817 value_now = getattr(self, var_name) 

2818 

2819 self.history[var_name].append(value_now) 

2820 

2821 def make_history(self): 

2822 """ 

2823 Runs a loop of sow-->cultivate-->reap-->mill act_T times, tracking the 

2824 evolution of variables X named in track_vars in dictionary fields 

2825 history[X]. 

2826 

2827 Parameters 

2828 ---------- 

2829 none 

2830 

2831 Returns 

2832 ------- 

2833 none 

2834 """ 

2835 self.reset() # Initialize the state of the market 

2836 for t in range(self.act_T): 

2837 self.sow() # Distribute aggregated information/state to agents 

2838 self.cultivate() # Agents take action 

2839 self.reap() # Collect individual data from agents 

2840 self.mill() # Process individual data into aggregate data 

2841 self.store() # Record variables of interest 

2842 

2843 def update_dynamics(self): 

2844 """ 

2845 Calculates a new "aggregate dynamic rule" using the history of variables 

2846 named in track_vars, and distributes this rule to AgentTypes in agents. 

2847 

2848 Parameters 

2849 ---------- 

2850 none 

2851 

2852 Returns 

2853 ------- 

2854 dynamics : instance 

2855 The new "aggregate dynamic rule" that agents believe in and act on. 

2856 Should have attributes named in dyn_vars. 

2857 """ 

2858 # Make a dictionary of inputs for the dynamics calculator 

2859 arg_names = list(get_arg_names(self.calc_dynamics)) 

2860 if "self" in arg_names: 

2861 arg_names.remove("self") 

2862 update_dict = {} 

2863 for name in arg_names: 

2864 update_dict[name] = ( 

2865 self.history[name] if name in self.track_vars else getattr(self, name) 

2866 ) 

2867 # Calculate a new dynamic rule and distribute it to the agents in agent_list 

2868 dynamics = self.calc_dynamics(**update_dict) # User-defined dynamics calculator 

2869 for var_name in self.dyn_vars: 

2870 this_obj = getattr(dynamics, var_name) 

2871 for this_type in self.agents: 

2872 setattr(this_type, var_name, this_obj) 

2873 return dynamics 

2874 

2875 

2876def distribute_params(agent, param_name, param_count, distribution): 

2877 """ 

2878 Distributes heterogeneous values of one parameter to the AgentTypes in self.agents. 

2879 Parameters 

2880 ---------- 

2881 agent: AgentType 

2882 An agent to clone. 

2883 param_name : string 

2884 Name of the parameter to be assigned. 

2885 param_count : int 

2886 Number of different values the parameter will take on. 

2887 distribution : Distribution 

2888 A 1-D distribution. 

2889 

2890 Returns 

2891 ------- 

2892 agent_set : [AgentType] 

2893 A list of param_count agents, ex ante heterogeneous with 

2894 respect to param_name. The AgentCount of the original 

2895 will be split between the agents of the returned 

2896 list in proportion to the given distribution. 

2897 """ 

2898 param_dist = distribution.discretize(N=param_count) 

2899 

2900 agent_set = [deepcopy(agent) for i in range(param_count)] 

2901 

2902 for j in range(param_count): 

2903 agent_set[j].assign_parameters( 

2904 **{"AgentCount": int(agent.AgentCount * param_dist.pmv[j])} 

2905 ) 

2906 agent_set[j].assign_parameters(**{param_name: param_dist.atoms[0, j]}) 

2907 

2908 return agent_set 

2909 

2910 

2911@dataclass 

2912class AgentPopulation: 

2913 """ 

2914 A class for representing a population of ex-ante heterogeneous agents. 

2915 """ 

2916 

2917 agent_type: AgentType # type of agent in the population 

2918 parameters: dict # dictionary of parameters 

2919 seed: int = 0 # random seed 

2920 time_var: List[str] = field(init=False) 

2921 time_inv: List[str] = field(init=False) 

2922 distributed_params: List[str] = field(init=False) 

2923 agent_type_count: Optional[int] = field(init=False) 

2924 term_age: Optional[int] = field(init=False) 

2925 continuous_distributions: Dict[str, Distribution] = field(init=False) 

2926 discrete_distributions: Dict[str, Distribution] = field(init=False) 

2927 population_parameters: List[Dict[str, Union[List[float], float]]] = field( 

2928 init=False 

2929 ) 

2930 agents: List[AgentType] = field(init=False) 

2931 agent_database: pd.DataFrame = field(init=False) 

2932 solution: List[Any] = field(init=False) 

2933 

2934 def __post_init__(self): 

2935 """ 

2936 Initialize the population of agents, determine distributed parameters, 

2937 and infer `agent_type_count` and `term_age`. 

2938 """ 

2939 # create a dummy agent and obtain its time-varying 

2940 # and time-invariant attributes 

2941 dummy_agent = self.agent_type() 

2942 self.time_var = dummy_agent.time_vary 

2943 self.time_inv = dummy_agent.time_inv 

2944 

2945 # create list of distributed parameters 

2946 # these are parameters that differ across agents 

2947 self.distributed_params = [ 

2948 key 

2949 for key, param in self.parameters.items() 

2950 if (isinstance(param, list) and isinstance(param[0], list)) 

2951 or isinstance(param, Distribution) 

2952 or (isinstance(param, DataArray) and param.dims[0] == "agent") 

2953 ] 

2954 

2955 self.__infer_counts__() 

2956 

2957 self.print_parallel_error_once = True 

2958 # Print warning once if parallel simulation fails 

2959 

2960 def __infer_counts__(self): 

2961 """ 

2962 Infer `agent_type_count` and `term_age` from the parameters. 

2963 If parameters include a `Distribution` type, a list of lists, 

2964 or a `DataArray` with `agent` as the first dimension, then 

2965 the AgentPopulation contains ex-ante heterogenous agents. 

2966 """ 

2967 

2968 # infer agent_type_count from distributed parameters 

2969 agent_type_count = 1 

2970 for key in self.distributed_params: 

2971 param = self.parameters[key] 

2972 if isinstance(param, Distribution): 

2973 agent_type_count = None 

2974 warn( 

2975 "Cannot infer agent_type_count from a Distribution. " 

2976 "Please provide approximation parameters." 

2977 ) 

2978 break 

2979 elif isinstance(param, list): 

2980 agent_type_count = max(agent_type_count, len(param)) 

2981 elif isinstance(param, DataArray) and param.dims[0] == "agent": 

2982 agent_type_count = max(agent_type_count, param.shape[0]) 

2983 

2984 self.agent_type_count = agent_type_count 

2985 

2986 # infer term_age from all parameters 

2987 term_age = 1 

2988 for param in self.parameters.values(): 

2989 if isinstance(param, Distribution): 

2990 term_age = None 

2991 warn( 

2992 "Cannot infer term_age from a Distribution. " 

2993 "Please provide approximation parameters." 

2994 ) 

2995 break 

2996 elif isinstance(param, list) and isinstance(param[0], list): 

2997 term_age = max(term_age, len(param[0])) 

2998 elif isinstance(param, DataArray) and param.dims[-1] == "age": 

2999 term_age = max(term_age, param.shape[-1]) 

3000 

3001 self.term_age = term_age 

3002 

3003 def approx_distributions(self, approx_params: dict): 

3004 """ 

3005 Approximate continuous distributions with discrete ones. If the initial 

3006 parameters include a `Distribution` type, then the AgentPopulation is 

3007 not ready to solve, and stands for an abstract population. To solve the 

3008 AgentPopulation, we need discretization parameters for each continuous 

3009 distribution. This method approximates the continuous distributions with 

3010 discrete ones, and updates the parameters dictionary. 

3011 """ 

3012 self.continuous_distributions = {} 

3013 self.discrete_distributions = {} 

3014 

3015 for key, args in approx_params.items(): 

3016 param = self.parameters[key] 

3017 if key in self.distributed_params and isinstance(param, Distribution): 

3018 self.continuous_distributions[key] = param 

3019 self.discrete_distributions[key] = param.discretize(**args) 

3020 else: 

3021 raise ValueError( 

3022 f"Warning: parameter {key} is not a Distribution found " 

3023 f"in agent type {self.agent_type}" 

3024 ) 

3025 

3026 if len(self.discrete_distributions) > 1: 

3027 joint_dist = combine_indep_dstns(*self.discrete_distributions.values()) 

3028 else: 

3029 joint_dist = list(self.discrete_distributions.values())[0] 

3030 

3031 for i, key in enumerate(self.discrete_distributions): 

3032 self.parameters[key] = DataArray(joint_dist.atoms[i], dims=("agent")) 

3033 

3034 self.__infer_counts__() 

3035 

3036 def __parse_parameters__(self) -> None: 

3037 """ 

3038 Creates distributed dictionaries of parameters for each ex-ante 

3039 heterogeneous agent in the parameterized population. The parameters 

3040 are stored in a list of dictionaries, where each dictionary contains 

3041 the parameters for one agent. Expands parameters that vary over time 

3042 to a list of length `term_age`. 

3043 """ 

3044 

3045 population_parameters = [] # container for dictionaries of each agent subgroup 

3046 for agent in range(self.agent_type_count): 

3047 agent_parameters = {} 

3048 for key, param in self.parameters.items(): 

3049 if key in self.time_var: 

3050 # parameters that vary over time have to be repeated 

3051 if isinstance(param, (int, float)): 

3052 parameter_per_t = [param] * self.term_age 

3053 elif isinstance(param, list): 

3054 if isinstance(param[0], list): 

3055 parameter_per_t = param[agent] 

3056 else: 

3057 parameter_per_t = param 

3058 elif isinstance(param, DataArray): 

3059 if param.dims[0] == "agent": 

3060 if len(param.dims) > 1 and param.dims[-1] == "age": 

3061 parameter_per_t = param[agent].values.tolist() 

3062 else: 

3063 parameter_per_t = param[agent].item() 

3064 elif param.dims[0] == "age": 

3065 parameter_per_t = param.values.tolist() 

3066 

3067 agent_parameters[key] = parameter_per_t 

3068 

3069 elif key in self.time_inv: 

3070 if isinstance(param, (int, float)): 

3071 agent_parameters[key] = param 

3072 elif isinstance(param, list): 

3073 if isinstance(param[0], list): 

3074 agent_parameters[key] = param[agent] 

3075 else: 

3076 agent_parameters[key] = param 

3077 elif isinstance(param, DataArray) and param.dims[0] == "agent": 

3078 agent_parameters[key] = param[agent].item() 

3079 

3080 else: 

3081 if isinstance(param, (int, float)): 

3082 agent_parameters[key] = param # assume time inv 

3083 elif isinstance(param, list): 

3084 if isinstance(param[0], list): 

3085 agent_parameters[key] = param[agent] # assume agent vary 

3086 else: 

3087 agent_parameters[key] = param # assume time vary 

3088 elif isinstance(param, DataArray): 

3089 if param.dims[0] == "agent": 

3090 if len(param.dims) > 1 and param.dims[-1] == "age": 

3091 agent_parameters[key] = param[agent].values.tolist() 

3092 else: 

3093 agent_parameters[key] = param[agent].item() 

3094 elif param.dims[0] == "age": 

3095 agent_parameters[key] = param.values.tolist() 

3096 

3097 population_parameters.append(agent_parameters) 

3098 

3099 self.population_parameters = population_parameters 

3100 

3101 def create_distributed_agents(self): 

3102 """ 

3103 Parses the parameters dictionary and creates a list of agents with the 

3104 appropriate parameters. Also sets the seed for each agent. 

3105 """ 

3106 

3107 self.__parse_parameters__() 

3108 

3109 rng = np.random.default_rng(self.seed) 

3110 

3111 self.agents = [ 

3112 self.agent_type(seed=rng.integers(0, 2**31 - 1), **agent_dict) 

3113 for agent_dict in self.population_parameters 

3114 ] 

3115 

3116 def create_database(self): 

3117 """ 

3118 Optionally creates a pandas DataFrame with the parameters for each agent. 

3119 """ 

3120 database = pd.DataFrame(self.population_parameters) 

3121 database["agents"] = self.agents 

3122 

3123 self.agent_database = database 

3124 

3125 def solve(self): 

3126 """ 

3127 Solves each agent of the population serially. 

3128 """ 

3129 

3130 # see Market class for an example of how to solve distributed agents in parallel 

3131 

3132 for agent in self.agents: 

3133 agent.solve() 

3134 

3135 def unpack_solutions(self): 

3136 """ 

3137 Unpacks the solutions of each agent into an attribute of the population. 

3138 """ 

3139 self.solution = [agent.solution for agent in self.agents] 

3140 

3141 def initialize_sim(self): 

3142 """ 

3143 Initializes the simulation for each agent. 

3144 """ 

3145 for agent in self.agents: 

3146 agent.initialize_sim() 

3147 

3148 def simulate(self, num_jobs=None): 

3149 """ 

3150 Simulates each agent of the population. 

3151 

3152 Parameters 

3153 ---------- 

3154 num_jobs : int, optional 

3155 Number of parallel jobs to use. Defaults to using all available 

3156 cores when ``None``. Falls back to serial execution if parallel 

3157 processing fails. 

3158 """ 

3159 try: 

3160 multi_thread_commands(self.agents, ["simulate()"], num_jobs) 

3161 except Exception as err: 

3162 if getattr(self, "print_parallel_error_once", False): 

3163 self.print_parallel_error_once = False 

3164 print( 

3165 "**** WARNING: could not execute multi_thread_commands in HARK.core.AgentPopulation.simulate() ", 

3166 "so using the serial version instead. This will likely be slower. ", 

3167 "The multi_thread_commands() function failed with the following error:\n", 

3168 sys.exc_info()[0], 

3169 ":", 

3170 err, 

3171 ) 

3172 multi_thread_commands_fake(self.agents, ["simulate()"], num_jobs) 

3173 

3174 def __iter__(self): 

3175 """ 

3176 Allows for iteration over the agents in the population. 

3177 """ 

3178 return iter(self.agents) 

3179 

3180 def __getitem__(self, idx): 

3181 """ 

3182 Allows for indexing into the population. 

3183 """ 

3184 return self.agents[idx] 

3185 

3186 

3187############################################################################### 

3188 

3189 

3190def multi_thread_commands_fake( 

3191 agent_list: List, command_list: List, num_jobs=None 

3192) -> None: 

3193 """ 

3194 Executes the list of commands in command_list for each AgentType in agent_list 

3195 in an ordinary, single-threaded loop. Each command should be a method of 

3196 that AgentType subclass. This function exists so as to easily disable 

3197 multithreading, as it uses the same syntax as multi_thread_commands. 

3198 

3199 Parameters 

3200 ---------- 

3201 agent_list : [AgentType] 

3202 A list of instances of AgentType on which the commands will be run. 

3203 command_list : [string] 

3204 A list of commands to run for each AgentType. 

3205 num_jobs : None 

3206 Dummy input to match syntax of multi_thread_commands. Does nothing. 

3207 

3208 Returns 

3209 ------- 

3210 none 

3211 """ 

3212 for agent in agent_list: 

3213 for command in command_list: 

3214 # Can pass method names with or without parentheses 

3215 if command[-2:] == "()": 

3216 getattr(agent, command[:-2])() 

3217 else: 

3218 getattr(agent, command)() 

3219 

3220 

3221def multi_thread_commands(agent_list: List, command_list: List, num_jobs=None) -> None: 

3222 """ 

3223 Executes the list of commands in command_list for each AgentType in agent_list 

3224 using a multithreaded system. Each command should be a method of that AgentType subclass. 

3225 

3226 Parameters 

3227 ---------- 

3228 agent_list : [AgentType] 

3229 A list of instances of AgentType on which the commands will be run. 

3230 command_list : [string] 

3231 A list of commands to run for each AgentType in agent_list. 

3232 

3233 Returns 

3234 ------- 

3235 None 

3236 """ 

3237 if len(agent_list) == 1: 

3238 multi_thread_commands_fake(agent_list, command_list) 

3239 return None 

3240 

3241 # Default number of parallel jobs is the smaller of number of AgentTypes in 

3242 # the input and the number of available cores. 

3243 if num_jobs is None: 

3244 num_jobs = min(len(agent_list), multiprocessing.cpu_count()) 

3245 

3246 # Send each command in command_list to each of the types in agent_list to be run 

3247 agent_list_out = Parallel(n_jobs=num_jobs)( 

3248 delayed(run_commands)(*args) 

3249 for args in zip(agent_list, len(agent_list) * [command_list]) 

3250 ) 

3251 

3252 # Replace the original types with the output from the parallel call 

3253 for j in range(len(agent_list)): 

3254 agent_list[j] = agent_list_out[j] 

3255 

3256 

3257def run_commands(agent: Any, command_list: List) -> Any: 

3258 """ 

3259 Executes each command in command_list on a given AgentType. The commands 

3260 should be methods of that AgentType's subclass. 

3261 

3262 Parameters 

3263 ---------- 

3264 agent : AgentType 

3265 An instance of AgentType on which the commands will be run. 

3266 command_list : [string] 

3267 A list of commands that the agent should run, as methods. 

3268 

3269 Returns 

3270 ------- 

3271 agent : AgentType 

3272 The same AgentType instance passed as input, after running the commands. 

3273 """ 

3274 for command in command_list: 

3275 # Can pass method names with or without parentheses 

3276 if command[-2:] == "()": 

3277 getattr(agent, command[:-2])() 

3278 else: 

3279 getattr(agent, command)() 

3280 return agent