Coverage for HARK / core.py: 96%

1092 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-10 06:19 +0000

1""" 

2High-level functions and classes for solving a wide variety of economic models. 

3The "core" of HARK is a framework for "microeconomic" and "macroeconomic" 

4models. A micro model concerns the dynamic optimization problem for some type 

5of agents, where agents take the inputs to their problem as exogenous. A macro 

6model adds an additional layer, endogenizing some of the inputs to the micro 

7problem by finding a general equilibrium dynamic rule. 

8""" 

9 

10# Import basic modules 

11import inspect 

12import sys 

13from collections import namedtuple 

14from copy import copy, deepcopy 

15from dataclasses import dataclass, field 

16from time import time 

17from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union 

18from warnings import warn 

19import multiprocessing 

20from joblib import Parallel, delayed 

21from pandas import DataFrame 

22 

23import numpy as np 

24import pandas as pd 

25from xarray import DataArray 

26 

27from HARK.distributions import ( 

28 Distribution, 

29 IndexDistribution, 

30 combine_indep_dstns, 

31) 

32from HARK.utilities import NullFunc, get_arg_names, get_it_from 

33from HARK.simulator import make_simulator_from_agent 

34from HARK.SSJutils import ( 

35 make_basic_SSJ_matrices, 

36 make_flat_LC_SSJ_matrices, 

37 calc_shock_response_manually, 

38) 

39from HARK.metric import MetricObject, distance_metric 

40 

41__all__ = [ 

42 "AgentType", 

43 "Market", 

44 "Parameters", 

45 "Model", 

46 "AgentPopulation", 

47 "multi_thread_commands", 

48 "multi_thread_commands_fake", 

49 "NullFunc", 

50 "make_one_period_oo_solver", 

51 "distribute_params", 

52] 

53 

54 

55class Parameters: 

56 """ 

57 A smart container for model parameters that handles age-varying dynamics. 

58 

59 This class stores parameters as an internal dictionary and manages their 

60 age-varying properties, providing both attribute-style and dictionary-style 

61 access. It is designed to handle the time-varying dynamics of parameters 

62 in economic models. 

63 

64 Attributes 

65 ---------- 

66 _length : int 

67 The terminal age of the agents in the model. 

68 _invariant_params : Set[str] 

69 A set of parameter names that are invariant over time. 

70 _varying_params : Set[str] 

71 A set of parameter names that vary over time. 

72 _parameters : Dict[str, Any] 

73 The internal dictionary storing all parameters. 

74 """ 

75 

76 __slots__ = ( 

77 "_length", 

78 "_invariant_params", 

79 "_varying_params", 

80 "_parameters", 

81 "_frozen", 

82 "_namedtuple_cache", 

83 ) 

84 

85 def __init__(self, **parameters: Any) -> None: 

86 """ 

87 Initialize a Parameters object and parse the age-varying dynamics of parameters. 

88 

89 Parameters 

90 ---------- 

91 T_cycle : int, optional 

92 The number of time periods in the model cycle (default: 1). 

93 Must be >= 1. 

94 frozen : bool, optional 

95 If True, the Parameters object will be immutable after initialization 

96 (default: False). 

97 _time_inv : List[str], optional 

98 List of parameter names to explicitly mark as time-invariant, 

99 overriding automatic inference. 

100 _time_vary : List[str], optional 

101 List of parameter names to explicitly mark as time-varying, 

102 overriding automatic inference. 

103 **parameters : Any 

104 Any number of parameters in the form key=value. 

105 

106 Raises 

107 ------ 

108 ValueError 

109 If T_cycle is less than 1. 

110 

111 Notes 

112 ----- 

113 Automatic time-variance inference rules: 

114 - Scalars (int, float, bool, None) are time-invariant 

115 - NumPy arrays are time-invariant (use lists/tuples for time-varying) 

116 - Single-element lists/tuples [x] are unwrapped to x and time-invariant 

117 - Multi-element lists/tuples are time-varying if length matches T_cycle 

118 - 2D arrays with first dimension matching T_cycle are time-varying 

119 - Distributions and Callables are time-invariant 

120 

121 Use _time_inv or _time_vary to override automatic inference when needed. 

122 """ 

123 # Extract special parameters 

124 self._length: int = parameters.pop("T_cycle", 1) 

125 frozen: bool = parameters.pop("frozen", False) 

126 time_inv_override: List[str] = parameters.pop("_time_inv", []) 

127 time_vary_override: List[str] = parameters.pop("_time_vary", []) 

128 

129 # Validate T_cycle 

130 if self._length < 1: 

131 raise ValueError(f"T_cycle must be >= 1, got {self._length}") 

132 

133 # Initialize internal state 

134 self._invariant_params: Set[str] = set() 

135 self._varying_params: Set[str] = set() 

136 self._parameters: Dict[str, Any] = {"T_cycle": self._length} 

137 self._frozen: bool = False # Set to False initially to allow setup 

138 self._namedtuple_cache: Optional[type] = None 

139 

140 # Set parameters using automatic inference 

141 for key, value in parameters.items(): 

142 self[key] = value 

143 

144 # Apply explicit overrides 

145 for param in time_inv_override: 

146 if param in self._parameters: 

147 self._invariant_params.add(param) 

148 self._varying_params.discard(param) 

149 

150 for param in time_vary_override: 

151 if param in self._parameters: 

152 self._varying_params.add(param) 

153 self._invariant_params.discard(param) 

154 

155 # Freeze if requested 

156 self._frozen = frozen 

157 

158 def __getitem__(self, item_or_key: Union[int, str]) -> Union["Parameters", Any]: 

159 """ 

160 Access parameters by age index or parameter name. 

161 

162 If item_or_key is an integer, returns a Parameters object with the parameters 

163 that apply to that age. This includes all invariant parameters and the 

164 `item_or_key`th element of all age-varying parameters. If item_or_key is a 

165 string, it returns the value of the parameter with that name. 

166 

167 Parameters 

168 ---------- 

169 item_or_key : Union[int, str] 

170 Age index or parameter name. 

171 

172 Returns 

173 ------- 

174 Union[Parameters, Any] 

175 A new Parameters object for the specified age, or the value of the 

176 specified parameter. 

177 

178 Raises 

179 ------ 

180 ValueError: 

181 If the age index is out of bounds. 

182 KeyError: 

183 If the parameter name is not found. 

184 TypeError: 

185 If the key is neither an integer nor a string. 

186 """ 

187 if isinstance(item_or_key, int): 

188 if item_or_key < 0 or item_or_key >= self._length: 

189 raise ValueError( 

190 f"Age {item_or_key} is out of bounds (valid: 0-{self._length - 1})." 

191 ) 

192 

193 params = {key: self._parameters[key] for key in self._invariant_params} 

194 params.update( 

195 { 

196 key: ( 

197 self._parameters[key][item_or_key] 

198 if isinstance(self._parameters[key], (list, tuple, np.ndarray)) 

199 else self._parameters[key] 

200 ) 

201 for key in self._varying_params 

202 } 

203 ) 

204 return Parameters(**params) 

205 elif isinstance(item_or_key, str): 

206 return self._parameters[item_or_key] 

207 else: 

208 raise TypeError("Key must be an integer (age) or string (parameter name).") 

209 

210 def __setitem__(self, key: str, value: Any) -> None: 

211 """ 

212 Set parameter values, automatically inferring time variance. 

213 

214 If the parameter is a scalar, numpy array, boolean, distribution, callable 

215 or None, it is assumed to be invariant over time. If the parameter is a 

216 list or tuple, it is assumed to be varying over time. If the parameter 

217 is a list or tuple of length greater than 1, the length of the list or 

218 tuple must match the `_length` attribute of the Parameters object. 

219 

220 2D numpy arrays with first dimension matching T_cycle are treated as 

221 time-varying parameters. 

222 

223 Parameters 

224 ---------- 

225 key : str 

226 Name of the parameter. 

227 value : Any 

228 Value of the parameter. 

229 

230 Raises 

231 ------ 

232 ValueError: 

233 If the parameter name is not a string or if the value type is unsupported. 

234 If the parameter value is inconsistent with the current model length. 

235 RuntimeError: 

236 If the Parameters object is frozen. 

237 """ 

238 if self._frozen: 

239 raise RuntimeError("Cannot modify frozen Parameters object") 

240 

241 if not isinstance(key, str): 

242 raise ValueError(f"Parameter name must be a string, got {type(key)}") 

243 

244 # Check for 2D numpy arrays with time-varying first dimension 

245 if isinstance(value, np.ndarray) and value.ndim >= 2: 

246 if value.shape[0] == self._length: 

247 self._varying_params.add(key) 

248 self._invariant_params.discard(key) 

249 else: 

250 self._invariant_params.add(key) 

251 self._varying_params.discard(key) 

252 elif isinstance( 

253 value, 

254 ( 

255 int, 

256 float, 

257 np.ndarray, 

258 type(None), 

259 Distribution, 

260 bool, 

261 Callable, 

262 MetricObject, 

263 ), 

264 ): 

265 self._invariant_params.add(key) 

266 self._varying_params.discard(key) 

267 elif isinstance(value, (list, tuple)): 

268 if len(value) == 1: 

269 value = value[0] 

270 self._invariant_params.add(key) 

271 self._varying_params.discard(key) 

272 elif self._length is None or self._length == 1: 

273 self._length = len(value) 

274 self._varying_params.add(key) 

275 self._invariant_params.discard(key) 

276 elif len(value) == self._length: 

277 self._varying_params.add(key) 

278 self._invariant_params.discard(key) 

279 else: 

280 raise ValueError( 

281 f"Parameter {key} must have length 1 or {self._length}, not {len(value)}" 

282 ) 

283 else: 

284 raise ValueError(f"Unsupported type for parameter {key}: {type(value)}") 

285 

286 self._parameters[key] = value 

287 

288 def __iter__(self) -> Iterator[str]: 

289 """Allow iteration over parameter names.""" 

290 return iter(self._parameters) 

291 

292 def __len__(self) -> int: 

293 """Return the number of parameters.""" 

294 return len(self._parameters) 

295 

296 def keys(self) -> Iterator[str]: 

297 """Return a view of parameter names.""" 

298 return self._parameters.keys() 

299 

300 def values(self) -> Iterator[Any]: 

301 """Return a view of parameter values.""" 

302 return self._parameters.values() 

303 

304 def items(self) -> Iterator[Tuple[str, Any]]: 

305 """Return a view of parameter (name, value) pairs.""" 

306 return self._parameters.items() 

307 

308 def to_dict(self) -> Dict[str, Any]: 

309 """ 

310 Convert parameters to a plain dictionary. 

311 

312 Returns 

313 ------- 

314 Dict[str, Any] 

315 A dictionary containing all parameters. 

316 """ 

317 return dict(self._parameters) 

318 

319 def to_namedtuple(self) -> namedtuple: 

320 """ 

321 Convert parameters to a namedtuple. 

322 

323 The namedtuple class is cached for efficiency on repeated calls. 

324 

325 Returns 

326 ------- 

327 namedtuple 

328 A namedtuple containing all parameters. 

329 """ 

330 if self._namedtuple_cache is None: 

331 self._namedtuple_cache = namedtuple("Parameters", self.keys()) 

332 return self._namedtuple_cache(**self.to_dict()) 

333 

334 def update(self, other: Union["Parameters", Dict[str, Any]]) -> None: 

335 """ 

336 Update parameters from another Parameters object or dictionary. 

337 

338 Parameters 

339 ---------- 

340 other : Union[Parameters, Dict[str, Any]] 

341 The source of parameters to update from. 

342 

343 Raises 

344 ------ 

345 TypeError 

346 If the input is neither a Parameters object nor a dictionary. 

347 """ 

348 if isinstance(other, Parameters): 

349 for key, value in other._parameters.items(): 

350 self[key] = value 

351 elif isinstance(other, dict): 

352 for key, value in other.items(): 

353 self[key] = value 

354 else: 

355 raise TypeError( 

356 "Update source must be a Parameters object or a dictionary." 

357 ) 

358 

359 def __repr__(self) -> str: 

360 """Return a detailed string representation of the Parameters object.""" 

361 return ( 

362 f"Parameters(_length={self._length}, " 

363 f"_invariant_params={self._invariant_params}, " 

364 f"_varying_params={self._varying_params}, " 

365 f"_parameters={self._parameters})" 

366 ) 

367 

368 def __str__(self) -> str: 

369 """Return a simple string representation of the Parameters object.""" 

370 return f"Parameters({str(self._parameters)})" 

371 

372 def __getattr__(self, name: str) -> Any: 

373 """ 

374 Allow attribute-style access to parameters. 

375 

376 Parameters 

377 ---------- 

378 name : str 

379 Name of the parameter to access. 

380 

381 Returns 

382 ------- 

383 Any 

384 The value of the specified parameter. 

385 

386 Raises 

387 ------ 

388 AttributeError: 

389 If the parameter name is not found. 

390 """ 

391 if name.startswith("_"): 

392 return super().__getattribute__(name) 

393 try: 

394 return self._parameters[name] 

395 except KeyError: 

396 raise AttributeError(f"'Parameters' object has no attribute '{name}'") 

397 

398 def __setattr__(self, name: str, value: Any) -> None: 

399 """ 

400 Allow attribute-style setting of parameters. 

401 

402 Parameters 

403 ---------- 

404 name : str 

405 Name of the parameter to set. 

406 value : Any 

407 Value to set for the parameter. 

408 """ 

409 if name.startswith("_"): 

410 super().__setattr__(name, value) 

411 else: 

412 self[name] = value 

413 

414 def __contains__(self, item: str) -> bool: 

415 """Check if a parameter exists in the Parameters object.""" 

416 return item in self._parameters 

417 

418 def copy(self) -> "Parameters": 

419 """ 

420 Create a deep copy of the Parameters object. 

421 

422 Returns 

423 ------- 

424 Parameters 

425 A new Parameters object with the same contents. 

426 """ 

427 return deepcopy(self) 

428 

429 def add_to_time_vary(self, *params: str) -> None: 

430 """ 

431 Adds any number of parameters to the time-varying set. 

432 

433 Parameters 

434 ---------- 

435 *params : str 

436 Any number of strings naming parameters to be added to time_vary. 

437 """ 

438 for param in params: 

439 if param in self._parameters: 

440 self._varying_params.add(param) 

441 self._invariant_params.discard(param) 

442 else: 

443 warn( 

444 f"Parameter '{param}' does not exist and cannot be added to time_vary." 

445 ) 

446 

447 def add_to_time_inv(self, *params: str) -> None: 

