Coverage for HARK / core.py: 94%

948 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-07 05:16 +0000

1""" 

2High-level functions and classes for solving a wide variety of economic models. 

3The "core" of HARK is a framework for "microeconomic" and "macroeconomic" 

4models. A micro model concerns the dynamic optimization problem for some type 

5of agents, where agents take the inputs to their problem as exogenous. A macro 

6model adds an additional layer, endogenizing some of the inputs to the micro 

7problem by finding a general equilibrium dynamic rule. 

8""" 

9 

10# Import basic modules 

11import inspect 

12import sys 

13from collections import namedtuple 

14from copy import copy, deepcopy 

15from dataclasses import dataclass, field 

16from time import time 

17from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union 

18from warnings import warn 

19import multiprocessing 

20from joblib import Parallel, delayed 

21 

22import numpy as np 

23import pandas as pd 

24from xarray import DataArray 

25 

26from HARK.distributions import ( 

27 Distribution, 

28 IndexDistribution, 

29 combine_indep_dstns, 

30) 

31from HARK.utilities import NullFunc, get_arg_names, get_it_from 

32from HARK.simulator import make_simulator_from_agent 

33from HARK.SSJutils import ( 

34 make_basic_SSJ_matrices, 

35 calc_shock_response_manually, 

36) 

37from HARK.metric import MetricObject 

38 

39__all__ = [ 

40 "AgentType", 

41 "Market", 

42 "Parameters", 

43 "Model", 

44 "AgentPopulation", 

45 "multi_thread_commands", 

46 "multi_thread_commands_fake", 

47 "NullFunc", 

48 "make_one_period_oo_solver", 

49 "distribute_params", 

50] 

51 

52 

53class Parameters: 

54 """ 

55 A smart container for model parameters that handles age-varying dynamics. 

56 

57 This class stores parameters as an internal dictionary and manages their 

58 age-varying properties, providing both attribute-style and dictionary-style 

59 access. It is designed to handle the time-varying dynamics of parameters 

60 in economic models. 

61 

62 Attributes 

63 ---------- 

64 _length : int 

65 The terminal age of the agents in the model. 

66 _invariant_params : Set[str] 

67 A set of parameter names that are invariant over time. 

68 _varying_params : Set[str] 

69 A set of parameter names that vary over time. 

70 _parameters : Dict[str, Any] 

71 The internal dictionary storing all parameters. 

72 """ 

73 

74 __slots__ = ( 

75 "_length", 

76 "_invariant_params", 

77 "_varying_params", 

78 "_parameters", 

79 "_frozen", 

80 "_namedtuple_cache", 

81 ) 

82 

83 def __init__(self, **parameters: Any) -> None: 

84 """ 

85 Initialize a Parameters object and parse the age-varying dynamics of parameters. 

86 

87 Parameters 

88 ---------- 

89 T_cycle : int, optional 

90 The number of time periods in the model cycle (default: 1). 

91 Must be >= 1. 

92 frozen : bool, optional 

93 If True, the Parameters object will be immutable after initialization 

94 (default: False). 

95 _time_inv : List[str], optional 

96 List of parameter names to explicitly mark as time-invariant, 

97 overriding automatic inference. 

98 _time_vary : List[str], optional 

99 List of parameter names to explicitly mark as time-varying, 

100 overriding automatic inference. 

101 **parameters : Any 

102 Any number of parameters in the form key=value. 

103 

104 Raises 

105 ------ 

106 ValueError 

107 If T_cycle is less than 1. 

108 

109 Notes 

110 ----- 

111 Automatic time-variance inference rules: 

112 - Scalars (int, float, bool, None) are time-invariant 

113 - NumPy arrays are time-invariant (use lists/tuples for time-varying) 

114 - Single-element lists/tuples [x] are unwrapped to x and time-invariant 

115 - Multi-element lists/tuples are time-varying if length matches T_cycle 

116 - 2D arrays with first dimension matching T_cycle are time-varying 

117 - Distributions and Callables are time-invariant 

118 

119 Use _time_inv or _time_vary to override automatic inference when needed. 

120 """ 

121 # Extract special parameters 

122 self._length: int = parameters.pop("T_cycle", 1) 

123 frozen: bool = parameters.pop("frozen", False) 

124 time_inv_override: List[str] = parameters.pop("_time_inv", []) 

125 time_vary_override: List[str] = parameters.pop("_time_vary", []) 

126 

127 # Validate T_cycle 

128 if self._length < 1: 

129 raise ValueError(f"T_cycle must be >= 1, got {self._length}") 

130 

131 # Initialize internal state 

132 self._invariant_params: Set[str] = set() 

133 self._varying_params: Set[str] = set() 

134 self._parameters: Dict[str, Any] = {"T_cycle": self._length} 

135 self._frozen: bool = False # Set to False initially to allow setup 

136 self._namedtuple_cache: Optional[type] = None 

137 

138 # Set parameters using automatic inference 

139 for key, value in parameters.items(): 

140 self[key] = value 

141 

142 # Apply explicit overrides 

143 for param in time_inv_override: 

144 if param in self._parameters: 

145 self._invariant_params.add(param) 

146 self._varying_params.discard(param) 

147 

148 for param in time_vary_override: 

149 if param in self._parameters: 

150 self._varying_params.add(param) 

151 self._invariant_params.discard(param) 

152 

153 # Freeze if requested 

154 self._frozen = frozen 

155 

156 def __getitem__(self, item_or_key: Union[int, str]) -> Union["Parameters", Any]: 

157 """ 

158 Access parameters by age index or parameter name. 

159 

160 If item_or_key is an integer, returns a Parameters object with the parameters 

161 that apply to that age. This includes all invariant parameters and the 

162 `item_or_key`th element of all age-varying parameters. If item_or_key is a 

163 string, it returns the value of the parameter with that name. 

164 

165 Parameters 

166 ---------- 

167 item_or_key : Union[int, str] 

168 Age index or parameter name. 

169 

170 Returns 

171 ------- 

172 Union[Parameters, Any] 

173 A new Parameters object for the specified age, or the value of the 

174 specified parameter. 

175 

176 Raises 

177 ------ 

178 ValueError: 

179 If the age index is out of bounds. 

180 KeyError: 

181 If the parameter name is not found. 

182 TypeError: 

183 If the key is neither an integer nor a string. 

184 """ 

185 if isinstance(item_or_key, int): 

186 if item_or_key < 0 or item_or_key >= self._length: 

187 raise ValueError( 

188 f"Age {item_or_key} is out of bounds (valid: 0-{self._length - 1})." 

189 ) 

190 

191 params = {key: self._parameters[key] for key in self._invariant_params} 

192 params.update( 

193 { 

194 key: ( 

195 self._parameters[key][item_or_key] 

196 if isinstance(self._parameters[key], (list, tuple, np.ndarray)) 

197 else self._parameters[key] 

198 ) 

199 for key in self._varying_params 

200 } 

201 ) 

202 return Parameters(**params) 

203 elif isinstance(item_or_key, str): 

204 return self._parameters[item_or_key] 

205 else: 

206 raise TypeError("Key must be an integer (age) or string (parameter name).") 

207 

208 def __setitem__(self, key: str, value: Any) -> None: 

209 """ 

210 Set parameter values, automatically inferring time variance. 

211 

212 If the parameter is a scalar, numpy array, boolean, distribution, callable 

213 or None, it is assumed to be invariant over time. If the parameter is a 

214 list or tuple, it is assumed to be varying over time. If the parameter 

215 is a list or tuple of length greater than 1, the length of the list or 

216 tuple must match the `_length` attribute of the Parameters object. 

217 

218 2D numpy arrays with first dimension matching T_cycle are treated as 

219 time-varying parameters. 

220 

221 Parameters 

222 ---------- 

223 key : str 

224 Name of the parameter. 

225 value : Any 

226 Value of the parameter. 

227 

228 Raises 

229 ------ 

230 ValueError: 

231 If the parameter name is not a string or if the value type is unsupported. 

232 If the parameter value is inconsistent with the current model length. 

233 RuntimeError: 

234 If the Parameters object is frozen. 

235 """ 

236 if self._frozen: 

237 raise RuntimeError("Cannot modify frozen Parameters object") 

238 

239 if not isinstance(key, str): 

240 raise ValueError(f"Parameter name must be a string, got {type(key)}") 

241 

242 # Check for 2D numpy arrays with time-varying first dimension 

243 if isinstance(value, np.ndarray) and value.ndim >= 2: 

244 if value.shape[0] == self._length: 

245 self._varying_params.add(key) 

246 self._invariant_params.discard(key) 

247 else: 

248 self._invariant_params.add(key) 

249 self._varying_params.discard(key) 

250 elif isinstance( 

251 value, 

252 ( 

253 int, 

254 float, 

255 np.ndarray, 

256 type(None), 

257 Distribution, 

258 bool, 

259 Callable, 

260 MetricObject, 

261 ), 

262 ): 

263 self._invariant_params.add(key) 

264 self._varying_params.discard(key) 

265 elif isinstance(value, (list, tuple)): 

266 if len(value) == 1: 

267 value = value[0] 

268 self._invariant_params.add(key) 

269 self._varying_params.discard(key) 

270 elif self._length is None or self._length == 1: 

271 self._length = len(value) 

272 self._varying_params.add(key) 

273 self._invariant_params.discard(key) 

274 elif len(value) == self._length: 

275 self._varying_params.add(key) 

276 self._invariant_params.discard(key) 

277 else: 

278 raise ValueError( 

279 f"Parameter {key} must have length 1 or {self._length}, not {len(value)}" 

280 ) 

281 else: 

282 raise ValueError(f"Unsupported type for parameter {key}: {type(value)}") 

283 

284 self._parameters[key] = value 

285 

286 def __iter__(self) -> Iterator[str]: 

287 """Allow iteration over parameter names.""" 

288 return iter(self._parameters) 

289 

290 def __len__(self) -> int: 

291 """Return the number of parameters.""" 

292 return len(self._parameters) 

293 

294 def keys(self) -> Iterator[str]: 

295 """Return a view of parameter names.""" 

296 return self._parameters.keys() 

297 

298 def values(self) -> Iterator[Any]: 

299 """Return a view of parameter values.""" 

300 return self._parameters.values() 

301 

302 def items(self) -> Iterator[Tuple[str, Any]]: 

303 """Return a view of parameter (name, value) pairs.""" 

304 return self._parameters.items() 

305 

306 def to_dict(self) -> Dict[str, Any]: 

307 """ 

308 Convert parameters to a plain dictionary. 

309 

310 Returns 

311 ------- 

312 Dict[str, Any] 

313 A dictionary containing all parameters. 

314 """ 

315 return dict(self._parameters) 

316 

317 def to_namedtuple(self) -> namedtuple: 

318 """ 

319 Convert parameters to a namedtuple. 

320 

321 The namedtuple class is cached for efficiency on repeated calls. 

322 

323 Returns 

324 ------- 

325 namedtuple 

326 A namedtuple containing all parameters. 

327 """ 

328 if self._namedtuple_cache is None: 

329 self._namedtuple_cache = namedtuple("Parameters", self.keys()) 

330 return self._namedtuple_cache(**self.to_dict()) 

331 

332 def update(self, other: Union["Parameters", Dict[str, Any]]) -> None: 

333 """ 

334 Update parameters from another Parameters object or dictionary. 

335 

336 Parameters 

337 ---------- 

338 other : Union[Parameters, Dict[str, Any]] 

339 The source of parameters to update from. 

340 

341 Raises 

342 ------ 

343 TypeError 

344 If the input is neither a Parameters object nor a dictionary. 

345 """ 

346 if isinstance(other, Parameters): 

347 for key, value in other._parameters.items(): 

348 self[key] = value 

349 elif isinstance(other, dict): 

350 for key, value in other.items(): 

351 self[key] = value 

352 else: 

353 raise TypeError( 

354 "Update source must be a Parameters object or a dictionary." 

355 ) 

356 

357 def __repr__(self) -> str: 

358 """Return a detailed string representation of the Parameters object.""" 

359 return ( 

360 f"Parameters(_length={self._length}, " 

361 f"_invariant_params={self._invariant_params}, " 

362 f"_varying_params={self._varying_params}, " 

363 f"_parameters={self._parameters})" 

364 ) 

365 

366 def __str__(self) -> str: 

367 """Return a simple string representation of the Parameters object.""" 

368 return f"Parameters({str(self._parameters)})" 

369 