448 """ 

449 Adds any number of parameters to the time-invariant set. 

450 

451 Parameters 

452 ---------- 

453 *params : str 

454 Any number of strings naming parameters to be added to time_inv. 

455 """ 

456 for param in params: 

457 if param in self._parameters: 

458 self._invariant_params.add(param) 

459 self._varying_params.discard(param) 

460 else: 

461 warn( 

462 f"Parameter '{param}' does not exist and cannot be added to time_inv." 

463 ) 

464 

465 def del_from_time_vary(self, *params: str) -> None: 

466 """ 

467 Removes any number of parameters from the time-varying set. 

468 

469 Parameters 

470 ---------- 

471 *params : str 

472 Any number of strings naming parameters to be removed from time_vary. 

473 """ 

474 for param in params: 

475 self._varying_params.discard(param) 

476 

477 def del_from_time_inv(self, *params: str) -> None: 

478 """ 

479 Removes any number of parameters from the time-invariant set. 

480 

481 Parameters 

482 ---------- 

483 *params : str 

484 Any number of strings naming parameters to be removed from time_inv. 

485 """ 

486 for param in params: 

487 self._invariant_params.discard(param) 

488 

489 def get(self, key: str, default: Any = None) -> Any: 

490 """ 

491 Get a parameter value, returning a default if not found. 

492 

493 Parameters 

494 ---------- 

495 key : str 

496 The parameter name. 

497 default : Any, optional 

498 The default value to return if the key is not found. 

499 

500 Returns 

501 ------- 

502 Any 

503 The parameter value or the default. 

504 """ 

505 return self._parameters.get(key, default) 

506 

507 def set_many(self, **kwargs: Any) -> None: 

508 """ 

509 Set multiple parameters at once. 

510 

511 Parameters 

512 ---------- 

513 **kwargs : Keyword arguments representing parameter names and values. 

514 """ 

515 for key, value in kwargs.items(): 

516 self[key] = value 

517 

518 def is_time_varying(self, key: str) -> bool: 

519 """ 

520 Check if a parameter is time-varying. 

521 

522 Parameters 

523 ---------- 

524 key : str 

525 The parameter name. 

526 

527 Returns 

528 ------- 

529 bool 

530 True if the parameter is time-varying, False otherwise. 

531 """ 

532 return key in self._varying_params 

533 

534 def at_age(self, age: int) -> "Parameters": 

535 """ 

536 Get parameters for a specific age. 

537 

538 This is an alternative to integer indexing (params[age]) that is more 

539 explicit and avoids potential confusion with dictionary-style access. 

540 

541 Parameters 

542 ---------- 

543 age : int 

544 The age index to retrieve parameters for. 

545 

546 Returns 

547 ------- 

548 Parameters 

549 A new Parameters object with parameters for the specified age. 

550 

551 Raises 

552 ------ 

553 ValueError 

554 If the age index is out of bounds. 

555 

556 Examples 

557 -------- 

558 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97], sigma=2.0) 

559 >>> age_1_params = params.at_age(1) 

560 >>> age_1_params.beta 

561 0.96 

562 """ 

563 return self[age] 

564 

565 def validate(self) -> None: 

566 """ 

567 Validate parameter consistency. 

568 

569 Checks that all time-varying parameters have length matching T_cycle. 

570 This is useful after manual modifications or when parameters are set 

571 programmatically. 

572 

573 Raises 

574 ------ 

575 ValueError 

576 If any time-varying parameter has incorrect length. 

577 

578 Examples 

579 -------- 

580 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97]) 

581 >>> params.validate() # Passes 

582 >>> params.add_to_time_vary("beta") 

583 >>> params.validate() # Still passes 

584 """ 

585 errors = [] 

586 for param in self._varying_params: 

587 value = self._parameters[param] 

588 if isinstance(value, (list, tuple)): 

589 if len(value) != self._length: 

590 errors.append( 

591 f"Parameter '{param}' has length {len(value)}, expected {self._length}" 

592 ) 

593 elif isinstance(value, np.ndarray): 

594 if value.ndim == 0: 

595 errors.append( 

596 f"Parameter '{param}' is a 0-dimensional array (scalar), " 

597 "which should not be time-varying" 

598 ) 

599 elif value.ndim >= 2: 

600 if value.shape[0] != self._length: 

601 errors.append( 

602 f"Parameter '{param}' has first dimension {value.shape[0]}, expected {self._length}" 

603 ) 

604 elif value.ndim == 1: 

605 if len(value) != self._length: 

606 errors.append( 

607 f"Parameter '{param}' has length {len(value)}, expected {self._length}" 

608 ) 

609 elif value.ndim == 0: 

610 errors.append( 

611 f"Parameter '{param}' is a 0-dimensional numpy array, expected length {self._length}" 

612 ) 

613 

614 if errors: 

615 raise ValueError( 

616 "Parameter validation failed:\n" + "\n".join(f" - {e}" for e in errors) 

617 ) 

618 

619 

620class Model: 

621 """ 

622 A class with special handling of parameters assignment. 

623 """ 

624 

625 def __init__(self): 

626 if not hasattr(self, "parameters"): 

627 self.parameters = {} 

628 if not hasattr(self, "constructors"): 

629 self.constructors = {} 

630 

631 def assign_parameters(self, **kwds): 

632 """ 

633 Assign an arbitrary number of attributes to this agent. 

634 

635 Parameters 

636 ---------- 

637 **kwds : keyword arguments 

638 Any number of keyword arguments of the form key=value. 

639 Each value will be assigned to the attribute named in self. 

640 

641 Returns 

642 ------- 

643 None 

644 """ 

645 self.parameters.update(kwds) 

646 for key in kwds: 

647 setattr(self, key, kwds[key]) 

648 

649 def get_parameter(self, name): 

650 """ 

651 Returns a parameter of this model 

652 

653 Parameters 

654 ---------- 

655 name : str 

656 The name of the parameter to get 

657 

658 Returns 

659 ------- 

660 value : The value of the parameter 

661 """ 

662 return self.parameters[name] 

663 

664 def __eq__(self, other): 

665 if isinstance(other, type(self)): 

666 return self.parameters == other.parameters 

667 

668 return NotImplemented 

669 

670 def __str__(self): 

671 type_ = type(self) 

672 module = type_.__module__ 

673 qualname = type_.__qualname__ 

674 

675 s = f"<{module}.{qualname} object at {hex(id(self))}.\n" 

676 s += "Parameters:" 

677 

678 for p in self.parameters: 

679 s += f"\n{p}: {self.parameters[p]}" 

680 

681 s += ">" 

682 return s 

683 

684 def describe(self): 

685 return self.__str__() 

686 

687 def del_param(self, param_name): 

688 """ 

689 Deletes a parameter from this instance, removing it both from the object's 

690 namespace (if it's there) and the parameters dictionary (likewise). 

691 

692 Parameters 

693 ---------- 

694 param_name : str 

695 A string naming a parameter or data to be deleted from this instance. 

696 Removes information from self.parameters dictionary and own namespace. 

697 

698 Returns 

699 ------- 

700 None 

701 """ 

702 if param_name in self.parameters: 

703 del self.parameters[param_name] 

704 if hasattr(self, param_name): 

705 delattr(self, param_name) 

706 

707 def _gather_constructor_args(self, key, constructor): 

708 """ 

709 Gather all arguments needed to call the given constructor. 

710 

711 Handles both the special ``get_it_from`` case and normal callables. 

712 For a normal callable the method inspects the function signature to 

713 find every required argument, then resolves each argument from the 

714 instance namespace (``self.<arg>``) or from ``self.parameters``. 

715 Arguments that have a default value are silently skipped when they 

716 cannot be found; required arguments that cannot be resolved are 

717 recorded as missing. 

718 

719 Parameters 

720 ---------- 

721 key : str 

722 The name of the constructed object (used to record missing pairs). 

723 constructor : callable or get_it_from 

724 The constructor to be called. 

725 

726 Returns 

727 ------- 

728 temp_dict : dict 

729 Keyword arguments to pass to the constructor. 

730 any_missing : bool 

731 True when at least one required argument could not be resolved. 

732 missing_args : list of str 

733 Names of unresolved required arguments. 

734 missing_key_data : list of tuple 

735 ``(key, arg)`` pairs for every unresolved required argument. 

736 """ 

737 missing_key_data = [] 

738 

739 # SPECIAL: if the constructor is get_it_from, handle it separately 

740 if isinstance(constructor, get_it_from): 

741 try: 

742 parent = getattr(self, constructor.name) 

743 query = key 

744 any_missing = False 

745 missing_args = [] 

746 except AttributeError: 

747 parent = None 

748 query = None 

749 any_missing = True 

750 missing_args = [constructor.name] 

751 temp_dict = {"parent": parent, "query": query} 

752 return temp_dict, any_missing, missing_args, missing_key_data 

753 

754 # Normal constructor: inspect signature and gather arguments 

755 args_needed = get_arg_names(constructor) 

756 has_no_default = { 

757 k: v.default is inspect.Parameter.empty 

758 for k, v in inspect.signature(constructor).parameters.items() 

759 } 

760 temp_dict = {} 

761 any_missing = False 

762 missing_args = [] 

763 for this_arg in args_needed: 

764 if hasattr(self, this_arg): 

765 temp_dict[this_arg] = getattr(self, this_arg) 

766 else: 

767 try: 

768 temp_dict[this_arg] = self.parameters[this_arg] 

769 except KeyError: 

770 if has_no_default[this_arg]: 

771 # Record missing key-data pair 

772 any_missing = True 

773 missing_key_data.append((key, this_arg)) 

774 missing_args.append(this_arg) 

775 

776 return temp_dict, any_missing, missing_args, missing_key_data 

777 

778 def _attempt_construct(self, key, i, keys_complete, backup, errors, force): 

779 """ 

780 Attempt to construct the object for a single key. 

781 

782 The method looks up the constructor, gathers its arguments, runs it, 

783 and records any errors. It also handles the ``None``-constructor case 

784 (restore from backup) and the missing-args case (record and defer). 

785 

786 Parameters 

787 ---------- 

788 key : str 

789 The name of the object to construct. 

790 i : int 

791 Index of *key* inside the ``keys`` array (used to update 

792 ``keys_complete``). 

793 keys_complete : np.ndarray of bool 

794 Boolean array indicating which keys have been completed; mutated 

795 in place when this key succeeds. 

796 backup : dict 

797 Dictionary of pre-construction attribute values. 

798 errors : dict 

799 ``self._constructor_errors``; mutated in place. 

800 force : bool 

801 When True, swallow exceptions and continue; when False, re-raise. 

802 

803 Returns 

804 ------- 

805 accomplished : bool 

806 True when the key was completed (constructor ran or was None). 

807 missing_key_data : list of tuple 

808 ``(key, arg)`` pairs recorded for missing required arguments. 

809 """ 

810 missing_key_data = [] 

811 

812 # Look up the constructor for this key 

813 try: 

814 constructor = self.constructors[key] 

815 except Exception as not_found: 

816 errors[key] = "No constructor found for " + str(not_found) 

817 if force: 

818 return False, missing_key_data 

819 else: 

820 raise KeyError("No constructor found for " + key) from None 

821 

822 # If the constructor is None, restore from backup and mark complete 

823 if constructor is None: 

824 if key in backup.keys(): 

825 setattr(self, key, backup[key]) 

826 self.parameters[key] = backup[key] 

827 keys_complete[i] = True 

828 return True, missing_key_data 

829 

830 # Gather arguments for the constructor 

831 temp_dict, any_missing, missing_args, missing_key_data = ( 

832 self._gather_constructor_args(key, constructor) 

833 ) 

834 

835 # If all required data was found, run the constructor and store the result 

836 if not any_missing: 

837 try: 

838 temp = constructor(**temp_dict) 

839 except Exception as problem: 

840 errors[key] = str(type(problem)) + ": " + str(problem) 

841 self.del_param(key) 

842 if force: 

843 return False, missing_key_data 

844 else: 

845 raise 

846 setattr(self, key, temp) 

847 self.parameters[key] = temp 

848 if key in errors: 

849 del errors[key] 

850 keys_complete[i] = True 

851 return True, missing_key_data 

852 

853 # Some required arguments were missing; record and defer 

854 msg = "Missing required arguments: " + ", ".join(missing_args) 

855 errors[key] = msg 

856 self.del_param(key) 

857 # Never raise exceptions here, as the arguments might be filled in later 

858 return False, missing_key_data 

859 

860 def _construct_pass(self, keys, keys_complete, backup, errors, force): 

861 """ 

862 Perform one full sweep over all incomplete keys. 

863 

864 Calls ``_attempt_construct`` for every key that has not yet been 

865 completed and accumulates results. 

866 

867 Parameters 

868 ---------- 

869 keys : sequence of str 

870 All keys requested for construction. 

871 keys_complete : np.ndarray of bool 

872 Boolean array indicating which keys have been completed; mutated 

873 in place by ``_attempt_construct``. 

874 backup : dict 

875 Dictionary of pre-construction attribute values. 

876 errors : dict 

877 ``self._constructor_errors``; mutated in place. 

878 force : bool 

879 Passed through to ``_attempt_construct``. 

880 

881 Returns 

882 ------- 

883 anything_accomplished : bool 

884 True when at least one key was completed during this pass. 

885 missing_key_data : list of tuple 

886 Accumulated ``(key, arg)`` pairs for all unresolved required 

887 arguments across every incomplete key. 

888 """ 

889 anything_accomplished = False 

890 missing_key_data = [] 

891 

892 for i, key in enumerate(keys): 

893 if keys_complete[i]: 

894 continue # This key has already been built 

895 

896 accomplished, key_missing = self._attempt_construct( 

897 key, i, keys_complete, backup, errors, force 

898 ) 

899 if accomplished: 

900 anything_accomplished = True 

901 missing_key_data.extend(key_missing) 

902 

903 return anything_accomplished, missing_key_data 

904 

905 def construct(self, *args, force=False): 

906 """ 

907 Top-level method for building constructed inputs. If called without any 

908 inputs, construct builds each of the objects named in the keys of the 

909 constructors dictionary; it draws inputs for the constructors from the 

910 parameters dictionary and adds its results to the same. If passed one or 

911 more strings as arguments, the method builds only the named keys. The 

912 method will do multiple "passes" over the requested keys, as some cons- 

913 tructors require inputs built by other constructors. If any requested 

914 constructors failed to build due to missing data, those keys (and the 

915 missing data) will be named in self._missing_key_data. Other errors are 

916 recorded in the dictionary attribute _constructor_errors. 

917 

918 This method tries to "start from scratch" by removing prior constructed 

919 objects, holding them in a backup dictionary during construction. This 

920 is done so that dependencies among constructors are resolved properly, 

921 without mistakenly relying on "old information". A backup value is used 

922 if a constructor function is set to None (i.e. "don't do anything"), or 

923 if the construct method fails to produce a new object. 

924 

925 Parameters 

926 ---------- 

927 *args : str, optional 

928 Keys of self.constructors that are requested to be constructed. 

929 If no arguments are passed, *all* elements of the dictionary are implied. 

930 force : bool, optional 

931 When True, the method will force its way past any errors, including 

932 missing constructors, missing arguments for constructors, and errors 

933 raised during execution of constructors. Information about all such 

934 errors is stored in the dictionary attributes described above. When 

935 False (default), any errors or exception will be raised. 

936 

937 Returns 

938 ------- 

939 None 

940 """ 

941 # Set up the requested work 

942 if len(args) > 0: 

943 keys = args 

944 else: 

945 keys = list(self.constructors.keys()) 

946 N_keys = len(keys) 

947 keys_complete = np.zeros(N_keys, dtype=bool) 

948 if N_keys == 0: 

949 return # Do nothing if there are no constructed objects 

950 

951 # Remove pre-existing constructed objects, preventing "incomplete" updates, 

952 # but store the current values in a backup dictionary in case something fails 

953 backup = {} 

954 for key in keys: 

955 if hasattr(self, key): 

956 backup[key] = getattr(self, key) 

957 self.del_param(key) 

958 

959 # Get the dictionary of constructor errors 

960 if not hasattr(self, "_constructor_errors"): 

961 self._constructor_errors = {} 

962 errors = self._constructor_errors 

963 

964 # As long as the work isn't complete and we made some progress on the last 

965 # pass, repeatedly perform passes of trying to construct objects 

966 any_keys_incomplete = np.any(np.logical_not(keys_complete)) 

967 go = any_keys_incomplete 

968 while go: 

969 anything_accomplished, missing_key_data = self._construct_pass( 

970 keys, keys_complete, backup, errors, force 

971 ) 

972 any_keys_incomplete = np.any(np.logical_not(keys_complete)) 

973 go = any_keys_incomplete and anything_accomplished 

974 

975 # Store missing key-data pairs and exit 

976 self._missing_key_data = missing_key_data 

977 self._constructor_errors = errors 

978 if any_keys_incomplete: 

979 msg = "Did not construct these objects:" 

980 for i in range(N_keys): 

981 if keys_complete[i]: 

982 continue 

983 msg += " " + keys[i] + "," 

984 key = keys[i] 

985 if key in backup.keys(): 

986 setattr(self, key, backup[key]) 

987 self.parameters[key] = backup[key] 

988 msg = msg[:-1] 

989 if not force: 

990 raise ValueError(msg) 

991 return 

992 

993 def describe_constructors(self, *args): 

994 """ 

995 Prints to screen a string describing this instance's constructed objects, 

996 including their names, the function that constructs them, the names of 

997 those functions inputs, and whether those inputs are present. 

998 

999 Parameters 

1000 ---------- 

1001 *args : str, optional 

1002 Optional list of strings naming constructed inputs to be described. 

1003 If none are passed, all constructors are described. 

1004 

1005 Returns 

1006 ------- 

1007 None 

1008 """ 

1009 if len(args) > 0: 

1010 keys = args 

1011 else: 

1012 keys = list(self.constructors.keys()) 

1013 yes = "\u2713" 

1014 no = "X" 

1015 maybe = "*" 

1016 noyes = [no, yes] 

1017 

1018 out = "" 

1019 for key in keys: 

1020 has_val = hasattr(self, key) or (key in self.parameters) 

1021 

1022 try: 

1023 constructor = self.constructors[key] 

1024 except KeyError: 

1025 out += noyes[int(has_val)] + " " + key + " : NO CONSTRUCTOR FOUND\n" 

1026 continue 

1027 

1028 # Get the constructor function if possible 

1029 if isinstance(constructor, get_it_from): 

1030 parent_name = self.constructors[key].name 

1031 out += ( 

1032 noyes[int(has_val)] 

1033 + " " 

1034 + key 

1035 + " : get it from " 

1036 + parent_name 

1037 + "\n" 

1038 ) 

1039 continue 

1040 else: 

1041 out += ( 

1042 noyes[int(has_val)] 

1043 + " " 

1044 + key 

1045 + " : " 

1046 + constructor.__name__ 

1047 + "\n" 

1048 ) 

1049 

1050 # Get constructor argument names 

1051 arg_names = get_arg_names(constructor) 

1052 has_no_default = { 

1053 k: v.default is inspect.Parameter.empty 

1054 for k, v in inspect.signature(constructor).parameters.items() 

1055 } 

1056 

1057 # Check whether each argument exists 

1058 for j in range(len(arg_names)): 

1059 this_arg = arg_names[j] 

1060 if hasattr(self, this_arg) or this_arg in self.parameters: 

1061 symb = yes 

1062 elif not has_no_default[this_arg]: 

1063 symb = maybe 

1064 else: 

1065 symb = no 

1066 out += " " + symb + " " + this_arg + "\n" 

1067 

1068 # Print the string to screen 

1069 print(out) 

1070 return 

1071 

1072 # This is a "synonym" method so that old calls to update() still work 

1073 def update(self, *args, **kwargs): 

1074 self.construct(*args, **kwargs) 

1075 

1076 

1077class AgentType(Model): 

1078 """ 

1079 A superclass for economic agents in the HARK framework. Each model should 

1080 specify its own subclass of AgentType, inheriting its methods and overwriting 

1081 as necessary. Critically, every subclass of AgentType should define class- 

1082 specific static values of the attributes time_vary and time_inv as lists of 

1083 strings. Each element of time_vary is the name of a field in AgentSubType 

1084 that varies over time in the model. Each element of time_inv is the name of 

1085 a field in AgentSubType that is constant over time in the model. 

1086 

1087 Parameters 

1088 ---------- 

1089 solution_terminal : Solution 

1090 A representation of the solution to the terminal period problem of 

1091 this AgentType instance, or an initial guess of the solution if this 

1092 is an infinite horizon problem. 

1093 cycles : int 

1094 The number of times the sequence of periods is experienced by this 

1095 AgentType in their "lifetime". cycles=1 corresponds to a lifecycle 

1096 model, with a certain sequence of one period problems experienced 

1097 once before terminating. cycles=0 corresponds to an infinite horizon 

1098 model, with a sequence of one period problems repeating indefinitely. 

1099 pseudo_terminal : bool 

1100 Indicates whether solution_terminal isn't actually part of the 

1101 solution to the problem (as a known solution to the terminal period 

1102 problem), but instead represents a "scrap value"-style termination. 

1103 When True, solution_terminal is not included in the solution; when 

1104 False, solution_terminal is the last element of the solution. 

1105 tolerance : float 

1106 Maximum acceptable "distance" between successive solutions to the 

1107 one period problem in an infinite horizon (cycles=0) model in order 

1108 for the solution to be considered as having "converged". Inoperative 

1109 when cycles>0. 

1110 verbose : int 

1111 Level of output to be displayed by this instance, default is 1. 

1112 quiet : bool 

1113 Indicator for whether this instance should operate "quietly", default False. 

1114 seed : int 

1115 A seed for this instance's random number generator. 

1116 construct : bool 

1117 Indicator for whether this instance's construct() method should be run 

1118 when initialized (default True). When False, an instance of the class 

1119 can be created even if not all of its attributes can be constructed. 

1120 use_defaults : bool 

1121 Indicator for whether this instance should use the values in the class' 

1122 default dictionary to fill in parameters and constructors for those not 

1123 provided by the user (default True). Setting this to False is useful for 

1124 situations where the user wants to be absolutely sure that they know what 

1125 is being passed to the class initializer, without resorting to defaults. 

1126 

1127 Attributes 

1128 ---------- 

1129 AgentCount : int 

1130 The number of agents of this type to use in simulation. 

1131 

1132 state_vars : list of string 

1133 The string labels for this AgentType's model state variables. 

1134 """ 

1135 

1136 time_vary_ = [] 

1137 time_inv_ = [] 

1138 shock_vars_ = [] 

1139 state_vars = [] 

1140 poststate_vars = [] 

1141 market_vars = [] 

1142 distributions = [] 

1143 default_ = {"params": {}, "solver": NullFunc()} 

1144 

1145 def __init__( 

1146 self, 

1147 solution_terminal=None, 

1148 pseudo_terminal=True, 

1149 tolerance=0.000001, 

1150 verbose=1, 

1151 quiet=False, 

1152 seed=0, 

1153 construct=True, 

1154 use_defaults=True, 

1155 **kwds, 

1156 ): 

1157 super().__init__() 

1158 params = deepcopy(self.default_["params"]) if use_defaults else {} 

1159 params.update(kwds) 

1160 

1161 # Correctly handle constructors that have been passed in kwds 

1162 if "constructors" in self.default_["params"].keys() and use_defaults: 

1163 constructors = deepcopy(self.default_["params"]["constructors"]) 

1164 else: 

1165 constructors = {} 

1166 if "constructors" in kwds.keys(): 

1167 constructors.update(kwds["constructors"]) 

1168 params["constructors"] = constructors 

1169 

1170 # Set default track_vars 

1171 if "track_vars" in self.default_.keys() and use_defaults: 

1172 self.track_vars = copy(self.default_["track_vars"]) 

1173 else: 

1174 self.track_vars = [] 

1175 

1176 # Set model file name if possible 

1177 try: 

1178 self.model_file = copy(self.default_["model"]) 

1179 except (KeyError, TypeError): 

1180 # Fallback to None if "model" key is missing or invalid for copying 

1181 self.model_file = None 

1182 

1183 if solution_terminal is None: 

1184 solution_terminal = NullFunc() 

1185 

1186 self.solve_one_period = self.default_["solver"] # NOQA 

1187 self.solution_terminal = solution_terminal # NOQA 

1188 self.pseudo_terminal = pseudo_terminal # NOQA 

1189 self.tolerance = tolerance # NOQA 

1190 self.verbose = verbose 

1191 self.quiet = quiet 

1192 self.seed = seed # NOQA 

1193 self.state_now = {sv: None for sv in self.state_vars} 

1194 self.state_prev = self.state_now.copy() 

1195 self.controls = {} 

1196 self.shocks = {} 

1197 self.read_shocks = False # NOQA 

1198 self.shock_history = {} 

1199 self.newborn_init_history = {} 

1200 self.history = {} 

1201 self.assign_parameters(**params) # NOQA 

1202 self.reset_rng() # NOQA 

1203 self.bilt = {} 

1204 if construct: 

1205 self.construct() 

1206 

1207 # Add instance-level lists and objects 

1208 self.time_vary = deepcopy(self.time_vary_) 

1209 self.time_inv = deepcopy(self.time_inv_) 

1210 self.shock_vars = deepcopy(self.shock_vars_) 

1211 

1212 def add_to_time_vary(self, *params): 

1213 """ 

1214 Adds any number of parameters to time_vary for this instance. 

1215 

1216 Parameters 

1217 ---------- 

1218 params : string 

1219 Any number of strings naming attributes to be added to time_vary 

1220 

1221 Returns 

1222 ------- 

1223 None 

1224 """ 

1225 for param in params: 

1226 if param not in self.time_vary: 

1227 self.time_vary.append(param) 

1228 

1229 def add_to_time_inv(self, *params): 

1230 """ 

1231 Adds any number of parameters to time_inv for this instance. 

1232 

1233 Parameters 

1234 ---------- 

1235 params : string 

1236 Any number of strings naming attributes to be added to time_inv 

1237 

1238 Returns 

1239 ------- 

1240 None 

1241 """ 

1242 for param in params: 

1243 if param not in self.time_inv: 

1244 self.time_inv.append(param) 

1245 

1246 def del_from_time_vary(self, *params): 

1247 """ 

1248 Removes any number of parameters from time_vary for this instance. 

1249 

1250 Parameters 

1251 ---------- 

1252 params : string 

1253 Any number of strings naming attributes to be removed from time_vary 

1254 

1255 Returns 

1256 ------- 

1257 None 

1258 """ 

1259 for param in params: 

1260 if param in self.time_vary: 

1261 self.time_vary.remove(param) 

1262 

1263 def del_from_time_inv(self, *params): 

1264 """ 

1265 Removes any number of parameters from time_inv for this instance. 

1266 

1267 Parameters 

1268 ---------- 

1269 params : string 

1270 Any number of strings naming attributes to be removed from time_inv 

1271 

1272 Returns 

1273 ------- 

1274 None 

1275 """ 

1276 for param in params: 

1277 if param in self.time_inv: 

1278 self.time_inv.remove(param) 

1279 

1280 def unpack(self, name): 

1281 """ 

1282 Unpacks an attribute from a solution object for easier access. 

1283 After the model has been solved, its components (like consumption function) 

1284 reside in the attributes of each element of `ThisType.solution` (e.g. `cFunc`). 

1285 This method creates a (time varying) attribute of the given attribute name 

1286 that contains a list of elements accessible by `ThisType.parameter`. 

1287 

1288 Parameters 

1289 ---------- 

1290 name: str 

1291 Name of the attribute to unpack from the solution 

1292 

1293 Returns 

1294 ------- 

1295 none 

1296 """ 

1297 # Use list comprehension for better performance instead of loop with append 

1298 if type(self.solution[0]) is dict: 

1299 setattr(self, name, [soln_t[name] for soln_t in self.solution]) 

1300 else: 

1301 setattr(self, name, [soln_t.__dict__[name] for soln_t in self.solution]) 

1302 self.add_to_time_vary(name) 

1303 

1304 def solve( 

1305 self, 

1306 verbose=False, 

1307 presolve=True, 

1308 postsolve=True, 

1309 from_solution=None, 

1310 from_t=None, 

1311 ): 

1312 """ 

1313 Solve the model for this instance of an agent type by backward induction. 

1314 Loops through the sequence of one period problems, passing the solution 

1315 from period t+1 to the problem for period t. 

1316 

1317 Parameters 

1318 ---------- 

1319 verbose : bool, optional 

1320 If True, solution progress is printed to screen. Default False. 

1321 presolve : bool, optional 

1322 If True (default), the pre_solve method is run before solving. 

1323 postsolve : bool, optional 

1324 If True (default), the post_solve method is run after solving. 

1325 from_solution: Solution 

1326 If different from None, will be used as the starting point of backward 

1327 induction, instead of self.solution_terminal. 

1328 from_t : int or None 

1329 If not None, indicates which period of the model the solver should start 

1330 from. It should usually only be used in combination with from_solution. 

1331 Stands for the time index that from_solution represents, and thus is 

1332 only compatible with cycles=1 and will be reset to None otherwise. 

1333 

1334 Returns 

1335 ------- 

1336 none 

1337 """ 