370 def __getattr__(self, name: str) -> Any: 

371 """ 

372 Allow attribute-style access to parameters. 

373 

374 Parameters 

375 ---------- 

376 name : str 

377 Name of the parameter to access. 

378 

379 Returns 

380 ------- 

381 Any 

382 The value of the specified parameter. 

383 

384 Raises 

385 ------ 

386 AttributeError: 

387 If the parameter name is not found. 

388 """ 

389 if name.startswith("_"): 

390 return super().__getattribute__(name) 

391 try: 

392 return self._parameters[name] 

393 except KeyError: 

394 raise AttributeError(f"'Parameters' object has no attribute '{name}'") 

395 

396 def __setattr__(self, name: str, value: Any) -> None: 

397 """ 

398 Allow attribute-style setting of parameters. 

399 

400 Parameters 

401 ---------- 

402 name : str 

403 Name of the parameter to set. 

404 value : Any 

405 Value to set for the parameter. 

406 """ 

407 if name.startswith("_"): 

408 super().__setattr__(name, value) 

409 else: 

410 self[name] = value 

411 

412 def __contains__(self, item: str) -> bool: 

413 """Check if a parameter exists in the Parameters object.""" 

414 return item in self._parameters 

415 

416 def copy(self) -> "Parameters": 

417 """ 

418 Create a deep copy of the Parameters object. 

419 

420 Returns 

421 ------- 

422 Parameters 

423 A new Parameters object with the same contents. 

424 """ 

425 return deepcopy(self) 

426 

427 def add_to_time_vary(self, *params: str) -> None: 

428 """ 

429 Adds any number of parameters to the time-varying set. 

430 

431 Parameters 

432 ---------- 

433 *params : str 

434 Any number of strings naming parameters to be added to time_vary. 

435 """ 

436 for param in params: 

437 if param in self._parameters: 

438 self._varying_params.add(param) 

439 self._invariant_params.discard(param) 

440 else: 

441 warn( 

442 f"Parameter '{param}' does not exist and cannot be added to time_vary." 

443 ) 

444 

445 def add_to_time_inv(self, *params: str) -> None: 

446 """ 

447 Adds any number of parameters to the time-invariant set. 

448 

449 Parameters 

450 ---------- 

451 *params : str 

452 Any number of strings naming parameters to be added to time_inv. 

453 """ 

454 for param in params: 

455 if param in self._parameters: 

456 self._invariant_params.add(param) 

457 self._varying_params.discard(param) 

458 else: 

459 warn( 

460 f"Parameter '{param}' does not exist and cannot be added to time_inv." 

461 ) 

462 

463 def del_from_time_vary(self, *params: str) -> None: 

464 """ 

465 Removes any number of parameters from the time-varying set. 

466 

467 Parameters 

468 ---------- 

469 *params : str 

470 Any number of strings naming parameters to be removed from time_vary. 

471 """ 

472 for param in params: 

473 self._varying_params.discard(param) 

474 

475 def del_from_time_inv(self, *params: str) -> None: 

476 """ 

477 Removes any number of parameters from the time-invariant set. 

478 

479 Parameters 

480 ---------- 

481 *params : str 

482 Any number of strings naming parameters to be removed from time_inv. 

483 """ 

484 for param in params: 

485 self._invariant_params.discard(param) 

486 

487 def get(self, key: str, default: Any = None) -> Any: 

488 """ 

489 Get a parameter value, returning a default if not found. 

490 

491 Parameters 

492 ---------- 

493 key : str 

494 The parameter name. 

495 default : Any, optional 

496 The default value to return if the key is not found. 

497 

498 Returns 

499 ------- 

500 Any 

501 The parameter value or the default. 

502 """ 

503 return self._parameters.get(key, default) 

504 

505 def set_many(self, **kwargs: Any) -> None: 

506 """ 

507 Set multiple parameters at once. 

508 

509 Parameters 

510 ---------- 

511 **kwargs : Keyword arguments representing parameter names and values. 

512 """ 

513 for key, value in kwargs.items(): 

514 self[key] = value 

515 

516 def is_time_varying(self, key: str) -> bool: 

517 """ 

518 Check if a parameter is time-varying. 

519 

520 Parameters 

521 ---------- 

522 key : str 

523 The parameter name. 

524 

525 Returns 

526 ------- 

527 bool 

528 True if the parameter is time-varying, False otherwise. 

529 """ 

530 return key in self._varying_params 

531 

532 def at_age(self, age: int) -> "Parameters": 

533 """ 

534 Get parameters for a specific age. 

535 

536 This is an alternative to integer indexing (params[age]) that is more 

537 explicit and avoids potential confusion with dictionary-style access. 

538 

539 Parameters 

540 ---------- 

541 age : int 

542 The age index to retrieve parameters for. 

543 

544 Returns 

545 ------- 

546 Parameters 

547 A new Parameters object with parameters for the specified age. 

548 

549 Raises 

550 ------ 

551 ValueError 

552 If the age index is out of bounds. 

553 

554 Examples 

555 -------- 

556 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97], sigma=2.0) 

557 >>> age_1_params = params.at_age(1) 

558 >>> age_1_params.beta 

559 0.96 

560 """ 

561 return self[age] 

562 

563 def validate(self) -> None: 

564 """ 

565 Validate parameter consistency. 

566 

567 Checks that all time-varying parameters have length matching T_cycle. 

568 This is useful after manual modifications or when parameters are set 

569 programmatically. 

570 

571 Raises 

572 ------ 

573 ValueError 

574 If any time-varying parameter has incorrect length. 

575 

576 Examples 

577 -------- 

578 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97]) 

579 >>> params.validate() # Passes 

580 >>> params.add_to_time_vary("beta") 

581 >>> params.validate() # Still passes 

582 """ 

583 errors = [] 

584 for param in self._varying_params: 

585 value = self._parameters[param] 

586 if isinstance(value, (list, tuple)): 

587 if len(value) != self._length: 

588 errors.append( 

589 f"Parameter '{param}' has length {len(value)}, expected {self._length}" 

590 ) 

591 elif isinstance(value, np.ndarray): 

592 if value.ndim == 0: 

593 errors.append( 

594 f"Parameter '{param}' is a 0-dimensional array (scalar), " 

595 "which should not be time-varying" 

596 ) 

597 elif value.ndim >= 2: 

598 if value.shape[0] != self._length: 

599 errors.append( 

600 f"Parameter '{param}' has first dimension {value.shape[0]}, expected {self._length}" 

601 ) 

602 elif value.ndim == 1: 

603 if len(value) != self._length: 

604 errors.append( 

605 f"Parameter '{param}' has length {len(value)}, expected {self._length}" 

606 ) 

607 elif value.ndim == 0: 

608 errors.append( 

609 f"Parameter '{param}' is a 0-dimensional numpy array, expected length {self._length}" 

610 ) 

611 

612 if errors: 

613 raise ValueError( 

614 "Parameter validation failed:\n" + "\n".join(f" - {e}" for e in errors) 

615 ) 

616 

617 

618class Model: 

619 """ 

620 A class with special handling of parameters assignment. 

621 """ 

622 

623 def __init__(self): 

624 if not hasattr(self, "parameters"): 

625 self.parameters = {} 

626 if not hasattr(self, "constructors"): 

627 self.constructors = {} 

628 

629 def assign_parameters(self, **kwds): 

630 """ 

631 Assign an arbitrary number of attributes to this agent. 

632 

633 Parameters 

634 ---------- 

635 **kwds : keyword arguments 

636 Any number of keyword arguments of the form key=value. 

637 Each value will be assigned to the attribute named in self. 

638 

639 Returns 

640 ------- 

641 None 

642 """ 

643 self.parameters.update(kwds) 

644 for key in kwds: 

645 setattr(self, key, kwds[key]) 

646 

647 def get_parameter(self, name): 

648 """ 

649 Returns a parameter of this model 

650 

651 Parameters 

652 ---------- 

653 name : str 

654 The name of the parameter to get 

655 

656 Returns 

657 ------- 

658 value : The value of the parameter 

659 """ 

660 return self.parameters[name] 

661 

662 def __eq__(self, other): 

663 if isinstance(other, type(self)): 

664 return self.parameters == other.parameters 

665 

666 return NotImplemented 

667 

668 def __str__(self): 

669 type_ = type(self) 

670 module = type_.__module__ 

671 qualname = type_.__qualname__ 

672 

673 s = f"<{module}.{qualname} object at {hex(id(self))}.\n" 

674 s += "Parameters:" 

675 

676 for p in self.parameters: 

677 s += f"\n{p}: {self.parameters[p]}" 

678 

679 s += ">" 

680 return s 

681 

682 def describe(self): 

683 return self.__str__() 

684 

685 def del_param(self, param_name): 

686 """ 

687 Deletes a parameter from this instance, removing it both from the object's 

688 namespace (if it's there) and the parameters dictionary (likewise). 

689 

690 Parameters 

691 ---------- 

692 param_name : str 

693 A string naming a parameter or data to be deleted from this instance. 

694 Removes information from self.parameters dictionary and own namespace. 

695 

696 Returns 

697 ------- 

698 None 

699 """ 

700 if param_name in self.parameters: 

701 del self.parameters[param_name] 

702 if hasattr(self, param_name): 

703 delattr(self, param_name) 

704 

705 def construct(self, *args, force=False): 

706 """ 

707 Top-level method for building constructed inputs. If called without any 

708 inputs, construct builds each of the objects named in the keys of the 

709 constructors dictionary; it draws inputs for the constructors from the 

710 parameters dictionary and adds its results to the same. If passed one or 

711 more strings as arguments, the method builds only the named keys. The 

712 method will do multiple "passes" over the requested keys, as some cons- 

713 tructors require inputs built by other constructors. If any requested 

714 constructors failed to build due to missing data, those keys (and the 

715 missing data) will be named in self._missing_key_data. Other errors are 

716 recorded in the dictionary attribute _constructor_errors. 

717 

718 This method tries to "start from scratch" by removing prior constructed 

719 objects, holding them in a backup dictionary during construction. This 

720 is done so that dependencies among constructors are resolved properly, 

721 without mistakenly relying on "old information". A backup value is used 

722 if a constructor function is set to None (i.e. "don't do anything"), or 

723 if the construct method fails to produce a new object. 

724 

725 Parameters 

726 ---------- 

727 *args : str, optional 

728 Keys of self.constructors that are requested to be constructed. 

729 If no arguments are passed, *all* elements of the dictionary are implied. 

730 force : bool, optional 

731 When True, the method will force its way past any errors, including 

732 missing constructors, missing arguments for constructors, and errors 

733 raised during execution of constructors. Information about all such 

734 errors is stored in the dictionary attributes described above. When 

735 False (default), any errors or exception will be raised. 

736 

737 Returns 

738 ------- 

739 None 

740 """ 

741 # Set up the requested work 

742 if len(args) > 0: 

743 keys = args 

744 else: 

745 keys = list(self.constructors.keys()) 

746 N_keys = len(keys) 

747 keys_complete = np.zeros(N_keys, dtype=bool) 

748 if N_keys == 0: 

749 return # Do nothing if there are no constructed objects 

750 

751 # Remove pre-existing constructed objects, preventing "incomplete" updates, 

752 # but store the current values in a backup dictionary in case something fails 

753 backup = {} 

754 for key in keys: 

755 if hasattr(self, key): 

756 backup[key] = getattr(self, key) 

757 self.del_param(key) 

758 

759 # Get the dictionary of constructor errors 

760 if not hasattr(self, "_constructor_errors"): 

761 self._constructor_errors = {} 

762 errors = self._constructor_errors 

763 

764 # As long as the work isn't complete and we made some progress on the last 

765 # pass, repeatedly perform passes of trying to construct objects 

766 any_keys_incomplete = np.any(np.logical_not(keys_complete)) 

767 go = any_keys_incomplete 

768 while go: 

769 anything_accomplished_this_pass = False # Nothing done yet! 

770 missing_key_data = [] # Keep this up-to-date on each pass 

771 

772 # Loop over keys to be constructed 

773 for i in range(N_keys): 

774 if keys_complete[i]: 

775 continue # This key has already been built 

776 

777 # Get this key and its constructor function 

778 key = keys[i] 

779 try: 

780 constructor = self.constructors[key] 

781 except Exception as not_found: 

782 errors[key] = "No constructor found for " + str(not_found) 

783 if force: 

784 continue 

785 else: 

786 raise KeyError("No constructor found for " + key) from None 

787 

788 # If this constructor is None, do nothing and mark it as completed; 

789 # this includes restoring the previous value if it exists 

790 if constructor is None: 

791 if key in backup.keys(): 

792 setattr(self, key, backup[key]) 

793 self.parameters[key] = backup[key] 

794 keys_complete[i] = True 

795 anything_accomplished_this_pass = True # We did something! 

796 continue 

797 

798 # SPECIAL: if the constructor is get_it_from, handle it separately 

799 if isinstance(constructor, get_it_from): 

800 try: 

801 parent = getattr(self, constructor.name) 

802 query = key 

803 any_missing = False 

804 missing_args = [] 

805 except: 

806 parent = None 

807 query = None 

808 any_missing = True 

809 missing_args = [constructor.name] 

810 temp_dict = {"parent": parent, "query": query} 

811 

812 # Get the names of arguments for this constructor and try to gather them 

813 else: # (if it's not the special case of get_it_from) 

814 args_needed = get_arg_names(constructor) 

815 has_no_default = { 

816 k: v.default is inspect.Parameter.empty 

817 for k, v in inspect.signature(constructor).parameters.items() 

818 } 

819 temp_dict = {} 

820 any_missing = False 

821 missing_args = [] 

822 for j in range(len(args_needed)): 

823 this_arg = args_needed[j] 

824 if hasattr(self, this_arg): 

825 temp_dict[this_arg] = getattr(self, this_arg) 

826 else: 

827 try: 

828 temp_dict[this_arg] = self.parameters[this_arg] 

829 except: 

830 if has_no_default[this_arg]: 

831 # Record missing key-data pair 

832 any_missing = True 

833 missing_key_data.append((key, this_arg)) 

834 missing_args.append(this_arg) 

835 

836 # If all of the required data was found, run the constructor and 

837 # store the result in parameters (and on self) 

838 if not any_missing: 

839 try: 

840 temp = constructor(**temp_dict) 

841 except Exception as problem: 

842 errors[key] = str(type(problem)) + ": " + str(problem) 

843 self.del_param(key) 

844 if force: 

845 continue 

846 else: 

847 raise 

848 setattr(self, key, temp) 

849 self.parameters[key] = temp 

850 if key in errors: 

851 del errors[key] 

852 keys_complete[i] = True 

853 anything_accomplished_this_pass = True # We did something! 

854 else: 

855 msg = "Missing required arguments:" 

856 for arg in missing_args: 

857 msg += " " + arg + "," 

858 msg = msg[:-1] 

859 errors[key] = msg 

860 self.del_param(key) 

861 # Never raise exceptions here, as the arguments might be filled in later 

862 

863 # Check whether another pass should be performed 

864 any_keys_incomplete = np.any(np.logical_not(keys_complete)) 

865 go = any_keys_incomplete and anything_accomplished_this_pass 

866 

867 # Store missing key-data pairs and exit 

868 self._missing_key_data = missing_key_data 

869 self._constructor_errors = errors 

870 if any_keys_incomplete: 

871 msg = "Did not construct these objects:" 

872 for i in range(N_keys): 

873 if keys_complete[i]: 

874 continue 

875 msg += " " + keys[i] + "," 

876 key = keys[i] 

877 if key in backup.keys(): 

878 setattr(self, key, backup[key]) 

879 self.parameters[key] = backup[key] 

880 msg = msg[:-1] 

881 if not force: 

882 raise ValueError(msg) 

883 return 

884 

885 def describe_constructors(self, *args): 

886 """ 

887 Prints to screen a string describing this instance's constructed objects, 

888 including their names, the function that constructs them, the names of 

889 those functions inputs, and whether those inputs are present. 

890 

891 Parameters 

892 ---------- 

893 *args : str, optional 

894 Optional list of strings naming constructed inputs to be described. 

895 If none are passed, all constructors are described. 

896 

897 Returns 

898 ------- 

899 None 

900 """ 

901 if len(args) > 0: 

902 keys = args 

903 else: 

904 keys = list(self.constructors.keys()) 

905 yes = "\u2713" 

906 no = "X" 

907 maybe = "*" 

908 noyes = [no, yes] 

909 

910 out = "" 

911 for key in keys: 

912 has_val = hasattr(self, key) or (key in self.parameters) 

913 

914 try: 

915 constructor = self.constructors[key] 

916 except: 

917 out += noyes[int(has_val)] + " " + key + " : NO CONSTRUCTOR FOUND\n" 

918 continue 

919 

920 # Get the constructor function if possible 

921 if isinstance(constructor, get_it_from): 

922 parent_name = self.constructors[key].name 

923 out += ( 

924 noyes[int(has_val)] 

925 + " " 

926 + key 

927 + " : get it from " 

928 + parent_name 

929 + "\n" 

930 ) 

931 continue 

932 else: 

933 out += ( 

934 noyes[int(has_val)] 

935 + " " 

936 + key 

937 + " : " 

938 + constructor.__name__ 

939 + "\n" 

940 ) 

941 

942 # Get constructor argument names 

943 arg_names = get_arg_names(constructor) 

944 has_no_default = { 

945 k: v.default is inspect.Parameter.empty 

946 for k, v in inspect.signature(constructor).parameters.items() 

947 } 

948 

949 # Check whether each argument exists 

950 for j in range(len(arg_names)): 

951 this_arg = arg_names[j] 

952 if hasattr(self, this_arg) or this_arg in self.parameters: 

953 symb = yes 

954 elif not has_no_default[this_arg]: 

955 symb = maybe 

956 else: 

957 symb = no 

958 out += " " + symb + " " + this_arg + "\n" 

959 

960 # Print the string to screen 

961 print(out) 

962 return 

963 

964 # This is a "synonym" method so that old calls to update() still work 

965 def update(self, *args): 

966 self.construct(*args) 

967 

968 

969class AgentType(Model): 

970 """ 

971 A superclass for economic agents in the HARK framework. Each model should 

972 specify its own subclass of AgentType, inheriting its methods and overwriting 

973 as necessary. Critically, every subclass of AgentType should define class- 

974 specific static values of the attributes time_vary and time_inv as lists of 

975 strings. Each element of time_vary is the name of a field in AgentSubType 

976 that varies over time in the model. Each element of time_inv is the name of 

977 a field in AgentSubType that is constant over time in the model. 

978 

979 Parameters 

980 ---------- 

981 solution_terminal : Solution 

982 A representation of the solution to the terminal period problem of 

983 this AgentType instance, or an initial guess of the solution if this 

984 is an infinite horizon problem. 

985 cycles : int 

986 The number of times the sequence of periods is experienced by this 

987 AgentType in their "lifetime". cycles=1 corresponds to a lifecycle 

988 model, with a certain sequence of one period problems experienced 

989 once before terminating. cycles=0 corresponds to an infinite horizon 

990 model, with a sequence of one period problems repeating indefinitely. 

991 pseudo_terminal : bool 

992 Indicates whether solution_terminal isn't actually part of the 

993 solution to the problem (as a known solution to the terminal period 

994 problem), but instead represents a "scrap value"-style termination. 

995 When True, solution_terminal is not included in the solution; when 

996 False, solution_terminal is the last element of the solution. 

997 tolerance : float 

998 Maximum acceptable "distance" between successive solutions to the 

999 one period problem in an infinite horizon (cycles=0) model in order 

1000 for the solution to be considered as having "converged". Inoperative 

1001 when cycles>0. 

1002 verbose : int 

1003 Level of output to be displayed by this instance, default is 1. 

1004 quiet : bool 

1005 Indicator for whether this instance should operate "quietly", default False. 

1006 seed : int 

1007 A seed for this instance's random number generator. 

1008 construct : bool 

1009 Indicator for whether this instance's construct() method should be run 

1010 when initialized (default True). When False, an instance of the class 

1011 can be created even if not all of its attributes can be constructed. 

1012 use_defaults : bool 

1013 Indicator for whether this instance should use the values in the class' 

1014 default dictionary to fill in parameters and constructors for those not 

1015 provided by the user (default True). Setting this to False is useful for 

1016 situations where the user wants to be absolutely sure that they know what 

1017 is being passed to the class initializer, without resorting to defaults. 

1018 

1019 Attributes 

1020 ---------- 

1021 AgentCount : int 

1022 The number of agents of this type to use in simulation. 

1023 

1024 state_vars : list of string 

1025 The string labels for this AgentType's model state variables. 

1026 """ 

1027 

1028 time_vary_ = [] 

1029 time_inv_ = [] 

1030 shock_vars_ = [] 

1031 state_vars = [] 

1032 poststate_vars = [] 

1033 distributions = [] 

1034 default_ = {"params": {}, "solver": NullFunc()} 

1035 

1036 def __init__( 

1037 self, 

1038 solution_terminal=None, 

1039 pseudo_terminal=True, 

1040 tolerance=0.000001, 

1041 verbose=1, 

1042 quiet=False, 

1043 seed=0, 

1044 construct=True, 

1045 use_defaults=True, 

1046 **kwds, 

1047 ): 

1048 super().__init__() 

1049 params = deepcopy(self.default_["params"]) if use_defaults else {} 

1050 params.update(kwds) 

1051 

1052 # Correctly handle constructors that have been passed in kwds 

1053 if "constructors" in self.default_["params"].keys() and use_defaults: 

1054 constructors = deepcopy(self.default_["params"]["constructors"]) 

1055 else: 

1056 constructors = {} 

1057 if "constructors" in kwds.keys(): 

1058 constructors.update(kwds["constructors"]) 

1059 params["constructors"] = constructors 

1060 

1061 # Set model file name if possible 

1062 try: 

1063 self.model_file = copy(self.default_["model"]) 

1064 except (KeyError, TypeError): 

1065 # Fallback to None if "model" key is missing or invalid for copying 

1066 self.model_file = None 

1067 

1068 if solution_terminal is None: 

1069 solution_terminal = NullFunc() 

1070 

1071 self.solve_one_period = self.default_["solver"] # NOQA 

1072 self.solution_terminal = solution_terminal # NOQA 

1073 self.pseudo_terminal = pseudo_terminal # NOQA 

1074 self.tolerance = tolerance # NOQA 

1075 self.verbose = verbose 

1076 self.quiet = quiet 

1077 self.seed = seed # NOQA 

1078 self.track_vars = [] # NOQA 

1079 self.state_now = {sv: None for sv in self.state_vars} 

1080 self.state_prev = self.state_now.copy() 

1081 self.controls = {} 

1082 self.shocks = {} 

1083 self.read_shocks = False # NOQA 

1084 self.shock_history = {} 

1085 self.newborn_init_history = {} 

1086 self.history = {} 

1087 self.assign_parameters(**params) # NOQA 

1088 self.reset_rng() # NOQA 

1089 self.bilt = {} 

1090 if construct: 

1091 self.construct() 

1092 

1093 # Add instance-level lists and objects 

1094 self.time_vary = deepcopy(self.time_vary_) 

1095 self.time_inv = deepcopy(self.time_inv_) 

1096 self.shock_vars = deepcopy(self.shock_vars_) 

1097 

1098 def add_to_time_vary(self, *params): 

1099 """ 

1100 Adds any number of parameters to time_vary for this instance. 

1101 

1102 Parameters 

1103 ---------- 

1104 params : string 

1105 Any number of strings naming attributes to be added to time_vary 

1106 

1107 Returns 

1108 ------- 

1109 None 

1110 """ 

1111 for param in params: 

1112 if param not in self.time_vary: 

1113 self.time_vary.append(param) 

1114 

1115 def add_to_time_inv(self, *params): 

1116 """ 

1117 Adds any number of parameters to time_inv for this instance. 

1118 

1119 Parameters 

1120 ---------- 

1121 params : string 

1122 Any number of strings naming attributes to be added to time_inv 

1123 

1124 Returns 

1125 ------- 

1126 None 

1127 """ 

1128 for param in params: 

1129 if param not in self.time_inv: 

1130 self.time_inv.append(param) 

1131 

1132 def del_from_time_vary(self, *params): 

1133 """ 

1134 Removes any number of parameters from time_vary for this instance. 

1135 

1136 Parameters 

1137 ---------- 

1138 params : string 

1139 Any number of strings naming attributes to be removed from time_vary 

1140 

1141 Returns 

1142 ------- 

1143 None 

1144 """ 

1145 for param in params: 

1146 if param in self.time_vary: 

1147 self.time_vary.remove(param) 

1148 

1149 def del_from_time_inv(self, *params): 