1338 

1339 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep- 

1340 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is 

1341 # -np.inf, np.inf/np.inf is np.nan and so on. 

1342 with np.errstate( 

1343 divide="ignore", over="ignore", under="ignore", invalid="ignore" 

1344 ): 

1345 if presolve: 

1346 self.pre_solve() # Do pre-solution stuff 

1347 self.solution = solve_agent( 

1348 self, 

1349 verbose, 

1350 from_solution, 

1351 from_t, 

1352 ) # Solve the model by backward induction 

1353 if postsolve: 

1354 self.post_solve() # Do post-solution stuff 

1355 

1356 def reset_rng(self): 

1357 """ 

1358 Reset the random number generator and all distributions for this type. 

1359 Type-checking for lists is to handle the following three cases: 

1360 

1361 1) The target is a single distribution object 

1362 2) The target is a list of distribution objects (probably time-varying) 

1363 3) The target is a nested list of distributions, as in ConsMarkovModel. 

1364 """ 

1365 self.RNG = np.random.default_rng(self.seed) 

1366 for name in self.distributions: 

1367 if not hasattr(self, name): 

1368 continue 

1369 

1370 dstn = getattr(self, name) 

1371 if isinstance(dstn, list): 

1372 for D in dstn: 

1373 if isinstance(D, list): 

1374 for d in D: 

1375 d.reset() 

1376 else: 

1377 D.reset() 

1378 else: 

1379 dstn.reset() 

1380 

1381 def check_elements_of_time_vary_are_lists(self): 

1382 """ 

1383 A method to check that elements of time_vary are lists. 

1384 """ 

1385 for param in self.time_vary: 

1386 if not hasattr(self, param): 

1387 continue 

1388 if not isinstance( 

1389 getattr(self, param), 

1390 (IndexDistribution,), 

1391 ): 

1392 assert type(getattr(self, param)) == list, ( 

1393 param 

1394 + " is not a list or time varying distribution," 

1395 + " but should be because it is in time_vary" 

1396 ) 

1397 

1398 def check_restrictions(self): 

1399 """ 

1400 A method to check that various restrictions are met for the model class. 

1401 """ 

1402 return 

1403 

1404 def pre_solve(self): 

1405 """ 

1406 A method that is run immediately before the model is solved, to check inputs or to prepare 

1407 the terminal solution, perhaps. 

1408 

1409 Parameters 

1410 ---------- 

1411 none 

1412 

1413 Returns 

1414 ------- 

1415 none 

1416 """ 

1417 self.check_restrictions() 

1418 self.check_elements_of_time_vary_are_lists() 

1419 return None 

1420 

1421 def post_solve(self): 

1422 """ 

1423 A method that is run immediately after the model is solved, to finalize 

1424 the solution in some way. Does nothing here. 

1425 

1426 Parameters 

1427 ---------- 

1428 none 

1429 

1430 Returns 

1431 ------- 

1432 none 

1433 """ 

1434 return None 

1435 

1436 def get_market_params(self, mkt, construct=True): 

1437 """ 

1438 Fetch data named in class attribute market_vars and assign it as attributes 

1439 (and parameters) of self. By default, the construct method is run within 

1440 this method, because the market parameters often have information needed to 

1441 "complete" the microeconomic problem. 

1442 

1443 This method is called automatically by the Market.give_agent_params() 

1444 method for all agents. 

1445 

1446 Parameters 

1447 ---------- 

1448 mkt : Market 

1449 Market to which this AgentType belongs. 

1450 construct : bool 

1451 Indicator for whether constructed attributes should be updated after 

1452 fetching data / parameters from mkt (default True) 

1453 

1454 Returns 

1455 ------- 

1456 None 

1457 """ 

1458 temp_dict = {} 

1459 for name in self.market_vars: 

1460 temp_dict[name] = copy(getattr(mkt, name)) 

1461 self.assign_parameters(**temp_dict) 

1462 if construct: 

1463 self.construct() 

1464 

1465 def initialize_sym(self, **kwargs): 

1466 """ 

1467 Use the new simulator structure to build a simulator from the agents' 

1468 attributes, storing it in a private attribute. 

1469 """ 

1470 self.reset_rng() # ensure seeds are set identically each time 

1471 self._simulator = make_simulator_from_agent(self, **kwargs) 

1472 self._simulator.reset() 

1473 

1474 def find_target(self, target_var, force_list=False, **kwargs): 

1475 r""" 

1476 Find the "target" level of a named variable such that $E[\Delta x] = 0$, 

1477 with $E[\Delta x-\epsilon] > 0$ and $E[\Delta x+\epsilon] < 0$ (locally stable). 

1478 Returns a single real value if there is only one target, and a list if multiple; 

1479 returns np.nan if no target is found. Pass force_list=True to always get a list. 

1480 See documentation for HARK.simulator.find_target_state for more options. 

1481 """ 

1482 if not hasattr(self, "solution"): 

1483 raise AttributeError("Model must be solved before using find_target!") 

1484 temp_simulator = make_simulator_from_agent(self) 

1485 target_vals = temp_simulator.find_target_state(target_var, **kwargs) 

1486 if force_list: 

1487 return target_vals 

1488 if len(target_vals) == 0: 

1489 return np.nan 

1490 elif len(target_vals) == 1: 

1491 return target_vals[0] 

1492 else: 

1493 return target_vals 

1494 

1495 def initialize_sim(self): 

1496 """ 

1497 Prepares this AgentType for a new simulation. Resets the internal random number generator, 

1498 makes initial states for all agents (using sim_birth), clears histories of tracked variables. 

1499 

1500 Parameters 

1501 ---------- 

1502 None 

1503 

1504 Returns 

1505 ------- 

1506 None 

1507 """ 

1508 if not hasattr(self, "T_sim"): 

1509 raise Exception( 

1510 "To initialize simulation variables it is necessary to first " 

1511 + "set the attribute T_sim to the largest number of observations " 

1512 + "you plan to simulate for each agent including re-births." 

1513 ) 

1514 elif self.T_sim <= 0: 

1515 raise Exception( 

1516 "T_sim represents the largest number of observations " 

1517 + "that can be simulated for an agent, and must be a positive number." 

1518 ) 

1519 

1520 self.reset_rng() 

1521 self.t_sim = 0 

1522 all_agents = np.ones(self.AgentCount, dtype=bool) 

1523 blank_array = np.empty(self.AgentCount) 

1524 blank_array[:] = np.nan 

1525 for var in self.state_vars: 

1526 self.state_now[var] = copy(blank_array) 

1527 

1528 # Number of periods since agent entry 

1529 self.t_age = np.zeros(self.AgentCount, dtype=int) 

1530 # Which cycle period each agent is on 

1531 self.t_cycle = np.zeros(self.AgentCount, dtype=int) 

1532 self.sim_birth(all_agents) 

1533 

1534 # If we are asked to use existing shocks and a set of initial conditions 

1535 # exist, use them 

1536 if self.read_shocks and bool(self.newborn_init_history): 

1537 for var_name in self.state_now: 

1538 # Check that we are actually given a value for the variable 

1539 if var_name in self.newborn_init_history.keys(): 

1540 # Copy only array-like idiosyncratic states. Aggregates should 

1541 # not be set by newborns 

1542 idio = ( 

1543 isinstance(self.state_now[var_name], np.ndarray) 

1544 and len(self.state_now[var_name]) == self.AgentCount 

1545 ) 

1546 if idio: 

1547 self.state_now[var_name] = self.newborn_init_history[var_name][ 

1548 0 

1549 ] 

1550 

1551 else: 

1552 warn( 

1553 "The option for reading shocks was activated but " 

1554 + "the model requires state " 

1555 + var_name 

1556 + ", not contained in " 

1557 + "newborn_init_history." 

1558 ) 

1559 

1560 self.clear_history() 

1561 

1562 def _export_single_var_by_time(self, history, var, t, dtype): 

1563 """ 

1564 Mode 1a: single variable, by_age=False. 

1565 

1566 Returns a DataFrame whose columns are simulation periods (t_sim) and 

1567 whose rows are agent indices. When t is None all T_sim periods are 

1568 included; when t is an array only those periods are included. 

1569 """ 

1570 try: 

1571 data = history[var] 

1572 except KeyError: 

1573 raise KeyError("Variable named " + var + " not found in simulated data!") 

1574 

1575 if t is None: 

1576 cols = [str(i) for i in range(self.T_sim)] 

1577 df = DataFrame(data=data.T, columns=cols, dtype=dtype) 

1578 else: 

1579 cols = [str(t_val) for t_val in t] 

1580 df = DataFrame(data=data[t, :].T, columns=cols, dtype=dtype) 

1581 return df 

1582 

1583 def _export_single_var_by_age(self, history, var, age, t, dtype, sym): 

1584 """ 

1585 Mode 1b: single variable, by_age=True. 

1586 

1587 Returns a DataFrame whose columns are within-agent model ages (t_age) 

1588 and whose rows are individual agent lifetimes. Observations after 

1589 death (or before birth) are NaN. When t is None all ages up to 

1590 max(t_age) are included; when t is an array only those ages are used. 

1591 The sym flag controls newborn detection: age == 0 when sym is True, 

1592 age == 1 when sym is False. 

1593 """ 

1594 try: 

1595 data = history[var] 

1596 except KeyError: 

1597 raise KeyError("Variable named " + var + " not found in simulated data!") 

1598 

1599 # Determine which ages to include and mark qualifying observations 

1600 if t is None: 

1601 age_set = np.arange(np.max(age) + 1) 

1602 in_age_set = np.ones_like(data, dtype=bool) 

1603 else: 

1604 age_set = t 

1605 in_age_set = np.zeros_like(data, dtype=bool) 

1606 for j in age_set: 

1607 these = age == j 

1608 in_age_set[these] = True 

1609 

1610 # Locate newborns to determine the number of individual lifetimes (rows) 

1611 newborns = age == 0 if sym else age == 1 

1612 T = age_set.size # number of age columns 

1613 N = np.sum(newborns) # number of agent lifetimes (rows) 

1614 out = np.full((N, T), np.nan) 

1615 

1616 # Extract each individual's sequence and place it into the output array 

1617 n = 0 

1618 for i in range(self.AgentCount): 

1619 data_i = data[:, i] 

1620 births = np.where(newborns[:, i])[0] 

1621 K = births.size 

1622 for k in range(K): 

1623 start = births[k] 

1624 stop = births[k + 1] if (k < K - 1) else self.T_sim 

1625 use = in_age_set[start:stop, i] 

1626 temp = data_i[start:stop][use] 

1627 out[n, : temp.size] = temp 

1628 n += 1 

1629 

1630 cols = [str(a) for a in age_set] 

1631 df = DataFrame(data=out, columns=cols, dtype=dtype) 

1632 return df 

1633 

1634 def _export_single_t_by_time(self, history, var_list, t, dtype): 

1635 """ 

1636 Mode 2a: single time period, by_age=False. 

1637 

1638 Returns a DataFrame with one row per agent and one column per variable, 

1639 drawn from the absolute simulation period t (i.e. history[name][t, :]). 

1640 """ 

1641 K = len(var_list) 

1642 N = self.AgentCount 

1643 out = np.full((N, K), np.nan) 

1644 for k, name in enumerate(var_list): 

1645 out[:, k] = history[name][t, :] 

1646 df = DataFrame(data=out, columns=var_list, dtype=dtype) 

1647 return df 

1648 

1649 def _export_single_t_by_age(self, history, var_list, age, t, dtype): 

1650 """ 

1651 Mode 2b: single time period, by_age=True. 

1652 

1653 Returns a DataFrame with one row per agent-period at which t_age == t 

1654 and one column per variable. 

1655 """ 

1656 right_age = age == t 

1657 N = np.sum(right_age) 

1658 K = len(var_list) 

1659 out = np.full((N, K), np.nan) 

1660 for k, name in enumerate(var_list): 

1661 out[:, k] = history[name][right_age] 

1662 df = DataFrame(data=out, columns=var_list, dtype=dtype) 

1663 return df 

1664 

1665 def export_to_df(self, var=None, t=None, by_age=False, dtype=None, sym=False): 

1666 """ 

1667 Export an AgentType instance's simulated data to a pandas dataframe object. 

1668 There are four construction modes depending on the arguments passed: 

1669 

1670 1a) If exactly one simulated variable is named as var and by_age is False, 

1671 then the dataframe will contain T_sim columns, each representing one 

1672 simulated period in absolute simulation time t_sim. Each row of the 

1673 dataframe will represent one *agent index* of the population, with death 

1674 and replacement occurring within a row. Optionally, argument t can be 

1675 provided as an array to specify which periods to include (default all). 

1676 

1677 1b) If exactly one simulated variable is named as var and by_age is True, 

1678 then the dataframe's columns will correspond to within-agent model age 

1679 t_age. Each row of the dataframe will represent one specific agent from 

1680 model entry (t_age=0) to model death. All observations after death will 

1681 be NaN. Optionally, argument t can be provided as an array to specify 

1682 which ages to include (default all). Number of columns in dataframe will 

1683 depend on max(t_age) and/or argument t. 

1684 

1685 2a) If an integer is provided as t and by_age is False, then each column of 

1686 the dataframe will represent a different simulated variable, using the 

1687 value for the specified absolute simulated period t=t_sim. Optionally, 

1688 the var argument can be provided as a list of strings naming which var- 

1689 iables should be included in the dataframe (default all). 

1690 

1691 2b) If an integer is provided as t and by_age is True, then each column of 

1692 the dataframe will represent a different simulated variable, taken from 

1693 all agent-periods at which t == t_age, within-agent model age. Optionally, 

1694 the var argument can be provided as a list of strings naming which var- 

1695 iables should be included in the dataframe (default all). 

1696 

1697 In summary, *either* var should be a single string *or* t should be an integer. 

1698 Any other combination of var and t will raise an exception. 

1699 

1700 Parameters 

1701 ---------- 

1702 var : str or [str] or None 

1703 If a single string is provided, it represents the name of the one simulated 

1704 variable to export. If a list of strings, then the argument t must also be 

1705 provided to indicate which time period the dataframe will represent. Name(s) 

1706 must correspond to a key for history or hystory dictionary (i.e. named in track_vars). 

1707 If not provided, then all keys in history or hystory are included. 

1708 t : int or np.array or None 

1709 If an integer, indicates which one period will be included in the dataframe. 

1710 When by_age is False (default), t refers to absolute simulated time t_sim: 

1711 literally the t-th row of history[key]. When by_age is True, t refers to 

1712 within-agent model age t_age; the dataframe will include all agent-periods 

1713 where the agent has exactly t_age==t. If var is a single string, then t is 

1714 an optional input as an array of periods (or ages) to include (default all). 

1715 by_age : bool 

1716 Indicator for whether observation selection should be on the basis of absolute 

1717 simulated time t_sim or within-agent model age t_age. If True, then t_age 

1718 must be in track_vars so that it appears in the simulated data. Additionally, 

1719 argument dtype should *not* be provided when by_age is True, as this will 

1720 result in NaNs being cast to a datatype that doesn't necessarily support them. 

1721 dtype : type or None 

1722 Optional data type to cast the dataframe. By default, uses the datatype from 

1723 the entry in history or hystory. 

1724 sym : bool 

1725 Indicator for whether the dataframe should look for simulated data in the 

1726 history (False, default) or hystory (True) dictionary attribute. This option 

1727 will be deprecated in the future when legacy simulation methods are removed. 

1728 

1729 Returns 

1730 ------- 

1731 df : pandas.DataFrame 

1732 The requested dataframe, constructed from this instance's simulated data. 

1733 """ 

1734 # Validate arguments 

1735 single_var = type(var) is str 

1736 single_t = isinstance(t, (int, np.integer)) 

1737 if not (single_var ^ single_t): 

1738 raise ValueError( 

1739 "Either var must be a single string, or t must be a single integer!" 

1740 ) 

1741 if dtype is not None and by_age: 

1742 raise ValueError( 

1743 "Can't specify dtype when using by_age is True because of potential incompatibility with representing NaN" 

1744 ) 

1745 

1746 # Get the relevant history dictionary (deprecate in future) 

1747 history = self.hystory if sym else self.history 

1748 

1749 # Retrieve age array once if needed (raises a clear error when missing) 

1750 if by_age: 

1751 try: 

1752 age = history["t_age"] 

1753 except KeyError: 

1754 raise KeyError( 

1755 "t_age must be in track_vars if by_age=True will be used!" 

1756 ) 

1757 

1758 # Route to the appropriate private method 

1759 if single_var and not by_age: 

1760 return self._export_single_var_by_time(history, var, t, dtype) 

1761 elif single_var and by_age: 

1762 return self._export_single_var_by_age(history, var, age, t, dtype, sym) 

1763 else: # single_t 

1764 # Build and validate the variable list 

1765 if var is None: 

1766 var_list = list(history.keys()) 

1767 else: 

1768 var_list = copy(var) 

1769 sim_keys = list(history.keys()) 

1770 for name in var_list: 

1771 if name not in sim_keys: 

1772 raise KeyError( 

1773 "Variable called " + name + " not found in simulation data!" 

1774 ) 

1775 if by_age: 

1776 return self._export_single_t_by_age(history, var_list, age, t, dtype) 

1777 else: 

1778 return self._export_single_t_by_time(history, var_list, t, dtype) 

1779 

1780 def sim_one_period(self): 

1781 """ 

1782 Simulates one period for this type. Calls the methods get_mortality(), get_shocks() or 

1783 read_shocks, get_states(), get_controls(), and get_poststates(). These should be defined for 

1784 AgentType subclasses, except get_mortality (define its components sim_death and sim_birth 

1785 instead) and read_shocks. 

1786 

1787 Parameters 

1788 ---------- 

1789 None 

1790 

1791 Returns 

1792 ------- 

1793 None 

1794 """ 

1795 if not hasattr(self, "solution"): 

1796 raise Exception( 

1797 "Model instance does not have a solution stored. To simulate, it is necessary" 

1798 " to run the `solve()` method first." 

1799 ) 

1800 

1801 # Mortality adjusts the agent population 

1802 self.get_mortality() # Replace some agents with "newborns" 

1803 

1804 # state_{t-1} 

1805 for var in self.state_now: 

1806 self.state_prev[var] = self.state_now[var] 

1807 

1808 if isinstance(self.state_now[var], np.ndarray): 

1809 self.state_now[var] = np.empty(self.AgentCount) 

1810 else: 

1811 # Probably an aggregate variable. It may be getting set by the Market. 

1812 pass 

1813 

1814 if self.read_shocks: # If shock histories have been pre-specified, use those 

1815 self.read_shocks_from_history() 

1816 else: # Otherwise, draw shocks as usual according to subclass-specific method 

1817 self.get_shocks() 

1818 self.get_states() # Determine each agent's state at decision time 

1819 self.get_controls() # Determine each agent's choice or control variables based on states 

1820 self.get_poststates() # Calculate variables that come *after* decision-time 

1821 

1822 # Advance time for all agents 

1823 self.t_age = self.t_age + 1 # Age all consumers by one period 

1824 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle 

1825 self.t_cycle[self.t_cycle == self.T_cycle] = ( 

1826 0 # Resetting to zero for those who have reached the end 

1827 ) 

1828 

1829 def make_shock_history(self): 

1830 """ 

1831 Makes a pre-specified history of shocks for the simulation. Shock variables should be named 

1832 in self.shock_vars, a list of strings that is subclass-specific. This method runs a subset 

1833 of the standard simulation loop by simulating only mortality and shocks; each variable named 

1834 in shock_vars is stored in a T_sim x AgentCount array in history dictionary self.history[X]. 

1835 Automatically sets self.read_shocks to True so that these pre-specified shocks are used for 

1836 all subsequent calls to simulate(). 

1837 

1838 Parameters 

1839 ---------- 

1840 None 

1841 

1842 Returns 

1843 ------- 

1844 None 

1845 """ 

1846 # Re-initialize the simulation 

1847 self.initialize_sim() 

1848 

1849 # Make blank history arrays for each shock variable (and mortality) 

1850 for var_name in self.shock_vars: 

1851 self.shock_history[var_name] = ( 

1852 np.zeros((self.T_sim, self.AgentCount)) + np.nan 

1853 ) 

1854 self.shock_history["who_dies"] = np.zeros( 

1855 (self.T_sim, self.AgentCount), dtype=bool 

1856 ) 

1857 

1858 # Also make blank arrays for the draws of newborns' initial conditions 

1859 for var_name in self.state_vars: 

1860 self.newborn_init_history[var_name] = ( 

1861 np.zeros((self.T_sim, self.AgentCount)) + np.nan 

1862 ) 

1863 

1864 # Record the initial condition of the newborns created by 

1865 # initialize_sim -> sim_births 

1866 for var_name in self.state_vars: 

1867 # Check whether the state is idiosyncratic or an aggregate 

1868 idio = ( 

1869 isinstance(self.state_now[var_name], np.ndarray) 

1870 and len(self.state_now[var_name]) == self.AgentCount 

1871 ) 

1872 if idio: 

1873 self.newborn_init_history[var_name][self.t_sim] = self.state_now[ 

1874 var_name 

1875 ] 

1876 else: 

1877 # Aggregate state is a scalar. Assign it to every agent. 

1878 self.newborn_init_history[var_name][self.t_sim, :] = self.state_now[ 

1879 var_name 

1880 ] 

1881 

1882 # Make and store the history of shocks for each period 

1883 for t in range(self.T_sim): 

1884 # Deaths 

1885 self.get_mortality() 

1886 self.shock_history["who_dies"][t, :] = self.who_dies 

1887 

1888 # Initial conditions of newborns 

1889 if self.who_dies.any(): 

1890 for var_name in self.state_vars: 

1891 # Check whether the state is idiosyncratic or an aggregate 

1892 idio = ( 

1893 isinstance(self.state_now[var_name], np.ndarray) 

1894 and len(self.state_now[var_name]) == self.AgentCount 

1895 ) 

1896 if idio: 

1897 self.newborn_init_history[var_name][t, self.who_dies] = ( 

1898 self.state_now[var_name][self.who_dies] 

1899 ) 

1900 else: 

1901 self.newborn_init_history[var_name][t, self.who_dies] = ( 

1902 self.state_now[var_name] 

1903 ) 

1904 

1905 # Other Shocks 

1906 self.get_shocks() 

1907 for var_name in self.shock_vars: 

1908 self.shock_history[var_name][t, :] = self.shocks[var_name] 

1909 

1910 self.t_sim += 1 

1911 self.t_age = self.t_age + 1 # Age all consumers by one period 

1912 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle 

1913 self.t_cycle[self.t_cycle == self.T_cycle] = ( 

1914 0 # Resetting to zero for those who have reached the end 

1915 ) 

1916 

1917 # Flag that shocks can be read rather than simulated 

1918 self.read_shocks = True 

1919 

1920 def get_mortality(self): 

1921 """ 

1922 Simulates mortality or agent turnover according to some model-specific rules named sim_death 

1923 and sim_birth (methods of an AgentType subclass). sim_death takes no arguments and returns 

1924 a Boolean array of size AgentCount, indicating which agents of this type have "died" and 

1925 must be replaced. sim_birth takes such a Boolean array as an argument and generates initial 

1926 post-decision states for those agent indices. 

1927 

1928 Parameters 

1929 ---------- 

1930 None 

1931 

1932 Returns 

1933 ------- 

1934 None 

1935 """ 

1936 if self.read_shocks: 

1937 who_dies = self.shock_history["who_dies"][self.t_sim, :] 

1938 # Instead of simulating births, assign the saved newborn initial conditions 

1939 if who_dies.any(): 

1940 for var_name in self.state_now: 

1941 if var_name in self.newborn_init_history.keys(): 

1942 # Copy only array-like idiosyncratic states. Aggregates should 

1943 # not be set by newborns 

1944 idio = ( 

1945 isinstance(self.state_now[var_name], np.ndarray) 

1946 and len(self.state_now[var_name]) == self.AgentCount 

1947 ) 

1948 if idio: 

1949 self.state_now[var_name][who_dies] = ( 

1950 self.newborn_init_history[var_name][ 

1951 self.t_sim, who_dies 

1952 ] 

1953 ) 

1954 

1955 else: 

1956 warn( 

1957 "The option for reading shocks was activated but " 

1958 + "the model requires state " 

1959 + var_name 

1960 + ", not contained in " 

1961 + "newborn_init_history." 

1962 ) 

1963 

1964 # Reset ages of newborns 

1965 self.t_age[who_dies] = 0 

1966 self.t_cycle[who_dies] = 0 

1967 else: 

1968 who_dies = self.sim_death() 

1969 self.sim_birth(who_dies) 

1970 self.who_dies = who_dies 

1971 return None 

1972 

1973 def sim_death(self): 

1974 """ 

1975 Determines which agents in the current population "die" or should be replaced. Takes no 

1976 inputs, returns a Boolean array of size self.AgentCount, which has True for agents who die 

1977 and False for those that survive. Returns all False by default, must be overwritten by a 

1978 subclass to have replacement events. 

1979 

1980 Parameters 

1981 ---------- 

1982 None 

1983 

1984 Returns 

1985 ------- 

1986 who_dies : np.array 

1987 Boolean array of size self.AgentCount indicating which agents die and are replaced. 

1988 """ 

1989 who_dies = np.zeros(self.AgentCount, dtype=bool) 

1990 return who_dies 

1991 

1992 def sim_birth(self, which_agents): # pragma: nocover 

1993 """ 

1994 Makes new agents for the simulation. Takes a boolean array as an input, indicating which 

1995 agent indices are to be "born". Does nothing by default, must be overwritten by a subclass. 

1996 

1997 Parameters 

1998 ---------- 

1999 which_agents : np.array(Bool) 

2000 Boolean array of size self.AgentCount indicating which agents should be "born". 

2001 

2002 Returns 

2003 ------- 

2004 None 

2005 """ 

2006 raise Exception("AgentType subclass must define method sim_birth!") 

2007 

2008 def get_shocks(self): # pragma: nocover 

2009 """ 

2010 Gets values of shock variables for the current period. Does nothing by default, but can 

2011 be overwritten by subclasses of AgentType. 

2012 

2013 Parameters 

2014 ---------- 

2015 None 

2016 

2017 Returns 

2018 ------- 

2019 None 

2020 """ 

2021 return None 

2022 

2023 def read_shocks_from_history(self): 

2024 """ 

2025 Reads values of shock variables for the current period from history arrays. 

2026 For each variable X named in self.shock_vars, this attribute of self is 

2027 set to self.history[X][self.t_sim,:]. 

2028 

2029 This method is only ever called if self.read_shocks is True. This can 

2030 be achieved by using the method make_shock_history() (or manually after 

2031 storing a "handcrafted" shock history). 

2032 

2033 Parameters 

2034 ---------- 

2035 None 

2036 

2037 Returns 

2038 ------- 

2039 None 

2040 """ 

2041 for var_name in self.shock_vars: 

2042 self.shocks[var_name] = self.shock_history[var_name][self.t_sim, :] 

2043 

2044 def get_states(self): 

2045 """ 

2046 Gets values of state variables for the current period. 

2047 By default, calls transition function and assigns values 

2048 to the state_now dictionary. 

2049 

2050 Parameters 

2051 ---------- 

2052 None 

2053 

2054 Returns 

2055 ------- 

2056 None 

2057 """ 

2058 new_states = self.transition() 

2059 

2060 for i, var in enumerate(self.state_now): 

2061 # a hack for now to deal with 'post-states' 

2062 if i < len(new_states): 

2063 self.state_now[var] = new_states[i] 

2064 

2065 def transition(self): # pragma: nocover 

2066 """ 

2067 

2068 Parameters 

2069 ---------- 

2070 None 

2071 

2072 [Eventually, to match dolo spec: 

2073 exogenous_prev, endogenous_prev, controls, exogenous, parameters] 

2074 

2075 Returns 

2076 ------- 

2077 

2078 endogenous_state: () 

2079 Tuple with new values of the endogenous states 

2080 """ 

2081 return () 

2082 

2083 def get_controls(self): # pragma: nocover 

2084 """ 

2085 Gets values of control variables for the current period, probably by using current states. 

2086 Does nothing by default, but can be overwritten by subclasses of AgentType. 

2087 

2088 Parameters 

2089 ---------- 

2090 None 

2091 

2092 Returns 

2093 ------- 

2094 None 

2095 """ 

2096 return None 

2097 

2098 def get_poststates(self): 

2099 """ 

2100 Gets values of post-decision state variables for the current period, 

2101 probably by current 

2102 states and controls and maybe market-level events or shock variables. 

2103 Does nothing by 

2104 default, but can be overwritten by subclasses of AgentType. 

2105 

2106 Parameters 

2107 ---------- 

2108 None 

2109 

2110 Returns 

2111 ------- 

2112 None 

2113 """ 

2114 return None 

2115 

2116 def symulate(self, T=None): 