1150 """ 

1151 Removes any number of parameters from time_inv for this instance. 

1152 

1153 Parameters 

1154 ---------- 

1155 params : string 

1156 Any number of strings naming attributes to be removed from time_inv 

1157 

1158 Returns 

1159 ------- 

1160 None 

1161 """ 

1162 for param in params: 

1163 if param in self.time_inv: 

1164 self.time_inv.remove(param) 

1165 

1166 def unpack(self, parameter): 

1167 """ 

1168 Unpacks an attribute from a solution object for easier access. 

1169 After the model has been solved, its components (like consumption function) 

1170 reside in the attributes of each element of `ThisType.solution` (e.g. `cFunc`). 

1171 This method creates a (time varying) attribute of the given attribute name 

1172 that contains a list of elements accessible by `ThisType.parameter`. 

1173 

1174 Parameters 

1175 ---------- 

1176 parameter: str 

1177 Name of the attribute to unpack from the solution 

1178 

1179 Returns 

1180 ------- 

1181 none 

1182 """ 

1183 # Use list comprehension for better performance instead of loop with append 

1184 setattr( 

1185 self, 

1186 parameter, 

1187 [solution_t.__dict__[parameter] for solution_t in self.solution], 

1188 ) 

1189 self.add_to_time_vary(parameter) 

1190 

1191 def solve( 

1192 self, 

1193 verbose=False, 

1194 presolve=True, 

1195 postsolve=True, 

1196 from_solution=None, 

1197 from_t=None, 

1198 ): 

1199 """ 

1200 Solve the model for this instance of an agent type by backward induction. 

1201 Loops through the sequence of one period problems, passing the solution 

1202 from period t+1 to the problem for period t. 

1203 

1204 Parameters 

1205 ---------- 

1206 verbose : bool, optional 

1207 If True, solution progress is printed to screen. Default False. 

1208 presolve : bool, optional 

1209 If True (default), the pre_solve method is run before solving. 

1210 postsolve : bool, optional 

1211 If True (default), the post_solve method is run after solving. 

1212 from_solution: Solution 

1213 If different from None, will be used as the starting point of backward 

1214 induction, instead of self.solution_terminal. 

1215 from_t : int or None 

1216 If not None, indicates which period of the model the solver should start 

1217 from. It should usually only be used in combination with from_solution. 

1218 Stands for the time index that from_solution represents, and thus is 

1219 only compatible with cycles=1 and will be reset to None otherwise. 

1220 

1221 Returns 

1222 ------- 

1223 none 

1224 """ 

1225 

1226 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep- 

1227 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is 

1228 # -np.inf, np.inf/np.inf is np.nan and so on. 

1229 with np.errstate( 

1230 divide="ignore", over="ignore", under="ignore", invalid="ignore" 

1231 ): 

1232 if presolve: 

1233 self.pre_solve() # Do pre-solution stuff 

1234 self.solution = solve_agent( 

1235 self, 

1236 verbose, 

1237 from_solution, 

1238 from_t, 

1239 ) # Solve the model by backward induction 

1240 if postsolve: 

1241 self.post_solve() # Do post-solution stuff 

1242 

1243 def reset_rng(self): 

1244 """ 

1245 Reset the random number generator and all distributions for this type. 

1246 Type-checking for lists is to handle the following three cases: 

1247 

1248 1) The target is a single distribution object 

1249 2) The target is a list of distribution objects (probably time-varying) 

1250 3) The target is a nested list of distributions, as in ConsMarkovModel. 

1251 """ 

1252 self.RNG = np.random.default_rng(self.seed) 

1253 for name in self.distributions: 

1254 if not hasattr(self, name): 

1255 continue 

1256 

1257 dstn = getattr(self, name) 

1258 if isinstance(dstn, list): 

1259 for D in dstn: 

1260 if isinstance(D, list): 

1261 for d in D: 

1262 d.reset() 

1263 else: 

1264 D.reset() 

1265 else: 

1266 dstn.reset() 

1267 

1268 def check_elements_of_time_vary_are_lists(self): 

1269 """ 

1270 A method to check that elements of time_vary are lists. 

1271 """ 

1272 for param in self.time_vary: 

1273 if not hasattr(self, param): 

1274 continue 

1275 if not isinstance( 

1276 getattr(self, param), 

1277 (IndexDistribution,), 

1278 ): 

1279 assert type(getattr(self, param)) == list, ( 

1280 param 

1281 + " is not a list or time varying distribution," 

1282 + " but should be because it is in time_vary" 

1283 ) 

1284 

1285 def check_restrictions(self): 

1286 """ 

1287 A method to check that various restrictions are met for the model class. 

1288 """ 

1289 return 

1290 

1291 def pre_solve(self): 

1292 """ 

1293 A method that is run immediately before the model is solved, to check inputs or to prepare 

1294 the terminal solution, perhaps. 

1295 

1296 Parameters 

1297 ---------- 

1298 none 

1299 

1300 Returns 

1301 ------- 

1302 none 

1303 """ 

1304 self.check_restrictions() 

1305 self.check_elements_of_time_vary_are_lists() 

1306 return None 

1307 

1308 def post_solve(self): 

1309 """ 

1310 A method that is run immediately after the model is solved, to finalize 

1311 the solution in some way. Does nothing here. 

1312 

1313 Parameters 

1314 ---------- 

1315 none 

1316 

1317 Returns 

1318 ------- 

1319 none 

1320 """ 

1321 return None 

1322 

1323 def initialize_sym(self, **kwargs): 

1324 """ 

1325 Use the new simulator structure to build a simulator from the agents' 

1326 attributes, storing it in a private attribute. 

1327 """ 

1328 self.reset_rng() # ensure seeds are set identically each time 

1329 self._simulator = make_simulator_from_agent(self, **kwargs) 

1330 self._simulator.reset() 

1331 

1332 def initialize_sim(self): 

1333 """ 

1334 Prepares this AgentType for a new simulation. Resets the internal random number generator, 

1335 makes initial states for all agents (using sim_birth), clears histories of tracked variables. 

1336 

1337 Parameters 

1338 ---------- 

1339 None 

1340 

1341 Returns 

1342 ------- 

1343 None 

1344 """ 

1345 if not hasattr(self, "T_sim"): 

1346 raise Exception( 

1347 "To initialize simulation variables it is necessary to first " 

1348 + "set the attribute T_sim to the largest number of observations " 

1349 + "you plan to simulate for each agent including re-births." 

1350 ) 

1351 elif self.T_sim <= 0: 

1352 raise Exception( 

1353 "T_sim represents the largest number of observations " 

1354 + "that can be simulated for an agent, and must be a positive number." 

1355 ) 

1356 

1357 self.reset_rng() 

1358 self.t_sim = 0 

1359 all_agents = np.ones(self.AgentCount, dtype=bool) 

1360 blank_array = np.empty(self.AgentCount) 

1361 blank_array[:] = np.nan 

1362 for var in self.state_vars: 

1363 self.state_now[var] = copy(blank_array) 

1364 

1365 # Number of periods since agent entry 

1366 self.t_age = np.zeros(self.AgentCount, dtype=int) 

1367 # Which cycle period each agent is on 

1368 self.t_cycle = np.zeros(self.AgentCount, dtype=int) 

1369 self.sim_birth(all_agents) 

1370 

1371 # If we are asked to use existing shocks and a set of initial conditions 

1372 # exist, use them 

1373 if self.read_shocks and bool(self.newborn_init_history): 

1374 for var_name in self.state_now: 

1375 # Check that we are actually given a value for the variable 

1376 if var_name in self.newborn_init_history.keys(): 

1377 # Copy only array-like idiosyncratic states. Aggregates should 

1378 # not be set by newborns 

1379 idio = ( 

1380 isinstance(self.state_now[var_name], np.ndarray) 

1381 and len(self.state_now[var_name]) == self.AgentCount 

1382 ) 

1383 if idio: 

1384 self.state_now[var_name] = self.newborn_init_history[var_name][ 

1385 0 

1386 ] 

1387 

1388 else: 

1389 warn( 

1390 "The option for reading shocks was activated but " 

1391 + "the model requires state " 

1392 + var_name 

1393 + ", not contained in " 

1394 + "newborn_init_history." 

1395 ) 

1396 

1397 self.clear_history() 

1398 return None 

1399 

1400 def sim_one_period(self): 

1401 """ 

1402 Simulates one period for this type. Calls the methods get_mortality(), get_shocks() or 

1403 read_shocks, get_states(), get_controls(), and get_poststates(). These should be defined for 

1404 AgentType subclasses, except get_mortality (define its components sim_death and sim_birth 

1405 instead) and read_shocks. 

1406 

1407 Parameters 

1408 ---------- 

1409 None 

1410 

1411 Returns 

1412 ------- 

1413 None 

1414 """ 

1415 if not hasattr(self, "solution"): 

1416 raise Exception( 

1417 "Model instance does not have a solution stored. To simulate, it is necessary" 

1418 " to run the `solve()` method first." 

1419 ) 

1420 

1421 # Mortality adjusts the agent population 

1422 self.get_mortality() # Replace some agents with "newborns" 

1423 

1424 # state_{t-1} 

1425 for var in self.state_now: 

1426 self.state_prev[var] = self.state_now[var] 

1427 

1428 if isinstance(self.state_now[var], np.ndarray): 

1429 self.state_now[var] = np.empty(self.AgentCount) 

1430 else: 

1431 # Probably an aggregate variable. It may be getting set by the Market. 

1432 pass 

1433 

1434 if self.read_shocks: # If shock histories have been pre-specified, use those 

1435 self.read_shocks_from_history() 

1436 else: # Otherwise, draw shocks as usual according to subclass-specific method 

1437 self.get_shocks() 

1438 self.get_states() # Determine each agent's state at decision time 

1439 self.get_controls() # Determine each agent's choice or control variables based on states 

1440 self.get_poststates() # Calculate variables that come *after* decision-time 

1441 

1442 # Advance time for all agents 

1443 self.t_age = self.t_age + 1 # Age all consumers by one period 

1444 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle 

1445 self.t_cycle[self.t_cycle == self.T_cycle] = ( 

1446 0 # Resetting to zero for those who have reached the end 

1447 ) 

1448 

1449 def make_shock_history(self): 

1450 """ 

1451 Makes a pre-specified history of shocks for the simulation. Shock variables should be named 

1452 in self.shock_vars, a list of strings that is subclass-specific. This method runs a subset 

1453 of the standard simulation loop by simulating only mortality and shocks; each variable named 

1454 in shock_vars is stored in a T_sim x AgentCount array in history dictionary self.history[X]. 

1455 Automatically sets self.read_shocks to True so that these pre-specified shocks are used for 

1456 all subsequent calls to simulate(). 

1457 

1458 Parameters 

1459 ---------- 

1460 None 

1461 

1462 Returns 

1463 ------- 

1464 None 

1465 """ 

1466 # Re-initialize the simulation 

1467 self.initialize_sim() 

1468 

1469 # Make blank history arrays for each shock variable (and mortality) 

1470 for var_name in self.shock_vars: 

1471 self.shock_history[var_name] = ( 

1472 np.zeros((self.T_sim, self.AgentCount)) + np.nan 

1473 ) 

1474 self.shock_history["who_dies"] = np.zeros( 

1475 (self.T_sim, self.AgentCount), dtype=bool 

1476 ) 

1477 

1478 # Also make blank arrays for the draws of newborns' initial conditions 

1479 for var_name in self.state_vars: 

1480 self.newborn_init_history[var_name] = ( 

1481 np.zeros((self.T_sim, self.AgentCount)) + np.nan 

1482 ) 

1483 

1484 # Record the initial condition of the newborns created by 

1485 # initialize_sim -> sim_births 

1486 for var_name in self.state_vars: 

1487 # Check whether the state is idiosyncratic or an aggregate 

1488 idio = ( 

1489 isinstance(self.state_now[var_name], np.ndarray) 

1490 and len(self.state_now[var_name]) == self.AgentCount 

1491 ) 

1492 if idio: 

1493 self.newborn_init_history[var_name][self.t_sim] = self.state_now[ 

1494 var_name 

1495 ] 

1496 else: 

1497 # Aggregate state is a scalar. Assign it to every agent. 

1498 self.newborn_init_history[var_name][self.t_sim, :] = self.state_now[ 

1499 var_name 

1500 ] 

1501 

1502 # Make and store the history of shocks for each period 

1503 for t in range(self.T_sim): 