2117 """ 

2118 Run the new simulation structure, with history results written to the 

2119 hystory attribute of self. 

2120 """ 

2121 self._simulator.simulate(T) 

2122 self.hystory = self._simulator.history 

2123 

2124 def describe_model(self, display=True): 

2125 """ 

2126 Print to screen information about this agent's model, based on its model 

2127 file. This is useful for learning about outcome variable names for tracking 

2128 during simulation, or for use with sequence space Jacobians. 

2129 """ 

2130 if not hasattr(self, "_simulator"): 

2131 self.initialize_sym() 

2132 self._simulator.describe(display=display) 

2133 

2134 def simulate(self, sim_periods=None): 

2135 """ 

2136 Simulates this agent type for a given number of periods. Defaults to self.T_sim, 

2137 or all remaining periods to simulate (T_sim - t_sim). Records histories of 

2138 attributes named in self.track_vars in self.history[varname]. 

2139 

2140 Parameters 

2141 ---------- 

2142 sim_periods : int or None 

2143 Number of periods to simulate. Default is all remaining periods (usually T_sim). 

2144 

2145 Returns 

2146 ------- 

2147 history : dict 

2148 The history tracked during the simulation. 

2149 """ 

2150 if not hasattr(self, "t_sim"): 

2151 raise Exception( 

2152 "It seems that the simulation variables were not initialize before calling " 

2153 + "simulate(). Call initialize_sim() to initialize the variables before calling simulate() again." 

2154 ) 

2155 

2156 if not hasattr(self, "T_sim"): 

2157 raise Exception( 

2158 "This agent type instance must have the attribute T_sim set to a positive integer." 

2159 + "Set T_sim to match the largest dataset you might simulate, and run this agent's" 

2160 + "initialize_sim() method before running simulate() again." 

2161 ) 

2162 

2163 if sim_periods is not None and self.T_sim < sim_periods: 

2164 raise Exception( 

2165 "To simulate, sim_periods has to be larger than the maximum data set size " 

2166 + "T_sim. Either increase the attribute T_sim of this agent type instance " 

2167 + "and call the initialize_sim() method again, or set sim_periods <= T_sim." 

2168 ) 

2169 

2170 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep- 

2171 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is 

2172 # -np.inf, np.inf/np.inf is np.nan and so on. 

2173 with np.errstate( 

2174 divide="ignore", over="ignore", under="ignore", invalid="ignore" 

2175 ): 

2176 if sim_periods is None: 

2177 sim_periods = self.T_sim - self.t_sim 

2178 

2179 for t in range(sim_periods): 

2180 self.sim_one_period() 

2181 

2182 for var_name in self.track_vars: 

2183 if var_name in self.state_now: 

2184 self.history[var_name][self.t_sim, :] = self.state_now[var_name] 

2185 elif var_name in self.shocks: 

2186 self.history[var_name][self.t_sim, :] = self.shocks[var_name] 

2187 elif var_name in self.controls: 

2188 self.history[var_name][self.t_sim, :] = self.controls[var_name] 

2189 else: 

2190 if var_name == "who_dies" and self.t_sim > 1: 

2191 self.history[var_name][self.t_sim - 1, :] = getattr( 

2192 self, var_name 

2193 ) 

2194 else: 

2195 self.history[var_name][self.t_sim, :] = getattr( 

2196 self, var_name 

2197 ) 

2198 self.t_sim += 1 

2199 

2200 def clear_history(self): 

2201 """ 

2202 Clears the histories of the attributes named in self.track_vars. 

2203 

2204 Parameters 

2205 ---------- 

2206 None 

2207 

2208 Returns 

2209 ------- 

2210 None 

2211 """ 

2212 for var_name in self.track_vars: 

2213 self.history[var_name] = np.empty((self.T_sim, self.AgentCount)) 

2214 self.history[var_name].fill(np.nan) 

2215 

2216 def make_basic_SSJ(self, shock, outcomes, grids, **kwargs): 

2217 """ 

2218 Construct and return sequence space Jacobian matrices for specified outcomes 

2219 with respect to specified "shock" variable. This "basic" method only works 

2220 for "one period infinite horizon" models (cycles=0, T_cycle=1) and for life- 

2221 cycle models (cycles=1). See documentation for simulator.make_basic_SSJ_matrices 

2222 and simulator.make_flat_LC_SSJ_matrices for more information. 

2223 """ 

2224 if (self.cycles == 0) and (self.T_cycle == 1): 

2225 return make_basic_SSJ_matrices(self, shock, outcomes, grids, **kwargs) 

2226 elif self.cycles == 1: 

2227 return make_flat_LC_SSJ_matrices(self, shock, outcomes, grids, **kwargs) 

2228 else: 

2229 raise ValueError( 

2230 "Can only make HA-SSJ matrices for infinite horizon or life-cycle models!" 

2231 ) 

2232 

2233 def calc_impulse_response_manually(self, shock, outcomes, grids, **kwargs): 

2234 """ 

2235 Calculate and return the impulse response(s) of a perturbation to the shock 

2236 parameter in period t=s, essentially computing one column of the sequence 

2237 space Jacobian matrix manually. This "basic" method only works for "one 

2238 period infinite horizon" models (cycles=0, T_cycle=1). See documentation 

2239 for simulator.calc_shock_response_manually for more information. 

2240 """ 

2241 return calc_shock_response_manually(self, shock, outcomes, grids, **kwargs) 

2242 

2243 

2244def solve_agent(agent, verbose, from_solution=None, from_t=None): 

2245 """ 

2246 Solve the dynamic model for one agent type using backwards induction. This 

2247 function iterates on "cycles" of an agent's model either a given number of 

2248 times or until solution convergence if an infinite horizon model is used 

2249 (with agent.cycles = 0). 

2250 

2251 Parameters 

2252 ---------- 

2253 agent : AgentType 

2254 The microeconomic AgentType whose dynamic problem 

2255 is to be solved. 

2256 verbose : boolean 

2257 If True, solution progress is printed to screen (when cycles != 1). 

2258 from_solution: Solution 

2259 If different from None, will be used as the starting point of backward 

2260 induction, instead of self.solution_terminal 

2261 from_t : int or None 

2262 If not None, indicates which period of the model the solver should start 

2263 from. It should usually only be used in combination with from_solution. 

2264 Stands for the time index that from_solution represents, and thus is 

2265 only compatible with cycles=1 and will be reset to None otherwise. 

2266 

2267 Returns 

2268 ------- 

2269 solution : [Solution] 

2270 A list of solutions to the one period problems that the agent will 

2271 encounter in his "lifetime". 

2272 """ 

2273 # Check to see whether this is an (in)finite horizon problem 

2274 cycles_left = agent.cycles # NOQA 

2275 infinite_horizon = cycles_left == 0 # NOQA 

2276 

2277 if from_solution is None: 

2278 solution_last = agent.solution_terminal # NOQA 

2279 else: 

2280 solution_last = from_solution 

2281 if agent.cycles != 1: 

2282 from_t = None 

2283 

2284 # Initialize the solution, which includes the terminal solution if it's not a pseudo-terminal period 

2285 solution = [] 

2286 if not agent.pseudo_terminal: 

2287 solution.insert(0, deepcopy(solution_last)) 

2288 

2289 # Initialize the process, then loop over cycles 

2290 go = True # NOQA 

2291 completed_cycles = 0 # NOQA 

2292 max_cycles = 5000 # NOQA - escape clause 

2293 if verbose: 

2294 t_last = time() 

2295 while go: 

2296 # Solve a cycle of the model, recording it if horizon is finite 

2297 solution_cycle = solve_one_cycle(agent, solution_last, from_t) 

2298 if not infinite_horizon: 

2299 solution = solution_cycle + solution 

2300 

2301 # Check for termination: identical solutions across 

2302 # cycle iterations or run out of cycles 

2303 solution_now = solution_cycle[0] 

2304 if infinite_horizon: 

2305 if completed_cycles > 0: 

2306 solution_distance = distance_metric(solution_now, solution_last) 

2307 agent.solution_distance = ( 

2308 solution_distance # Add these attributes so users can 

2309 ) 

2310 agent.completed_cycles = ( 

2311 completed_cycles # query them to see if solution is ready 

2312 ) 

2313 go = ( 

2314 solution_distance > agent.tolerance 

2315 and completed_cycles < max_cycles 

2316 ) 

2317 else: # Assume solution does not converge after only one cycle 

2318 solution_distance = 100.0 

2319 go = True 

2320 else: 

2321 cycles_left += -1 

2322 go = cycles_left > 0 

2323 

2324 # Update the "last period solution" 

2325 solution_last = solution_now 

2326 completed_cycles += 1 

2327 

2328 # Display progress if requested 

2329 if verbose: 

2330 t_now = time() 

2331 if infinite_horizon: 

2332 print( 

2333 "Finished cycle #" 

2334 + str(completed_cycles) 

2335 + " in " 

2336 + "{:.6f}".format(t_now - t_last) 

2337 + " seconds, solution distance = " 

2338 + str(solution_distance) 

2339 ) 

2340 else: 

2341 print( 

2342 "Finished cycle #" 

2343 + str(completed_cycles) 

2344 + " of " 

2345 + str(agent.cycles) 

2346 + " in " 

2347 + str(t_now - t_last) 

2348 + " seconds." 

2349 ) 

2350 t_last = t_now 

2351 

2352 # Record the last cycle if horizon is infinite (solution is still empty!) 

2353 if infinite_horizon: 

2354 solution = ( 

2355 solution_cycle # PseudoTerminal=False impossible for infinite horizon 

2356 ) 

2357 

2358 return solution 

2359 

2360 

2361def solve_one_cycle(agent, solution_last, from_t): 

2362 """ 

2363 Solve one "cycle" of the dynamic model for one agent type. This function 

2364 iterates over the periods within an agent's cycle, updating the time-varying 

2365 parameters and passing them to the single period solver(s). 

2366 

2367 Parameters 

2368 ---------- 

2369 agent : AgentType 

2370 The microeconomic AgentType whose dynamic problem is to be solved. 

2371 solution_last : Solution 

2372 A representation of the solution of the period that comes after the 

2373 end of the sequence of one period problems. This might be the term- 

2374 inal period solution, a "pseudo terminal" solution, or simply the 

2375 solution to the earliest period from the succeeding cycle. 

2376 from_t : int or None 

2377 If not None, indicates which period of the model the solver should start 

2378 from. When used, represents the time index that solution_last is from. 

2379 

2380 Returns 

2381 ------- 

2382 solution_cycle : [Solution] 

2383 A list of one period solutions for one "cycle" of the AgentType's 

2384 microeconomic model. 

2385 """ 

2386 

2387 # Check if the agent has a 'Parameters' attribute of the 'Parameters' class 

2388 # if so, take advantage of it. Else, use the old method 

2389 if hasattr(agent, "parameters") and isinstance(agent.parameters, Parameters): 

2390 T = agent.parameters._length if from_t is None else from_t 

2391 

2392 # Initialize the solution for this cycle, then iterate on periods 

2393 solution_cycle = [] 

2394 solution_next = solution_last 

2395 

2396 cycles_range = [0] + list(range(T - 1, 0, -1)) 

2397 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range: 

2398 # Update which single period solver to use (if it depends on time) 

2399 if hasattr(agent.solve_one_period, "__getitem__"): 

2400 solve_one_period = agent.solve_one_period[k] 

2401 else: 

2402 solve_one_period = agent.solve_one_period 

2403 

2404 if hasattr(solve_one_period, "solver_args"): 

2405 these_args = solve_one_period.solver_args 

2406 else: 

2407 these_args = get_arg_names(solve_one_period) 

2408 

2409 # Make a temporary dictionary for this period 

2410 temp_pars = agent.parameters[k] 

2411 temp_dict = { 

2412 name: solution_next if name == "solution_next" else temp_pars[name] 

2413 for name in these_args 

2414 } 

2415 

2416 # Solve one period, add it to the solution, and move to the next period 

2417 solution_t = solve_one_period(**temp_dict) 

2418 solution_cycle.insert(0, solution_t) 

2419 solution_next = solution_t 

2420 

2421 else: 

2422 # Calculate number of periods per cycle, defaults to 1 if all variables are time invariant 

2423 if len(agent.time_vary) > 0: 

2424 T = agent.T_cycle if from_t is None else from_t 

2425 else: 

2426 T = 1 

2427 

2428 solve_dict = { 

2429 parameter: agent.__dict__[parameter] for parameter in agent.time_inv 

2430 } 

2431 solve_dict.update({parameter: None for parameter in agent.time_vary}) 

2432 

2433 # Initialize the solution for this cycle, then iterate on periods 

2434 solution_cycle = [] 

2435 solution_next = solution_last 

2436 

2437 cycles_range = [0] + list(range(T - 1, 0, -1)) 

2438 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range: 

2439 # Update which single period solver to use (if it depends on time) 

2440 if hasattr(agent.solve_one_period, "__getitem__"): 

2441 solve_one_period = agent.solve_one_period[k] 

2442 else: 

2443 solve_one_period = agent.solve_one_period 

2444 

2445 if hasattr(solve_one_period, "solver_args"): 

2446 these_args = solve_one_period.solver_args 

2447 else: 

2448 these_args = get_arg_names(solve_one_period) 

2449 

2450 # Update time-varying single period inputs 

2451 for name in agent.time_vary: 

2452 if name in these_args: 

2453 solve_dict[name] = agent.__dict__[name][k] 

2454 solve_dict["solution_next"] = solution_next 

2455 

2456 # Make a temporary dictionary for this period 

2457 temp_dict = {name: solve_dict[name] for name in these_args} 

2458 

2459 # Solve one period, add it to the solution, and move to the next period 

2460 solution_t = solve_one_period(**temp_dict) 

2461 solution_cycle.insert(0, solution_t) 

2462 solution_next = solution_t 

2463 

2464 # Return the list of per-period solutions 

2465 return solution_cycle 

2466 

2467 

2468def make_one_period_oo_solver(solver_class): 

2469 """ 

2470 Returns a function that solves a single period consumption-saving 

2471 problem. 

2472 Parameters 

2473 ---------- 

2474 solver_class : Solver 

2475 A class of Solver to be used. 

2476 ------- 

2477 solver_function : function 

2478 A function for solving one period of a problem. 

2479 """ 

2480 

2481 def one_period_solver(**kwds): 

2482 solver = solver_class(**kwds) 

2483 

2484 # not ideal; better if this is defined in all Solver classes 

2485 if hasattr(solver, "prepare_to_solve"): 

2486 solver.prepare_to_solve() 

2487 

2488 solution_now = solver.solve() 

2489 return solution_now 

2490 

2491 one_period_solver.solver_class = solver_class 