1504 # Deaths 

1505 self.get_mortality() 

1506 self.shock_history["who_dies"][t, :] = self.who_dies 

1507 

1508 # Initial conditions of newborns 

1509 if self.who_dies.any(): 

1510 for var_name in self.state_vars: 

1511 # Check whether the state is idiosyncratic or an aggregate 

1512 idio = ( 

1513 isinstance(self.state_now[var_name], np.ndarray) 

1514 and len(self.state_now[var_name]) == self.AgentCount 

1515 ) 

1516 if idio: 

1517 self.newborn_init_history[var_name][t, self.who_dies] = ( 

1518 self.state_now[var_name][self.who_dies] 

1519 ) 

1520 else: 

1521 self.newborn_init_history[var_name][t, self.who_dies] = ( 

1522 self.state_now[var_name] 

1523 ) 

1524 

1525 # Other Shocks 

1526 self.get_shocks() 

1527 for var_name in self.shock_vars: 

1528 self.shock_history[var_name][t, :] = self.shocks[var_name] 

1529 

1530 self.t_sim += 1 

1531 self.t_age = self.t_age + 1 # Age all consumers by one period 

1532 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle 

1533 self.t_cycle[self.t_cycle == self.T_cycle] = ( 

1534 0 # Resetting to zero for those who have reached the end 

1535 ) 

1536 

1537 # Flag that shocks can be read rather than simulated 

1538 self.read_shocks = True 

1539 

1540 def get_mortality(self): 

1541 """ 

1542 Simulates mortality or agent turnover according to some model-specific rules named sim_death 

1543 and sim_birth (methods of an AgentType subclass). sim_death takes no arguments and returns 

1544 a Boolean array of size AgentCount, indicating which agents of this type have "died" and 

1545 must be replaced. sim_birth takes such a Boolean array as an argument and generates initial 

1546 post-decision states for those agent indices. 

1547 

1548 Parameters 

1549 ---------- 

1550 None 

1551 

1552 Returns 

1553 ------- 

1554 None 

1555 """ 

1556 if self.read_shocks: 

1557 who_dies = self.shock_history["who_dies"][self.t_sim, :] 

1558 # Instead of simulating births, assign the saved newborn initial conditions 

1559 if who_dies.any(): 

1560 for var_name in self.state_now: 

1561 if var_name in self.newborn_init_history.keys(): 

1562 # Copy only array-like idiosyncratic states. Aggregates should 

1563 # not be set by newborns 

1564 idio = ( 

1565 isinstance(self.state_now[var_name], np.ndarray) 

1566 and len(self.state_now[var_name]) == self.AgentCount 

1567 ) 

1568 if idio: 

1569 self.state_now[var_name][who_dies] = ( 

1570 self.newborn_init_history[var_name][ 

1571 self.t_sim, who_dies 

1572 ] 

1573 ) 

1574 

1575 else: 

1576 warn( 

1577 "The option for reading shocks was activated but " 

1578 + "the model requires state " 

1579 + var_name 

1580 + ", not contained in " 

1581 + "newborn_init_history." 

1582 ) 

1583 

1584 # Reset ages of newborns 

1585 self.t_age[who_dies] = 0 

1586 self.t_cycle[who_dies] = 0 

1587 else: 

1588 who_dies = self.sim_death() 

1589 self.sim_birth(who_dies) 

1590 self.who_dies = who_dies 

1591 return None 

1592 

1593 def sim_death(self): 

1594 """ 

1595 Determines which agents in the current population "die" or should be replaced. Takes no 

1596 inputs, returns a Boolean array of size self.AgentCount, which has True for agents who die 

1597 and False for those that survive. Returns all False by default, must be overwritten by a 

1598 subclass to have replacement events. 

1599 

1600 Parameters 

1601 ---------- 

1602 None 

1603 

1604 Returns 

1605 ------- 

1606 who_dies : np.array 

1607 Boolean array of size self.AgentCount indicating which agents die and are replaced. 

1608 """ 

1609 who_dies = np.zeros(self.AgentCount, dtype=bool) 

1610 return who_dies 

1611 

1612 def sim_birth(self, which_agents): # pragma: nocover 

1613 """ 

1614 Makes new agents for the simulation. Takes a boolean array as an input, indicating which 

1615 agent indices are to be "born". Does nothing by default, must be overwritten by a subclass. 

1616 

1617 Parameters 

1618 ---------- 

1619 which_agents : np.array(Bool) 

1620 Boolean array of size self.AgentCount indicating which agents should be "born". 

1621 

1622 Returns 

1623 ------- 

1624 None 

1625 """ 

1626 raise Exception("AgentType subclass must define method sim_birth!") 

1627 

1628 def get_shocks(self): # pragma: nocover 

1629 """ 

1630 Gets values of shock variables for the current period. Does nothing by default, but can 

1631 be overwritten by subclasses of AgentType. 

1632 

1633 Parameters 

1634 ---------- 

1635 None 

1636 

1637 Returns 

1638 ------- 

1639 None 

1640 """ 

1641 return None 

1642 

1643 def read_shocks_from_history(self): 

1644 """ 

1645 Reads values of shock variables for the current period from history arrays. 

1646 For each variable X named in self.shock_vars, this attribute of self is 

1647 set to self.history[X][self.t_sim,:]. 

1648 

1649 This method is only ever called if self.read_shocks is True. This can 

1650 be achieved by using the method make_shock_history() (or manually after 

1651 storing a "handcrafted" shock history). 

1652 

1653 Parameters 

1654 ---------- 

1655 None 

1656 

1657 Returns 

1658 ------- 

1659 None 

1660 """ 

1661 for var_name in self.shock_vars: 

1662 self.shocks[var_name] = self.shock_history[var_name][self.t_sim, :] 

1663 

1664 def get_states(self): 

1665 """ 

1666 Gets values of state variables for the current period. 

1667 By default, calls transition function and assigns values 

1668 to the state_now dictionary. 

1669 

1670 Parameters 

1671 ---------- 

1672 None 

1673 

1674 Returns 

1675 ------- 

1676 None 

1677 """ 

1678 new_states = self.transition() 

1679 

1680 for i, var in enumerate(self.state_now): 

1681 # a hack for now to deal with 'post-states' 

1682 if i < len(new_states): 

1683 self.state_now[var] = new_states[i] 

1684 

1685 def transition(self): # pragma: nocover 

1686 """ 

1687 

1688 Parameters 

1689 ---------- 

1690 None 

1691 

1692 [Eventually, to match dolo spec: 

1693 exogenous_prev, endogenous_prev, controls, exogenous, parameters] 

1694 

1695 Returns 

1696 ------- 

1697 

1698 endogenous_state: () 

1699 Tuple with new values of the endogenous states 

1700 """ 

1701 return () 

1702 

1703 def get_controls(self): # pragma: nocover 

1704 """ 

1705 Gets values of control variables for the current period, probably by using current states. 

1706 Does nothing by default, but can be overwritten by subclasses of AgentType. 

1707 

1708 Parameters 

1709 ---------- 

1710 None 

1711 

1712 Returns 

1713 ------- 

1714 None 

1715 """ 

1716 return None 

1717 

1718 def get_poststates(self): 

1719 """ 

1720 Gets values of post-decision state variables for the current period, 

1721 probably by current 

1722 states and controls and maybe market-level events or shock variables. 

1723 Does nothing by 

1724 default, but can be overwritten by subclasses of AgentType. 

1725 

1726 Parameters 

1727 ---------- 

1728 None 

1729 

1730 Returns 

1731 ------- 

1732 None 

1733 """ 

1734 return None 

1735 

1736 def symulate(self, T=None): 

1737 """ 

1738 Run the new simulation structure, with history results written to the 

1739 hystory attribute of self. 

1740 """ 

1741 self._simulator.simulate(T) 

1742 self.hystory = self._simulator.history 

1743 

1744 def describe_model(self, display=True): 

1745 """ 

1746 Print to screen information about this agent's model, based on its model 

1747 file. This is useful for learning about outcome variable names for tracking 

1748 during simulation, or for use with sequence space Jacobians. 

1749 """ 

1750 if not hasattr(self, "_simulator"): 

1751 self.initialize_sym() 

1752 self._simulator.describe(display=display) 

1753 

1754 def simulate(self, sim_periods=None): 

1755 """ 

1756 Simulates this agent type for a given number of periods. Defaults to self.T_sim, 

1757 or all remaining periods to simulate (T_sim - t_sim). Records histories of 

1758 attributes named in self.track_vars in self.history[varname]. 

1759 

1760 Parameters 

1761 ---------- 

1762 sim_periods : int or None 

1763 Number of periods to simulate. Default is all remaining periods (usually T_sim). 

1764 

1765 Returns 

1766 ------- 

1767 history : dict 

1768 The history tracked during the simulation. 

1769 """ 

1770 if not hasattr(self, "t_sim"): 

1771 raise Exception( 

1772 "It seems that the simulation variables were not initialize before calling " 

1773 + "simulate(). Call initialize_sim() to initialize the variables before calling simulate() again." 

1774 ) 

1775 

1776 if not hasattr(self, "T_sim"): 

1777 raise Exception( 

1778 "This agent type instance must have the attribute T_sim set to a positive integer." 

1779 + "Set T_sim to match the largest dataset you might simulate, and run this agent's" 

1780 + "initialize_sim() method before running simulate() again." 

1781 ) 

1782 

1783 if sim_periods is not None and self.T_sim < sim_periods: 

1784 raise Exception( 

1785 "To simulate, sim_periods has to be larger than the maximum data set size " 

1786 + "T_sim. Either increase the attribute T_sim of this agent type instance " 

1787 + "and call the initialize_sim() method again, or set sim_periods <= T_sim." 

1788 ) 

1789 

1790 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep- 

1791 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is 

1792 # -np.inf, np.inf/np.inf is np.nan and so on. 

1793 with np.errstate( 

1794 divide="ignore", over="ignore", under="ignore", invalid="ignore" 

1795 ): 

1796 if sim_periods is None: 

1797 sim_periods = self.T_sim - self.t_sim 

1798 

1799 for t in range(sim_periods): 

1800 self.sim_one_period() 

1801 

1802 for var_name in self.track_vars: 

1803 if var_name in self.state_now: 

1804 self.history[var_name][self.t_sim, :] = self.state_now[var_name] 

1805 elif var_name in self.shocks: 

1806 self.history[var_name][self.t_sim, :] = self.shocks[var_name] 

1807 elif var_name in self.controls: 

1808 self.history[var_name][self.t_sim, :] = self.controls[var_name] 

1809 else: 

1810 if var_name == "who_dies" and self.t_sim > 1: 

1811 self.history[var_name][self.t_sim - 1, :] = getattr( 

1812 self, var_name 

1813 ) 

1814 else: 

1815 self.history[var_name][self.t_sim, :] = getattr( 

1816 self, var_name 

1817 ) 

1818 self.t_sim += 1 

1819 

1820 def clear_history(self): 

1821 """ 

1822 Clears the histories of the attributes named in self.track_vars. 

1823 

1824 Parameters 

1825 ---------- 

1826 None 

1827 

1828 Returns 

1829 ------- 

1830 None 

1831 """ 

1832 for var_name in self.track_vars: 

1833 self.history[var_name] = np.empty((self.T_sim, self.AgentCount)) 

1834 self.history[var_name].fill(np.nan) 

1835 

1836 def make_basic_SSJ(self, shock, outcomes, grids, **kwargs): 

1837 """ 

1838 Construct and return sequence space Jacobian matrices for specified outcomes 

1839 with respect to specified "shock" variable. This "basic" method only works 

1840 for "one period infinite horizon" models (cycles=0, T_cycle=1). See documen- 

1841 tation for simulator.make_basic_SSJ_matrices for more information. 

1842 """ 

1843 return make_basic_SSJ_matrices(self, shock, outcomes, grids, **kwargs) 

1844 

1845 def calc_impulse_response_manually(self, shock, outcomes, grids, **kwargs): 

1846 """ 

1847 Calculate and return the impulse response(s) of a perturbation to the shock 

1848 parameter in period t=s, essentially computing one column of the sequence 

1849 space Jacobian matrix manually. This "basic" method only works for "one 

1850 period infinite horizon" models (cycles=0, T_cycle=1). See documentation 

1851 for simulator.calc_shock_response_manually for more information. 

1852 """ 

1853 return calc_shock_response_manually(self, shock, outcomes, grids, **kwargs) 