2492 # This can be revisited once it is possible to export parameters 

2493 one_period_solver.solver_args = get_arg_names(solver_class.__init__)[1:] 

2494 

2495 return one_period_solver 

2496 

2497 

2498# ======================================================================== 

2499# ======================================================================== 

2500 

2501 

2502class Market(Model): 

2503 """ 

2504 A superclass to represent a central clearinghouse of information. Used for 

2505 dynamic general equilibrium models to solve the "macroeconomic" model as a 

2506 layer on top of the "microeconomic" models of one or more AgentTypes. 

2507 

2508 Parameters 

2509 ---------- 

2510 agents : [AgentType] 

2511 A list of all the AgentTypes in this market. 

2512 sow_vars : [string] 

2513 Names of variables generated by the "aggregate market process" that should 

2514 be "sown" to the agents in the market. Aggregate state, etc. 

2515 reap_vars : [string] 

2516 Names of variables to be collected ("reaped") from agents in the market 

2517 to be used in the "aggregate market process". 

2518 const_vars : [string] 

2519 Names of attributes of the Market instance that are used in the "aggregate 

2520 market process" but do not come from agents-- they are constant or simply 

2521 parameters inherent to the process. 

2522 track_vars : [string] 

2523 Names of variables generated by the "aggregate market process" that should 

2524 be tracked as a "history" so that a new dynamic rule can be calculated. 

2525 This is often a subset of sow_vars. 

2526 dyn_vars : [string] 

2527 Names of variables that constitute a "dynamic rule". 

2528 mill_rule : function 

2529 A function that takes inputs named in reap_vars and returns a tuple the 

2530 same size and order as sow_vars. The "aggregate market process" that 

2531 transforms individual agent actions/states/data into aggregate data to 

2532 be sent back to agents. 

2533 calc_dynamics : function 

2534 A function that takes inputs named in track_vars and returns an object 

2535 with attributes named in dyn_vars. Looks at histories of aggregate 

2536 variables and generates a new "dynamic rule" for agents to believe and 

2537 act on. 

2538 act_T : int 

2539 The number of times that the "aggregate market process" should be run 

2540 in order to generate a history of aggregate variables. 

2541 tolerance: float 

2542 Minimum acceptable distance between "dynamic rules" to consider the 

2543 Market solution process converged. Distance is a user-defined metric. 

2544 """ 

2545 

2546 def __init__( 

2547 self, 

2548 agents=None, 

2549 sow_vars=None, 

2550 reap_vars=None, 

2551 const_vars=None, 

2552 track_vars=None, 

2553 dyn_vars=None, 

2554 mill_rule=None, 

2555 calc_dynamics=None, 

2556 distributions=None, 

2557 act_T=1000, 

2558 tolerance=0.000001, 

2559 seed=0, 

2560 **kwds, 

2561 ): 

2562 super().__init__() 

2563 self.agents = agents if agents is not None else list() # NOQA 

2564 

2565 self.reap_vars = reap_vars if reap_vars is not None else list() # NOQA 

2566 self.reap_state = {var: [] for var in self.reap_vars} 

2567 

2568 self.sow_vars = sow_vars if sow_vars is not None else list() # NOQA 

2569 # dictionaries for tracking initial and current values 

2570 # of the sow variables. 

2571 self.sow_init = {var: None for var in self.sow_vars} 

2572 self.sow_state = {var: None for var in self.sow_vars} 

2573 

2574 const_vars = const_vars if const_vars is not None else list() # NOQA 

2575 self.const_vars = {var: None for var in const_vars} 

2576 

2577 self.track_vars = track_vars if track_vars is not None else list() # NOQA 

2578 self.dyn_vars = dyn_vars if dyn_vars is not None else list() # NOQA 

2579 self.distributions = distributions if distributions is not None else list() # NOQA 

2580 

2581 if mill_rule is not None: # To prevent overwriting of method-based mill_rules 

2582 self.mill_rule = mill_rule 

2583 if calc_dynamics is not None: # Ditto for calc_dynamics 

2584 self.calc_dynamics = calc_dynamics 

2585 self.act_T = act_T # NOQA 

2586 self.tolerance = tolerance # NOQA 

2587 self.seed = seed 

2588 self.max_loops = 1000 # NOQA 

2589 self.history = {} 

2590 self.assign_parameters(**kwds) 

2591 self.RNG = np.random.default_rng(self.seed) 

2592 

2593 self.print_parallel_error_once = True 

2594 # Print the error associated with calling the parallel method 

2595 # "solve_agents" one time. If set to false, the error will never 

2596 # print. See "solve_agents" for why this prints once or never. 

2597 

2598 def give_agent_params(self, construct=True): 

2599 """ 

2600 Distribute relevant market-level parameters to each AgentType in self.agents 

2601 by having them call their get_market_params method. 

2602 

2603 Parameters 

2604 ---------- 

2605 construct : bool, optional 

2606 Whether agents should run their construct method after fetching market 

2607 data (default True). 

2608 

2609 Returns 

2610 ------- 

2611 None 

2612 """ 

2613 for agent in self.agents: 

2614 agent.get_market_params(self, construct) 

2615 

2616 def solve_agents(self): 

2617 """ 

2618 Solves the microeconomic problem for all AgentTypes in this market. 

2619 

2620 Parameters 

2621 ---------- 

2622 None 

2623 

2624 Returns 

2625 ------- 

2626 None 

2627 """ 

2628 try: 

2629 multi_thread_commands(self.agents, ["solve()"]) 

2630 except Exception as err: 

2631 if self.print_parallel_error_once: 

2632 # Set flag to False so this is only printed once. 

2633 self.print_parallel_error_once = False 

2634 print( 

2635 "**** WARNING: could not execute multi_thread_commands in HARK.core.Market.solve_agents() ", 

2636 "so using the serial version instead. This will likely be slower. " 

2637 "The multi_thread_commands() functions failed with the following error:", 

2638 "\n", 

2639 sys.exc_info()[0], 

2640 ":", 

2641 err, 

2642 ) # sys.exc_info()[0]) 

2643 multi_thread_commands_fake(self.agents, ["solve()"]) 

2644 

2645 def solve(self): 

2646 """ 

2647 "Solves" the market by finding a "dynamic rule" that governs the aggregate 

2648 market state such that when agents believe in these dynamics, their actions 

2649 collectively generate the same dynamic rule. 

2650 

2651 Parameters 

2652 ---------- 

2653 None 

2654 

2655 Returns 

2656 ------- 

2657 None 

2658 """ 

2659 go = True 

2660 max_loops = self.max_loops # Failsafe against infinite solution loop 

2661 completed_loops = 0 

2662 old_dynamics = None 

2663 

2664 while go: # Loop until the dynamic process converges or we hit the loop cap 

2665 self.solve_agents() # Solve each AgentType's micro problem 

2666 self.make_history() # "Run" the model while tracking aggregate variables 

2667 new_dynamics = self.update_dynamics() # Find a new aggregate dynamic rule 

2668 

2669 # Check to see if the dynamic rule has converged (if this is not the first loop) 

2670 if completed_loops > 0: 

2671 distance = new_dynamics.distance(old_dynamics) 

2672 else: 

2673 distance = 1000000.0 

2674 

2675 # Move to the next loop if the terminal conditions are not met 

2676 old_dynamics = new_dynamics 

2677 completed_loops += 1 

2678 go = distance >= self.tolerance and completed_loops < max_loops 

2679 

2680 self.dynamics = new_dynamics # Store the final dynamic rule in self 

2681 

2682 def reap(self): 

2683 """ 

2684 Collects attributes named in reap_vars from each AgentType in the market, 

2685 storing them in respectively named attributes of self. 

2686 

2687 Parameters 

2688 ---------- 

2689 none 

2690 

2691 Returns 

2692 ------- 

2693 none 

2694 """ 

2695 for var in self.reap_state: 

2696 harvest = [] 

2697 

2698 for agent in self.agents: 

2699 # TODO: generalized variable lookup across namespaces 

2700 if var in agent.state_now: 

2701 # or state_now ?? 

2702 harvest.append(agent.state_now[var]) 

2703 

2704 self.reap_state[var] = harvest 

2705 

2706 def sow(self): 

2707 """ 

2708 Distributes attrributes named in sow_vars from self to each AgentType 

2709 in the market, storing them in respectively named attributes. 

2710 

2711 Parameters 

2712 ---------- 

2713 none 

2714 

2715 Returns 

2716 ------- 

2717 none 

2718 """ 

2719 for sow_var in self.sow_state: 

2720 for this_type in self.agents: 

2721 if sow_var in this_type.state_now: 

2722 this_type.state_now[sow_var] = self.sow_state[sow_var] 

2723 if sow_var in this_type.shocks: 

2724 this_type.shocks[sow_var] = self.sow_state[sow_var] 

2725 else: 

2726 setattr(this_type, sow_var, self.sow_state[sow_var]) 

2727 

2728 def mill(self): 

2729 """ 

2730 Processes the variables collected from agents using the function mill_rule, 

2731 storing the results in attributes named in aggr_sow. 

2732 

2733 Parameters 

2734 ---------- 

2735 none 

2736 

2737 Returns 

2738 ------- 

2739 none 

2740 """ 

2741 # Make a dictionary of inputs for the mill_rule 

2742 mill_dict = copy(self.reap_state) 

2743 mill_dict.update(self.const_vars) 

2744 

2745 # Run the mill_rule and store its output in self 

2746 product = self.mill_rule(**mill_dict) 

2747 

2748 for i, sow_var in enumerate(self.sow_state): 

2749 self.sow_state[sow_var] = product[i] 

2750 

2751 def cultivate(self): 

2752 """ 

2753 Has each AgentType in agents perform their market_action method, using 

2754 variables sown from the market (and maybe also "private" variables). 

2755 The market_action method should store new results in attributes named in 

2756 reap_vars to be reaped later. 

2757 

2758 Parameters 

2759 ---------- 

2760 none 

2761 

2762 Returns 

2763 ------- 

2764 none 

2765 """ 

2766 for this_type in self.agents: 

2767 this_type.market_action() 

2768 

2769 def reset(self): 

2770 """ 

2771 Reset the state of the market (attributes in sow_vars, etc) to some 

2772 user-defined initial state, and erase the histories of tracked variables. 

2773 Also resets the internal RNG so that draws can be reproduced. 

2774 

2775 Parameters 

2776 ---------- 

2777 none 

2778 

2779 Returns 

2780 ------- 

2781 none 

2782 """ 

2783 # Reset internal RNG and distributions 

2784 for name in self.distributions: 

2785 if not hasattr(self, name): 

2786 continue 

2787 dstn = getattr(self, name) 

2788 if isinstance(dstn, list): 

2789 for D in dstn: 

2790 D.reset() 

2791 else: 

2792 dstn.reset() 

2793 

2794 # Reset the history of tracked variables 

2795 self.history = {var_name: [] for var_name in self.track_vars} 

2796 

2797 # Set the sow variables to their initial levels 

2798 for var_name in self.sow_state: 

2799 self.sow_state[var_name] = self.sow_init[var_name] 

2800 

2801 # Reset each AgentType in the market 

2802 for this_type in self.agents: 

2803 this_type.reset() 

2804 

2805 def store(self): 

2806 """ 

2807 Record the current value of each variable X named in track_vars in an 

2808 dictionary field named history[X]. 

2809 

2810 Parameters 

2811 ---------- 

2812 none 

2813 

2814 Returns 

2815 ------- 

2816 none 

2817 """ 

2818 for var_name in self.track_vars: 

2819 if var_name in self.sow_state: 

2820 value_now = self.sow_state[var_name] 

2821 elif var_name in self.reap_state: 

2822 value_now = self.reap_state[var_name] 

2823 elif var_name in self.const_vars: 

2824 value_now = self.const_vars[var_name] 

2825 else: 

2826 value_now = getattr(self, var_name) 

2827 

2828 self.history[var_name].append(value_now) 

2829 

2830 def make_history(self): 

2831 """ 

2832 Runs a loop of sow-->cultivate-->reap-->mill act_T times, tracking the 

2833 evolution of variables X named in track_vars in dictionary fields 

2834 history[X]. 

2835 

2836 Parameters 

2837 ---------- 

2838 none 

2839 

2840 Returns 

2841 ------- 

2842 none 

2843 """ 

2844 self.reset() # Initialize the state of the market 

2845 for t in range(self.act_T): 

2846 self.sow() # Distribute aggregated information/state to agents 

2847 self.cultivate() # Agents take action 

2848 self.reap() # Collect individual data from agents 

2849 self.mill() # Process individual data into aggregate data 

2850 self.store() # Record variables of interest 

2851 

2852 def update_dynamics(self): 

2853 """ 

2854 Calculates a new "aggregate dynamic rule" using the history of variables 

2855 named in track_vars, and distributes this rule to AgentTypes in agents. 

2856 

2857 Parameters 

2858 ---------- 

2859 none 

2860 

2861 Returns 

2862 ------- 

2863 dynamics : instance 

2864 The new "aggregate dynamic rule" that agents believe in and act on. 

2865 Should have attributes named in dyn_vars. 

2866 """ 

2867 # Make a dictionary of inputs for the dynamics calculator 

2868 arg_names = list(get_arg_names(self.calc_dynamics)) 

2869 if "self" in arg_names: 

2870 arg_names.remove("self") 

2871 update_dict = {} 

2872 for name in arg_names: 

2873 update_dict[name] = ( 

2874 self.history[name] if name in self.track_vars else getattr(self, name) 

2875 ) 

2876 # Calculate a new dynamic rule and distribute it to the agents in agent_list 

2877 dynamics = self.calc_dynamics(**update_dict) # User-defined dynamics calculator 

2878 for var_name in self.dyn_vars: 

2879 this_obj = getattr(dynamics, var_name) 

2880 for this_type in self.agents: 

2881 setattr(this_type, var_name, this_obj) 

2882 return dynamics 

2883 

2884 

2885def distribute_params(agent, param_name, param_count, distribution): 

2886 """ 

2887 Distributes heterogeneous values of one parameter to the AgentTypes in self.agents. 

2888 Parameters 

2889 ---------- 

2890 agent: AgentType 

2891 An agent to clone. 

2892 param_name : string 

2893 Name of the parameter to be assigned. 

2894 param_count : int 

2895 Number of different values the parameter will take on. 

2896 distribution : Distribution 

2897 A 1-D distribution. 

2898 

2899 Returns 

2900 ------- 

2901 agent_set : [AgentType] 

2902 A list of param_count agents, ex ante heterogeneous with 

2903 respect to param_name. The AgentCount of the original 

2904 will be split between the agents of the returned 

2905 list in proportion to the given distribution. 

2906 """ 

2907 param_dist = distribution.discretize(N=param_count) 

2908 

2909 agent_set = [deepcopy(agent) for i in range(param_count)] 

2910 

2911 for j in range(param_count): 

2912 agent_set[j].assign_parameters( 

2913 **{"AgentCount": int(agent.AgentCount * param_dist.pmv[j])} 

2914 ) 

2915 agent_set[j].assign_parameters(**{param_name: param_dist.atoms[0, j]}) 

2916 

2917 return agent_set 

2918 

2919 

2920@dataclass 

2921class AgentPopulation: 

2922 """ 

2923 A class for representing a population of ex-ante heterogeneous agents. 

2924 """ 

2925 

2926 agent_type: AgentType # type of agent in the population 

2927 parameters: dict # dictionary of parameters 

2928 seed: int = 0 # random seed 

2929 time_var: List[str] = field(init=False) 

2930 time_inv: List[str] = field(init=False) 

2931 distributed_params: List[str] = field(init=False) 

2932 agent_type_count: Optional[int] = field(init=False) 

2933 term_age: Optional[int] = field(init=False) 

2934 continuous_distributions: Dict[str, Distribution] = field(init=False) 

2935 discrete_distributions: Dict[str, Distribution] = field(init=False) 

2936 population_parameters: List[Dict[str, Union[List[float], float]]] = field( 

2937 init=False 

2938 ) 

2939 agents: List[AgentType] = field(init=False) 

2940 agent_database: pd.DataFrame = field(init=False) 

2941 solution: List[Any] = field(init=False) 

2942 

2943 def __post_init__(self): 

2944 """ 

2945 Initialize the population of agents, determine distributed parameters, 

2946 and infer `agent_type_count` and `term_age`. 

2947 """ 

2948 # create a dummy agent and obtain its time-varying 

2949 # and time-invariant attributes 

2950 dummy_agent = self.agent_type() 

2951 self.time_var = dummy_agent.time_vary 

2952 self.time_inv = dummy_agent.time_inv 

2953 

2954 # create list of distributed parameters 

2955 # these are parameters that differ across agents 

2956 self.distributed_params = [ 

2957 key 

2958 for key, param in self.parameters.items() 

2959 if (isinstance(param, list) and isinstance(param[0], list)) 

2960 or isinstance(param, Distribution) 

2961 or (isinstance(param, DataArray) and param.dims[0] == "agent") 

2962 ] 

2963 

2964 self.__infer_counts__() 

2965 

2966 self.print_parallel_error_once = True 

2967 # Print warning once if parallel simulation fails 

2968 

2969 def __infer_counts__(self): 

2970 """ 

2971 Infer `agent_type_count` and `term_age` from the parameters. 

2972 If parameters include a `Distribution` type, a list of lists, 

2973 or a `DataArray` with `agent` as the first dimension, then 

2974 the AgentPopulation contains ex-ante heterogenous agents. 

2975 """ 

2976 

2977 # infer agent_type_count from distributed parameters 

2978 agent_type_count = 1 

2979 for key in self.distributed_params: 

2980 param = self.parameters[key] 

2981 if isinstance(param, Distribution): 

2982 agent_type_count = None 

2983 warn( 

2984 "Cannot infer agent_type_count from a Distribution. " 

2985 "Please provide approximation parameters." 

2986 ) 

2987 break 

2988 elif isinstance(param, list): 

2989 agent_type_count = max(agent_type_count, len(param)) 

2990 elif isinstance(param, DataArray) and param.dims[0] == "agent": 

2991 agent_type_count = max(agent_type_count, param.shape[0]) 

2992 

2993 self.agent_type_count = agent_type_count 

2994 

2995 # infer term_age from all parameters 

2996 term_age = 1 

2997 for param in self.parameters.values(): 

2998 if isinstance(param, Distribution): 

2999 term_age = None 

3000 warn( 

3001 "Cannot infer term_age from a Distribution. " 

3002 "Please provide approximation parameters." 

3003 ) 

3004 break 

3005 elif isinstance(param, list) and isinstance(param[0], list): 

3006 term_age = max(term_age, len(param[0])) 

3007 elif isinstance(param, DataArray) and param.dims[-1] == "age": 

3008 term_age = max(term_age, param.shape[-1]) 

3009 

3010 self.term_age = term_age 

3011 

3012 def approx_distributions(self, approx_params: dict): 

3013 """ 

3014 Approximate continuous distributions with discrete ones. If the initial 

3015 parameters include a `Distribution` type, then the AgentPopulation is 

3016 not ready to solve, and stands for an abstract population. To solve the 

3017 AgentPopulation, we need discretization parameters for each continuous 

3018 distribution. This method approximates the continuous distributions with 

3019 discrete ones, and updates the parameters dictionary. 

3020 """ 

3021 self.continuous_distributions = {} 

3022 self.discrete_distributions = {} 

3023 

3024 for key, args in approx_params.items(): 

3025 param = self.parameters[key] 

3026 if key in self.distributed_params and isinstance(param, Distribution): 

3027 self.continuous_distributions[key] = param 

3028 self.discrete_distributions[key] = param.discretize(**args) 

3029 else: 

3030 raise ValueError( 

3031 f"Warning: parameter {key} is not a Distribution found " 

3032 f"in agent type {self.agent_type}" 

3033 ) 

3034 

3035 if len(self.discrete_distributions) > 1: 

3036 joint_dist = combine_indep_dstns(*self.discrete_distributions.values()) 

3037 else: 

3038 joint_dist = list(self.discrete_distributions.values())[0] 

3039 

3040 for i, key in enumerate(self.discrete_distributions): 

3041 self.parameters[key] = DataArray(joint_dist.atoms[i], dims=("agent")) 

3042 

3043 self.__infer_counts__() 

3044 

3045 def __parse_parameters__(self) -> None: 

3046 """ 

3047 Creates distributed dictionaries of parameters for each ex-ante 

3048 heterogeneous agent in the parameterized population. The parameters 

3049 are stored in a list of dictionaries, where each dictionary contains 

3050 the parameters for one agent. Expands parameters that vary over time 

3051 to a list of length `term_age`. 

3052 """ 

3053 

3054 population_parameters = [] # container for dictionaries of each agent subgroup 

3055 for agent in range(self.agent_type_count): 

3056 agent_parameters = {} 

3057 for key, param in self.parameters.items(): 

3058 if key in self.time_var: 

3059 # parameters that vary over time have to be repeated 

3060 if isinstance(param, (int, float)): 

3061 parameter_per_t = [param] * self.term_age 

3062 elif isinstance(param, list): 

3063 if isinstance(param[0], list): 

3064 parameter_per_t = param[agent] 

3065 else: 

3066 parameter_per_t = param 

3067 elif isinstance(param, DataArray): 

3068 if param.dims[0] == "agent": 

3069 if len(param.dims) > 1 and param.dims[-1] == "age": 

3070 parameter_per_t = param[agent].values.tolist() 

3071 else: 

3072 parameter_per_t = param[agent].item() 

3073 elif param.dims[0] == "age": 

3074 parameter_per_t = param.values.tolist() 

3075 

3076 agent_parameters[key] = parameter_per_t 

3077 

3078 elif key in self.time_inv: 

3079 if isinstance(param, (int, float)): 

3080 agent_parameters[key] = param 

3081 elif isinstance(param, list): 

3082 if isinstance(param[0], list): 

3083 agent_parameters[key] = param[agent] 

3084 else: 

3085 agent_parameters[key] = param 

3086 elif isinstance(param, DataArray) and param.dims[0] == "agent": 

3087 agent_parameters[key] = param[agent].item() 

3088 

3089 else: 

3090 if isinstance(param, (int, float)): 

3091 agent_parameters[key] = param # assume time inv 

3092 elif isinstance(param, list): 

3093 if isinstance(param[0], list): 

3094 agent_parameters[key] = param[agent] # assume agent vary 

3095 else: 

3096 agent_parameters[key] = param # assume time vary 

3097 elif isinstance(param, DataArray): 

3098 if param.dims[0] == "agent": 

3099 if len(param.dims) > 1 and param.dims[-1] == "age": 

3100 agent_parameters[key] = param[agent].values.tolist() 

3101 else: 

3102 agent_parameters[key] = param[agent].item() 

3103 elif param.dims[0] == "age": 

3104 agent_parameters[key] = param.values.tolist() 

3105 

3106 population_parameters.append(agent_parameters) 

3107 

3108 self.population_parameters = population_parameters 

3109 

3110 def create_distributed_agents(self): 

3111 """ 

3112 Parses the parameters dictionary and creates a list of agents with the 

3113 appropriate parameters. Also sets the seed for each agent. 

3114 """ 

3115 

3116 self.__parse_parameters__() 

3117 

3118 rng = np.random.default_rng(self.seed) 

3119 

3120 self.agents = [ 

3121 self.agent_type(seed=rng.integers(0, 2**31 - 1), **agent_dict) 

3122 for agent_dict in self.population_parameters 

3123 ] 

3124 

3125 def create_database(self): 

3126 """ 

3127 Optionally creates a pandas DataFrame with the parameters for each agent. 

3128 """ 

3129 database = pd.DataFrame(self.population_parameters) 

3130 database["agents"] = self.agents 

3131 

3132 self.agent_database = database 

3133 

3134 def solve(self): 

3135 """ 

3136 Solves each agent of the population serially. 

3137 """ 

3138 

3139 # see Market class for an example of how to solve distributed agents in parallel 

3140 

3141 for agent in self.agents: 

3142 agent.solve() 

3143 

3144 def unpack_solutions(self): 

3145 """ 

3146 Unpacks the solutions of each agent into an attribute of the population. 

3147 """ 

3148 self.solution = [agent.solution for agent in self.agents] 

3149 

3150 def initialize_sim(self): 

3151 """ 

3152 Initializes the simulation for each agent. 

3153 """ 

3154 for agent in self.agents: 

3155 agent.initialize_sim() 

3156 

3157 def simulate(self, num_jobs=None): 

3158 """ 

3159 Simulates each agent of the population. 

3160 

3161 Parameters 

3162 ---------- 

3163 num_jobs : int, optional 

3164 Number of parallel jobs to use. Defaults to using all available 

3165 cores when ``None``. Falls back to serial execution if parallel 

3166 processing fails. 

3167 """ 

3168 try: 

3169 multi_thread_commands(self.agents, ["simulate()"], num_jobs) 

3170 except Exception as err: 

3171 if getattr(self, "print_parallel_error_once", False): 

3172 self.print_parallel_error_once = False 

3173 print( 

3174 "**** WARNING: could not execute multi_thread_commands in HARK.core.AgentPopulation.simulate() ", 

3175 "so using the serial version instead. This will likely be slower. ", 

3176 "The multi_thread_commands() function failed with the following error:\n", 

3177 sys.exc_info()[0], 

3178 ":", 

3179 err, 

3180 ) 

3181 multi_thread_commands_fake(self.agents, ["simulate()"], num_jobs) 

3182 

3183 def __iter__(self): 

3184 """ 

3185 Allows for iteration over the agents in the population. 

3186 """ 

3187 return iter(self.agents) 

3188 

3189 def __getitem__(self, idx): 

3190 """ 

3191 Allows for indexing into the population. 

3192 """ 

3193 return self.agents[idx] 

3194 

3195 

3196############################################################################### 

3197 

3198 

3199def multi_thread_commands_fake( 

3200 agent_list: List, command_list: List, num_jobs=None 

3201) -> None: 

3202 """ 

3203 Executes the list of commands in command_list for each AgentType in agent_list 

3204 in an ordinary, single-threaded loop. Each command should be a method of 

3205 that AgentType subclass. This function exists so as to easily disable 

3206 multithreading, as it uses the same syntax as multi_thread_commands. 

3207 

3208 Parameters 

3209 ---------- 

3210 agent_list : [AgentType] 

3211 A list of instances of AgentType on which the commands will be run. 

3212 command_list : [string] 

3213 A list of commands to run for each AgentType. 

3214 num_jobs : None 

3215 Dummy input to match syntax of multi_thread_commands. Does nothing. 

3216 

3217 Returns 

3218 ------- 

3219 none 

3220 """ 

3221 for agent in agent_list: 

3222 for command in command_list: 

3223 # Can pass method names with or without parentheses 

3224 if command[-2:] == "()": 

3225 getattr(agent, command[:-2])() 

3226 else: 

3227 getattr(agent, command)() 

3228 

3229 

3230def multi_thread_commands(agent_list: List, command_list: List, num_jobs=None) -> None: 

3231 """ 

3232 Executes the list of commands in command_list for each AgentType in agent_list 

3233 using a multithreaded system. Each command should be a method of that AgentType subclass. 

3234 

3235 Parameters 

3236 ---------- 

3237 agent_list : [AgentType] 

3238 A list of instances of AgentType on which the commands will be run. 

3239 command_list : [string] 

3240 A list of commands to run for each AgentType in agent_list. 

3241 

3242 Returns 

3243 ------- 

3244 None 

3245 """ 

3246 if len(agent_list) == 1: 

3247 multi_thread_commands_fake(agent_list, command_list) 

3248 return None 

3249 

3250 # Default number of parallel jobs is the smaller of number of AgentTypes in 

3251 # the input and the number of available cores. 

3252 if num_jobs is None: 

3253 num_jobs = min(len(agent_list), multiprocessing.cpu_count()) 

3254 

3255 # Send each command in command_list to each of the types in agent_list to be run 

3256 agent_list_out = Parallel(n_jobs=num_jobs)( 

3257 delayed(run_commands)(*args) 

3258 for args in zip(agent_list, len(agent_list) * [command_list]) 

3259 ) 

3260 

3261 # Replace the original types with the output from the parallel call 

3262 for j in range(len(agent_list)): 

3263 agent_list[j] = agent_list_out[j] 

3264 

3265 

3266def run_commands(agent: Any, command_list: List) -> Any: 

3267 """ 

3268 Executes each command in command_list on a given AgentType. The commands 

3269 should be methods of that AgentType's subclass. 

3270 

3271 Parameters 

3272 ---------- 

3273 agent : AgentType 

3274 An instance of AgentType on which the commands will be run. 

3275 command_list : [string] 

3276 A list of commands that the agent should run, as methods. 

3277 

3278 Returns 

3279 ------- 

3280 agent : AgentType 

3281 The same AgentType instance passed as input, after running the commands. 

3282 """ 

3283 for command in command_list: 

3284 # Can pass method names with or without parentheses 

3285 if command[-2:] == "()": 

3286 getattr(agent, command[:-2])() 

3287 else: 

3288 getattr(agent, command)() 

3289 return agent