1854 

1855 

1856def solve_agent(agent, verbose, from_solution=None, from_t=None): 

1857 """ 

1858 Solve the dynamic model for one agent type using backwards induction. This 

1859 function iterates on "cycles" of an agent's model either a given number of 

1860 times or until solution convergence if an infinite horizon model is used 

1861 (with agent.cycles = 0). 

1862 

1863 Parameters 

1864 ---------- 

1865 agent : AgentType 

1866 The microeconomic AgentType whose dynamic problem 

1867 is to be solved. 

1868 verbose : boolean 

1869 If True, solution progress is printed to screen (when cycles != 1). 

1870 from_solution: Solution 

1871 If different from None, will be used as the starting point of backward 

1872 induction, instead of self.solution_terminal 

1873 from_t : int or None 

1874 If not None, indicates which period of the model the solver should start 

1875 from. It should usually only be used in combination with from_solution. 

1876 Stands for the time index that from_solution represents, and thus is 

1877 only compatible with cycles=1 and will be reset to None otherwise. 

1878 

1879 Returns 

1880 ------- 

1881 solution : [Solution] 

1882 A list of solutions to the one period problems that the agent will 

1883 encounter in his "lifetime". 

1884 """ 

1885 # Check to see whether this is an (in)finite horizon problem 

1886 cycles_left = agent.cycles # NOQA 

1887 infinite_horizon = cycles_left == 0 # NOQA 

1888 

1889 if from_solution is None: 

1890 solution_last = agent.solution_terminal # NOQA 

1891 else: 

1892 solution_last = from_solution 

1893 if agent.cycles != 1: 

1894 from_t = None 

1895 

1896 # Initialize the solution, which includes the terminal solution if it's not a pseudo-terminal period 

1897 solution = [] 

1898 if not agent.pseudo_terminal: 

1899 solution.insert(0, deepcopy(solution_last)) 

1900 

1901 # Initialize the process, then loop over cycles 

1902 go = True # NOQA 

1903 completed_cycles = 0 # NOQA 

1904 max_cycles = 5000 # NOQA - escape clause 

1905 if verbose: 

1906 t_last = time() 

1907 while go: 

1908 # Solve a cycle of the model, recording it if horizon is finite 

1909 solution_cycle = solve_one_cycle(agent, solution_last, from_t) 

1910 if not infinite_horizon: 

1911 solution = solution_cycle + solution 

1912 

1913 # Check for termination: identical solutions across 

1914 # cycle iterations or run out of cycles 

1915 solution_now = solution_cycle[0] 

1916 if infinite_horizon: 

1917 if completed_cycles > 0: 

1918 solution_distance = solution_now.distance(solution_last) 

1919 agent.solution_distance = ( 

1920 solution_distance # Add these attributes so users can 

1921 ) 

1922 agent.completed_cycles = ( 

1923 completed_cycles # query them to see if solution is ready 

1924 ) 

1925 go = ( 

1926 solution_distance > agent.tolerance 

1927 and completed_cycles < max_cycles 

1928 ) 

1929 else: # Assume solution does not converge after only one cycle 

1930 solution_distance = 100.0 

1931 go = True 

1932 else: 

1933 cycles_left += -1 

1934 go = cycles_left > 0 

1935 

1936 # Update the "last period solution" 

1937 solution_last = solution_now 

1938 completed_cycles += 1 

1939 

1940 # Display progress if requested 

1941 if verbose: 

1942 t_now = time() 

1943 if infinite_horizon: 

1944 print( 

1945 "Finished cycle #" 

1946 + str(completed_cycles) 

1947 + " in " 

1948 + str(t_now - t_last) 

1949 + " seconds, solution distance = " 

1950 + str(solution_distance) 

1951 ) 

1952 else: 

1953 print( 

1954 "Finished cycle #" 

1955 + str(completed_cycles) 

1956 + " of " 

1957 + str(agent.cycles) 

1958 + " in " 

1959 + str(t_now - t_last) 

1960 + " seconds." 

1961 ) 

1962 t_last = t_now 

1963 

1964 # Record the last cycle if horizon is infinite (solution is still empty!) 

1965 if infinite_horizon: 

1966 solution = ( 

1967 solution_cycle # PseudoTerminal=False impossible for infinite horizon 

1968 ) 

1969 

1970 return solution 

1971 

1972 

1973def solve_one_cycle(agent, solution_last, from_t): 

1974 """ 

1975 Solve one "cycle" of the dynamic model for one agent type. This function 

1976 iterates over the periods within an agent's cycle, updating the time-varying 

1977 parameters and passing them to the single period solver(s). 

1978 

1979 Parameters 

1980 ---------- 

1981 agent : AgentType 

1982 The microeconomic AgentType whose dynamic problem is to be solved. 

1983 solution_last : Solution 

1984 A representation of the solution of the period that comes after the 

1985 end of the sequence of one period problems. This might be the term- 

1986 inal period solution, a "pseudo terminal" solution, or simply the 

1987 solution to the earliest period from the succeeding cycle. 

1988 from_t : int or None 

1989 If not None, indicates which period of the model the solver should start 

1990 from. When used, represents the time index that solution_last is from. 

1991 

1992 Returns 

1993 ------- 

1994 solution_cycle : [Solution] 

1995 A list of one period solutions for one "cycle" of the AgentType's 

1996 microeconomic model. 

1997 """ 

1998 

1999 # Check if the agent has a 'Parameters' attribute of the 'Parameters' class 

2000 # if so, take advantage of it. Else, use the old method 

2001 if hasattr(agent, "parameters") and isinstance(agent.parameters, Parameters): 

2002 T = agent.parameters._length if from_t is None else from_t 

2003 

2004 # Initialize the solution for this cycle, then iterate on periods 

2005 solution_cycle = [] 

2006 solution_next = solution_last 

2007 

2008 cycles_range = [0] + list(range(T - 1, 0, -1)) 

2009 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range: 

2010 # Update which single period solver to use (if it depends on time) 

2011 if hasattr(agent.solve_one_period, "__getitem__"): 

2012 solve_one_period = agent.solve_one_period[k] 

2013 else: 

2014 solve_one_period = agent.solve_one_period 

2015 

2016 if hasattr(solve_one_period, "solver_args"): 

2017 these_args = solve_one_period.solver_args 

2018 else: 

2019 these_args = get_arg_names(solve_one_period) 

2020 

2021 # Make a temporary dictionary for this period 

2022 temp_pars = agent.parameters[k] 

2023 temp_dict = { 

2024 name: solution_next if name == "solution_next" else temp_pars[name] 

2025 for name in these_args 

2026 } 

2027 

2028 # Solve one period, add it to the solution, and move to the next period 

2029 solution_t = solve_one_period(**temp_dict) 

2030 solution_cycle.insert(0, solution_t) 

2031 solution_next = solution_t 

2032 

2033 else: 

2034 # Calculate number of periods per cycle, defaults to 1 if all variables are time invariant 

2035 if len(agent.time_vary) > 0: 

2036 T = agent.T_cycle if from_t is None else from_t 

2037 else: 

2038 T = 1 

2039 

2040 solve_dict = { 

2041 parameter: agent.__dict__[parameter] for parameter in agent.time_inv 

2042 } 

2043 solve_dict.update({parameter: None for parameter in agent.time_vary}) 

2044 

2045 # Initialize the solution for this cycle, then iterate on periods 

2046 solution_cycle = [] 

2047 solution_next = solution_last 

2048 

2049 cycles_range = [0] + list(range(T - 1, 0, -1)) 

2050 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range: 

2051 # Update which single period solver to use (if it depends on time) 

2052 if hasattr(agent.solve_one_period, "__getitem__"): 

2053 solve_one_period = agent.solve_one_period[k] 

2054 else: 

2055 solve_one_period = agent.solve_one_period 

2056 

2057 if hasattr(solve_one_period, "solver_args"): 

2058 these_args = solve_one_period.solver_args 

2059 else: 

2060 these_args = get_arg_names(solve_one_period) 

2061 

2062 # Update time-varying single period inputs 

2063 for name in agent.time_vary: 

2064 if name in these_args: 

2065 solve_dict[name] = agent.__dict__[name][k] 

2066 solve_dict["solution_next"] = solution_next 

2067 

2068 # Make a temporary dictionary for this period 

2069 temp_dict = {name: solve_dict[name] for name in these_args} 

2070 

2071 # Solve one period, add it to the solution, and move to the next period 

2072 solution_t = solve_one_period(**temp_dict) 

2073 solution_cycle.insert(0, solution_t) 

2074 solution_next = solution_t 

2075 

2076 # Return the list of per-period solutions 

2077 return solution_cycle 

2078 

2079 

2080def make_one_period_oo_solver(solver_class): 

2081 """ 

2082 Returns a function that solves a single period consumption-saving 

2083 problem. 

2084 Parameters 

2085 ---------- 

2086 solver_class : Solver 

2087 A class of Solver to be used. 

2088 ------- 

2089 solver_function : function 

2090 A function for solving one period of a problem. 

2091 """ 

2092 

2093 def one_period_solver(**kwds): 

2094 solver = solver_class(**kwds) 

2095 

2096 # not ideal; better if this is defined in all Solver classes 

2097 if hasattr(solver, "prepare_to_solve"): 

2098 solver.prepare_to_solve() 

2099 

2100 solution_now = solver.solve() 

2101 return solution_now 

2102 

2103 one_period_solver.solver_class = solver_class 

2104 # This can be revisited once it is possible to export parameters 

2105 one_period_solver.solver_args = get_arg_names(solver_class.__init__)[1:] 

2106 

2107 return one_period_solver 

2108 

2109 

2110# ======================================================================== 

2111# ======================================================================== 

2112 

2113 

2114class Market(Model): 

2115 """ 

2116 A superclass to represent a central clearinghouse of information. Used for 

2117 dynamic general equilibrium models to solve the "macroeconomic" model as a 

2118 layer on top of the "microeconomic" models of one or more AgentTypes. 

2119 

2120 Parameters 

2121 ---------- 

2122 agents : [AgentType] 

2123 A list of all the AgentTypes in this market. 

2124 sow_vars : [string] 

2125 Names of variables generated by the "aggregate market process" that should 

2126 be "sown" to the agents in the market. Aggregate state, etc. 

2127 reap_vars : [string] 

2128 Names of variables to be collected ("reaped") from agents in the market 

2129 to be used in the "aggregate market process". 

2130 const_vars : [string] 

2131 Names of attributes of the Market instance that are used in the "aggregate 

2132 market process" but do not come from agents-- they are constant or simply 

2133 parameters inherent to the process. 

2134 track_vars : [string] 

2135 Names of variables generated by the "aggregate market process" that should 

2136 be tracked as a "history" so that a new dynamic rule can be calculated. 

2137 This is often a subset of sow_vars. 

2138 dyn_vars : [string] 

2139 Names of variables that constitute a "dynamic rule". 

2140 mill_rule : function 

2141 A function that takes inputs named in reap_vars and returns a tuple the 

2142 same size and order as sow_vars. The "aggregate market process" that 

2143 transforms individual agent actions/states/data into aggregate data to 

2144 be sent back to agents. 

2145 calc_dynamics : function 

2146 A function that takes inputs named in track_vars and returns an object 

2147 with attributes named in dyn_vars. Looks at histories of aggregate 

2148 variables and generates a new "dynamic rule" for agents to believe and 

2149 act on. 

2150 act_T : int 

2151 The number of times that the "aggregate market process" should be run 

2152 in order to generate a history of aggregate variables. 

2153 tolerance: float 

2154 Minimum acceptable distance between "dynamic rules" to consider the 

2155 Market solution process converged. Distance is a user-defined metric. 

2156 """ 

2157 

2158 def __init__( 

2159 self, 

2160 agents=None, 

2161 sow_vars=None, 

2162 reap_vars=None, 

2163 const_vars=None, 

2164 track_vars=None, 

2165 dyn_vars=None, 

2166 mill_rule=None, 

2167 calc_dynamics=None, 

2168 act_T=1000, 

2169 tolerance=0.000001, 

2170 **kwds, 

2171 ): 

2172 super().__init__() 

2173 self.agents = agents if agents is not None else list() # NOQA 

2174 

2175 self.reap_vars = reap_vars if reap_vars is not None else list() # NOQA 

2176 self.reap_state = {var: [] for var in self.reap_vars} 

2177 

2178 self.sow_vars = sow_vars if sow_vars is not None else list() # NOQA 

2179 # dictionaries for tracking initial and current values 

2180 # of the sow variables. 

2181 self.sow_init = {var: None for var in self.sow_vars} 

2182 self.sow_state = {var: None for var in self.sow_vars} 

2183 

2184 const_vars = const_vars if const_vars is not None else list() # NOQA 

2185 self.const_vars = {var: None for var in const_vars} 

2186 

2187 self.track_vars = track_vars if track_vars is not None else list() # NOQA 

2188 self.dyn_vars = dyn_vars if dyn_vars is not None else list() # NOQA 

2189 

2190 if mill_rule is not None: # To prevent overwriting of method-based mill_rules 

2191 self.mill_rule = mill_rule 

2192 if calc_dynamics is not None: # Ditto for calc_dynamics 

2193 self.calc_dynamics = calc_dynamics 

2194 self.act_T = act_T # NOQA 

2195 self.tolerance = tolerance # NOQA 

2196 self.max_loops = 1000 # NOQA 

2197 self.history = {} 

2198 self.assign_parameters(**kwds) 

2199 

2200 self.print_parallel_error_once = True 

2201 # Print the error associated with calling the parallel method 

2202 # "solve_agents" one time. If set to false, the error will never 

2203 # print. See "solve_agents" for why this prints once or never. 

2204 

2205 def solve_agents(self): 

2206 """ 

2207 Solves the microeconomic problem for all AgentTypes in this market. 

2208 

2209 Parameters 

2210 ---------- 

2211 None 

2212 

2213 Returns 

2214 ------- 

2215 None 

2216 """ 

2217 try: 

2218 multi_thread_commands(self.agents, ["solve()"]) 

2219 except Exception as err: 

2220 if self.print_parallel_error_once: 

2221 # Set flag to False so this is only printed once. 

2222 self.print_parallel_error_once = False 

2223 print( 

2224 "**** WARNING: could not execute multi_thread_commands in HARK.core.Market.solve_agents() ", 

2225 "so using the serial version instead. This will likely be slower. " 

2226 "The multi_thread_commands() functions failed with the following error:", 

2227 "\n", 

2228 sys.exc_info()[0], 

2229 ":", 

2230 err, 

2231 ) # sys.exc_info()[0]) 

2232 multi_thread_commands_fake(self.agents, ["solve()"]) 

2233 

2234 def solve(self): 

2235 """ 

2236 "Solves" the market by finding a "dynamic rule" that governs the aggregate 

2237 market state such that when agents believe in these dynamics, their actions 

2238 collectively generate the same dynamic rule. 

2239 

2240 Parameters 

2241 ---------- 

2242 None 

2243 

2244 Returns 

2245 ------- 

2246 None 

2247 """ 

2248 go = True 

2249 max_loops = self.max_loops # Failsafe against infinite solution loop 

2250 completed_loops = 0 

2251 old_dynamics = None 

2252 

2253 while go: # Loop until the dynamic process converges or we hit the loop cap 

2254 self.solve_agents() # Solve each AgentType's micro problem 

2255 self.make_history() # "Run" the model while tracking aggregate variables 

2256 new_dynamics = self.update_dynamics() # Find a new aggregate dynamic rule 

2257 

2258 # Check to see if the dynamic rule has converged (if this is not the first loop) 

2259 if completed_loops > 0: 

2260 distance = new_dynamics.distance(old_dynamics) 

2261 else: 

2262 distance = 1000000.0 

2263 

2264 # Move to the next loop if the terminal conditions are not met 

2265 old_dynamics = new_dynamics 

2266 completed_loops += 1 

2267 go = distance >= self.tolerance and completed_loops < max_loops 

2268 

2269 self.dynamics = new_dynamics # Store the final dynamic rule in self 

2270 

2271 def reap(self): 

2272 """ 

2273 Collects attributes named in reap_vars from each AgentType in the market, 

2274 storing them in respectively named attributes of self. 

2275 

2276 Parameters 

2277 ---------- 

2278 none 

2279 

2280 Returns 

2281 ------- 

2282 none 

2283 """ 

2284 for var in self.reap_state: 

2285 harvest = [] 

2286 

2287 for agent in self.agents: 

2288 # TODO: generalized variable lookup across namespaces 

2289 if var in agent.state_now: 

2290 # or state_now ?? 

2291 harvest.append(agent.state_now[var]) 

2292 

2293 self.reap_state[var] = harvest 

2294 

2295 def sow(self): 

2296 """ 

2297 Distributes attrributes named in sow_vars from self to each AgentType 

2298 in the market, storing them in respectively named attributes. 

2299 

2300 Parameters 

2301 ---------- 

2302 none 

2303 

2304 Returns 

2305 ------- 

2306 none 

2307 """ 

2308 for sow_var in self.sow_state: 

2309 for this_type in self.agents: 

2310 if sow_var in this_type.state_now: 

2311 this_type.state_now[sow_var] = self.sow_state[sow_var] 

2312 if sow_var in this_type.shocks: 

2313 this_type.shocks[sow_var] = self.sow_state[sow_var] 

2314 else: 

2315 setattr(this_type, sow_var, self.sow_state[sow_var]) 

2316 

2317 def mill(self): 

2318 """ 

2319 Processes the variables collected from agents using the function mill_rule, 

2320 storing the results in attributes named in aggr_sow. 

2321 

2322 Parameters 

2323 ---------- 

2324 none 

2325 

2326 Returns 

2327 ------- 

2328 none 

2329 """ 

2330 # Make a dictionary of inputs for the mill_rule 

2331 mill_dict = copy(self.reap_state) 

2332 mill_dict.update(self.const_vars) 

2333 

2334 # Run the mill_rule and store its output in self 

2335 product = self.mill_rule(**mill_dict) 

2336 

2337 for i, sow_var in enumerate(self.sow_state): 

2338 self.sow_state[sow_var] = product[i] 

2339 

2340 def cultivate(self): 

2341 """ 

2342 Has each AgentType in agents perform their market_action method, using 

2343 variables sown from the market (and maybe also "private" variables). 

2344 The market_action method should store new results in attributes named in 

2345 reap_vars to be reaped later. 

2346 

2347 Parameters 

2348 ---------- 

2349 none 

2350 

2351 Returns 

2352 ------- 

2353 none 

2354 """ 

2355 for this_type in self.agents: 

2356 this_type.market_action() 

2357 

2358 def reset(self): 

2359 """ 

2360 Reset the state of the market (attributes in sow_vars, etc) to some 

2361 user-defined initial state, and erase the histories of tracked variables. 

2362 

2363 Parameters 

2364 ---------- 

2365 none 

2366 

2367 Returns 

2368 ------- 

2369 none 

2370 """ 

2371 # Reset the history of tracked variables 

2372 self.history = {var_name: [] for var_name in self.track_vars} 

2373 

2374 # Set the sow variables to their initial levels 

2375 for var_name in self.sow_state: 

2376 self.sow_state[var_name] = self.sow_init[var_name] 

2377 

2378 # Reset each AgentType in the market 

2379 for this_type in self.agents: 

2380 this_type.reset() 

2381 

2382 def store(self): 

2383 """ 

2384 Record the current value of each variable X named in track_vars in an 

2385 dictionary field named history[X]. 

2386 

2387 Parameters 

2388 ---------- 

2389 none 

2390 

2391 Returns 

2392 ------- 

2393 none 

2394 """ 

2395 for var_name in self.track_vars: 

2396 if var_name in self.sow_state: 

2397 value_now = self.sow_state[var_name] 

2398 elif var_name in self.reap_state: 

2399 value_now = self.reap_state[var_name] 

2400 elif var_name in self.const_vars: 

2401 value_now = self.const_vars[var_name] 

2402 else: 

2403 value_now = getattr(self, var_name) 

2404 

2405 self.history[var_name].append(value_now) 

2406 

2407 def make_history(self): 

2408 """ 

2409 Runs a loop of sow-->cultivate-->reap-->mill act_T times, tracking the 

2410 evolution of variables X named in track_vars in dictionary fields 

2411 history[X]. 

2412 

2413 Parameters 

2414 ---------- 

2415 none 

2416 

2417 Returns 

2418 ------- 

2419 none 

2420 """ 

2421 self.reset() # Initialize the state of the market 

2422 for t in range(self.act_T): 

2423 self.sow() # Distribute aggregated information/state to agents 

2424 self.cultivate() # Agents take action 

2425 self.reap() # Collect individual data from agents 

2426 self.mill() # Process individual data into aggregate data 

2427 self.store() # Record variables of interest 

2428 

2429 def update_dynamics(self): 

2430 """ 

2431 Calculates a new "aggregate dynamic rule" using the history of variables 

2432 named in track_vars, and distributes this rule to AgentTypes in agents. 

2433 

2434 Parameters 

2435 ---------- 

2436 none 

2437 

2438 Returns 

2439 ------- 

2440 dynamics : instance 

2441 The new "aggregate dynamic rule" that agents believe in and act on. 

2442 Should have attributes named in dyn_vars. 

2443 """ 

2444 # Make a dictionary of inputs for the dynamics calculator 

2445 arg_names = list(get_arg_names(self.calc_dynamics)) 

2446 if "self" in arg_names: 

2447 arg_names.remove("self") 

2448 update_dict = {name: self.history[name] for name in arg_names} 

2449 # Calculate a new dynamic rule and distribute it to the agents in agent_list 

2450 dynamics = self.calc_dynamics(**update_dict) # User-defined dynamics calculator 

2451 for var_name in self.dyn_vars: 

2452 this_obj = getattr(dynamics, var_name) 

2453 for this_type in self.agents: 

2454 setattr(this_type, var_name, this_obj) 

2455 return dynamics 

2456 

2457 

2458def distribute_params(agent, param_name, param_count, distribution): 

2459 """ 

2460 Distributes heterogeneous values of one parameter to the AgentTypes in self.agents. 

2461 Parameters 

2462 ---------- 

2463 agent: AgentType 

2464 An agent to clone. 

2465 param_name : string 

2466 Name of the parameter to be assigned. 

2467 param_count : int 

2468 Number of different values the parameter will take on. 

2469 distribution : Distribution 

2470 A 1-D distribution. 

2471 

2472 Returns 

2473 ------- 

2474 agent_set : [AgentType] 

2475 A list of param_count agents, ex ante heterogeneous with 

2476 respect to param_name. The AgentCount of the original 

2477 will be split between the agents of the returned 

2478 list in proportion to the given distribution. 

2479 """ 

2480 param_dist = distribution.discretize(N=param_count) 

2481 

2482 agent_set = [deepcopy(agent) for i in range(param_count)] 

2483 

2484 for j in range(param_count): 

2485 agent_set[j].assign_parameters( 

2486 **{"AgentCount": int(agent.AgentCount * param_dist.pmv[j])} 

2487 ) 

2488 agent_set[j].assign_parameters(**{param_name: param_dist.atoms[0, j]}) 

2489 

2490 return agent_set 

2491 

2492 

2493@dataclass 

2494class AgentPopulation: 

2495 """ 

2496 A class for representing a population of ex-ante heterogeneous agents. 

2497 """ 

2498 

2499 agent_type: AgentType # type of agent in the population 

2500 parameters: dict # dictionary of parameters 

2501 seed: int = 0 # random seed 

2502 time_var: List[str] = field(init=False) 

2503 time_inv: List[str] = field(init=False) 

2504 distributed_params: List[str] = field(init=False) 

2505 agent_type_count: Optional[int] = field(init=False) 

2506 term_age: Optional[int] = field(init=False) 

2507 continuous_distributions: Dict[str, Distribution] = field(init=False) 

2508 discrete_distributions: Dict[str, Distribution] = field(init=False) 

2509 population_parameters: List[Dict[str, Union[List[float], float]]] = field( 

2510 init=False 

2511 ) 

2512 agents: List[AgentType] = field(init=False) 

2513 agent_database: pd.DataFrame = field(init=False) 

2514 solution: List[Any] = field(init=False) 

2515 

2516 def __post_init__(self): 

2517 """ 

2518 Initialize the population of agents, determine distributed parameters, 

2519 and infer `agent_type_count` and `term_age`. 

2520 """ 

2521 # create a dummy agent and obtain its time-varying 

2522 # and time-invariant attributes 

2523 dummy_agent = self.agent_type() 

2524 self.time_var = dummy_agent.time_vary 

2525 self.time_inv = dummy_agent.time_inv 

2526 

2527 # create list of distributed parameters 

2528 # these are parameters that differ across agents 

2529 self.distributed_params = [ 

2530 key 

2531 for key, param in self.parameters.items() 

2532 if (isinstance(param, list) and isinstance(param[0], list)) 

2533 or isinstance(param, Distribution) 

2534 or (isinstance(param, DataArray) and param.dims[0] == "agent") 

2535 ] 

2536 

2537 self.__infer_counts__() 

2538 

2539 self.print_parallel_error_once = True 

2540 # Print warning once if parallel simulation fails 

2541 

2542 def __infer_counts__(self): 

2543 """ 

2544 Infer `agent_type_count` and `term_age` from the parameters. 

2545 If parameters include a `Distribution` type, a list of lists, 

2546 or a `DataArray` with `agent` as the first dimension, then 

2547 the AgentPopulation contains ex-ante heterogenous agents. 

2548 """ 

2549 

2550 # infer agent_type_count from distributed parameters 

2551 agent_type_count = 1 

2552 for key in self.distributed_params: 

2553 param = self.parameters[key] 

2554 if isinstance(param, Distribution): 

2555 agent_type_count = None 

2556 warn( 

2557 "Cannot infer agent_type_count from a Distribution. " 

2558 "Please provide approximation parameters." 

2559 ) 

2560 break 

2561 elif isinstance(param, list): 

2562 agent_type_count = max(agent_type_count, len(param)) 

2563 elif isinstance(param, DataArray) and param.dims[0] == "agent": 

2564 agent_type_count = max(agent_type_count, param.shape[0]) 

2565 

2566 self.agent_type_count = agent_type_count 

2567 

2568 # infer term_age from all parameters 

2569 term_age = 1 

2570 for param in self.parameters.values(): 

2571 if isinstance(param, Distribution): 

2572 term_age = None 

2573 warn( 

2574 "Cannot infer term_age from a Distribution. " 

2575 "Please provide approximation parameters." 

2576 ) 

2577 break 

2578 elif isinstance(param, list) and isinstance(param[0], list): 

2579 term_age = max(term_age, len(param[0])) 

2580 elif isinstance(param, DataArray) and param.dims[-1] == "age": 

2581 term_age = max(term_age, param.shape[-1]) 

2582 

2583 self.term_age = term_age 

2584 

2585 def approx_distributions(self, approx_params: dict): 

2586 """ 

2587 Approximate continuous distributions with discrete ones. If the initial 

2588 parameters include a `Distribution` type, then the AgentPopulation is 

2589 not ready to solve, and stands for an abstract population. To solve the 

2590 AgentPopulation, we need discretization parameters for each continuous 

2591 distribution. This method approximates the continuous distributions with 

2592 discrete ones, and updates the parameters dictionary. 

2593 """ 

2594 self.continuous_distributions = {} 

2595 self.discrete_distributions = {} 

2596 

2597 for key, args in approx_params.items(): 

2598 param = self.parameters[key] 

2599 if key in self.distributed_params and isinstance(param, Distribution): 

2600 self.continuous_distributions[key] = param 

2601 self.discrete_distributions[key] = param.discretize(**args) 

2602 else: 

2603 raise ValueError( 

2604 f"Warning: parameter {key} is not a Distribution found " 

2605 f"in agent type {self.agent_type}" 

2606 ) 

2607 

2608 if len(self.discrete_distributions) > 1: 

2609 joint_dist = combine_indep_dstns(*self.discrete_distributions.values()) 

2610 else: 

2611 joint_dist = list(self.discrete_distributions.values())[0] 

2612 

2613 for i, key in enumerate(self.discrete_distributions): 

2614 self.parameters[key] = DataArray(joint_dist.atoms[i], dims=("agent")) 

2615 

2616 self.__infer_counts__() 

2617 

2618 def __parse_parameters__(self) -> None: 

2619 """ 

2620 Creates distributed dictionaries of parameters for each ex-ante 

2621 heterogeneous agent in the parameterized population. The parameters 

2622 are stored in a list of dictionaries, where each dictionary contains 

2623 the parameters for one agent. Expands parameters that vary over time 

2624 to a list of length `term_age`. 

2625 """ 

2626 

2627 population_parameters = [] # container for dictionaries of each agent subgroup 

2628 for agent in range(self.agent_type_count): 

2629 agent_parameters = {} 

2630 for key, param in self.parameters.items(): 

2631 if key in self.time_var: 

2632 # parameters that vary over time have to be repeated 

2633 if isinstance(param, (int, float)): 

2634 parameter_per_t = [param] * self.term_age 

2635 elif isinstance(param, list): 

2636 if isinstance(param[0], list): 

2637 parameter_per_t = param[agent] 

2638 else: 

2639 parameter_per_t = param 

2640 elif isinstance(param, DataArray): 

2641 if param.dims[0] == "agent": 

2642 if param.dims[-1] == "age": 

2643 parameter_per_t = param[agent].item() 

2644 else: 

2645 parameter_per_t = param.item() 

2646 elif param.dims[0] == "age": 

2647 parameter_per_t = param.item() 

2648 

2649 agent_parameters[key] = parameter_per_t 

2650 

2651 elif key in self.time_inv: 

2652 if isinstance(param, (int, float)): 

2653 agent_parameters[key] = param 

2654 elif isinstance(param, list): 

2655 if isinstance(param[0], list): 

2656 agent_parameters[key] = param[agent] 

2657 else: 

2658 agent_parameters[key] = param 

2659 elif isinstance(param, DataArray) and param.dims[0] == "agent": 

2660 agent_parameters[key] = param[agent].item() 

2661 

2662 else: 

2663 if isinstance(param, (int, float)): 

2664 agent_parameters[key] = param # assume time inv 

2665 elif isinstance(param, list): 

2666 if isinstance(param[0], list): 

2667 agent_parameters[key] = param[agent] # assume agent vary 

2668 else: 

2669 agent_parameters[key] = param # assume time vary 

2670 elif isinstance(param, DataArray): 

2671 if param.dims[0] == "agent": 

2672 if param.dims[-1] == "age": 

2673 agent_parameters[key] = param[ 

2674 agent 

2675 ].item() # assume agent vary 

2676 else: 

2677 agent_parameters[key] = param.item() # assume time vary 

2678 elif param.dims[0] == "age": 

2679 agent_parameters[key] = param.item() # assume time vary 

2680 

2681 population_parameters.append(agent_parameters) 

2682 

2683 self.population_parameters = population_parameters 

2684 

2685 def create_distributed_agents(self): 

2686 """ 

2687 Parses the parameters dictionary and creates a list of agents with the 

2688 appropriate parameters. Also sets the seed for each agent. 

2689 """ 

2690 

2691 self.__parse_parameters__() 

2692 

2693 rng = np.random.default_rng(self.seed) 

2694 

2695 self.agents = [ 

2696 self.agent_type(seed=rng.integers(0, 2**31 - 1), **agent_dict) 

2697 for agent_dict in self.population_parameters 

2698 ] 

2699 

2700 def create_database(self): 

2701 """ 

2702 Optionally creates a pandas DataFrame with the parameters for each agent. 

2703 """ 

2704 database = pd.DataFrame(self.population_parameters) 

2705 database["agents"] = self.agents 

2706 

2707 self.agent_database = database 

2708 

2709 def solve(self): 

2710 """ 

2711 Solves each agent of the population serially. 

2712 """ 

2713 

2714 # see Market class for an example of how to solve distributed agents in parallel 

2715 

2716 for agent in self.agents: 

2717 agent.solve() 

2718 

2719 def unpack_solutions(self): 

2720 """ 

2721 Unpacks the solutions of each agent into an attribute of the population. 

2722 """ 

2723 self.solution = [agent.solution for agent in self.agents] 

2724 

2725 def initialize_sim(self): 

2726 """ 

2727 Initializes the simulation for each agent. 

2728 """ 

2729 for agent in self.agents: 

2730 agent.initialize_sim() 

2731 

2732 def simulate(self, num_jobs=None): 

2733 """ 

2734 Simulates each agent of the population. 

2735 

2736 Parameters 

2737 ---------- 

2738 num_jobs : int, optional 

2739 Number of parallel jobs to use. Defaults to using all available 

2740 cores when ``None``. Falls back to serial execution if parallel 

2741 processing fails. 

2742 """ 

2743 try: 

2744 multi_thread_commands(self.agents, ["simulate()"], num_jobs) 

2745 except Exception as err: 

2746 if getattr(self, "print_parallel_error_once", False): 

2747 self.print_parallel_error_once = False 

2748 print( 

2749 "**** WARNING: could not execute multi_thread_commands in HARK.core.AgentPopulation.simulate() ", 

2750 "so using the serial version instead. This will likely be slower. ", 

2751 "The multi_thread_commands() function failed with the following error:\n", 

2752 sys.exc_info()[0], 

2753 ":", 

2754 err, 

2755 ) 

2756 multi_thread_commands_fake(self.agents, ["simulate()"], num_jobs) 

2757 

2758 def __iter__(self): 

2759 """ 

2760 Allows for iteration over the agents in the population. 

2761 """ 

2762 return iter(self.agents) 

2763 

2764 def __getitem__(self, idx): 

2765 """ 

2766 Allows for indexing into the population. 

2767 """ 

2768 return self.agents[idx] 

2769 

2770 

2771############################################################################### 

2772 

2773 

2774def multi_thread_commands_fake( 

2775 agent_list: List, command_list: List, num_jobs=None 

2776) -> None: 

2777 """ 

2778 Executes the list of commands in command_list for each AgentType in agent_list 

2779 in an ordinary, single-threaded loop. Each command should be a method of 

2780 that AgentType subclass. This function exists so as to easily disable 

2781 multithreading, as it uses the same syntax as multi_thread_commands. 

2782 

2783 Parameters 

2784 ---------- 

2785 agent_list : [AgentType] 

2786 A list of instances of AgentType on which the commands will be run. 

2787 command_list : [string] 

2788 A list of commands to run for each AgentType. 

2789 num_jobs : None 

2790 Dummy input to match syntax of multi_thread_commands. Does nothing. 

2791 

2792 Returns 

2793 ------- 

2794 none 

2795 """ 

2796 for agent in agent_list: 

2797 for command in command_list: 

2798 # TODO: Code should be updated to pass in the method name instead of method() 

2799 getattr(agent, command[:-2])() 

2800 

2801 

2802def multi_thread_commands(agent_list: List, command_list: List, num_jobs=None) -> None: 

2803 """ 

2804 Executes the list of commands in command_list for each AgentType in agent_list 

2805 using a multithreaded system. Each command should be a method of that AgentType subclass. 

2806 

2807 Parameters 

2808 ---------- 

2809 agent_list : [AgentType] 

2810 A list of instances of AgentType on which the commands will be run. 

2811 command_list : [string] 

2812 A list of commands to run for each AgentType in agent_list. 

2813 

2814 Returns 

2815 ------- 

2816 None 

2817 """ 

2818 if len(agent_list) == 1: 

2819 multi_thread_commands_fake(agent_list, command_list) 

2820 return None 

2821 

2822 # Default number of parallel jobs is the smaller of number of AgentTypes in 

2823 # the input and the number of available cores. 

2824 if num_jobs is None: 

2825 num_jobs = min(len(agent_list), multiprocessing.cpu_count()) 

2826 

2827 # Send each command in command_list to each of the types in agent_list to be run 

2828 agent_list_out = Parallel(n_jobs=num_jobs)( 

2829 delayed(run_commands)(*args) 

2830 for args in zip(agent_list, len(agent_list) * [command_list]) 

2831 ) 

2832 

2833 # Replace the original types with the output from the parallel call 

2834 for j in range(len(agent_list)): 

2835 agent_list[j] = agent_list_out[j] 

2836 

2837 

2838def run_commands(agent: Any, command_list: List) -> Any: 

2839 """ 

2840 Executes each command in command_list on a given AgentType. The commands 

2841 should be methods of that AgentType's subclass. 

2842 

2843 Parameters 

2844 ---------- 

2845 agent : AgentType 

2846 An instance of AgentType on which the commands will be run. 

2847 command_list : [string] 

2848 A list of commands that the agent should run, as methods. 

2849 

2850 Returns 

2851 ------- 

2852 agent : AgentType 

2853 The same AgentType instance passed as input, after running the commands. 

2854 """ 

2855 for command in command_list: 

2856 # TODO: Code should be updated to pass in the method name instead of method() 

2857 getattr(agent, command[:-2])() 

2858 return agent