Coverage for HARK/core.py: 83%

925 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-02 05:14 +0000

1""" 

2High-level functions and classes for solving a wide variety of economic models. 

3The "core" of HARK is a framework for "microeconomic" and "macroeconomic" 

4models. A micro model concerns the dynamic optimization problem for some type 

5of agents, where agents take the inputs to their problem as exogenous. A macro 

6model adds an additional layer, endogenizing some of the inputs to the micro 

7problem by finding a general equilibrium dynamic rule. 

8""" 

9 

10# Set logging and define basic functions 

11import inspect 

12import logging 

13import sys 

14from collections import namedtuple 

15from copy import copy, deepcopy 

16from dataclasses import dataclass, field 

17from time import time 

18from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union 

19from warnings import warn 

20import multiprocessing 

21from joblib import Parallel, delayed 

22 

23import numpy as np 

24import pandas as pd 

25from xarray import DataArray 

26 

27from HARK.distributions import ( 

28 Distribution, 

29 IndexDistribution, 

30 combine_indep_dstns, 

31) 

32from HARK.utilities import NullFunc, get_arg_names, get_it_from 

33from HARK.simulator import make_simulator_from_agent 

34from HARK.SSJutils import ( 

35 make_basic_SSJ_matrices, 

36 calc_shock_response_manually, 

37) 

38 

39logging.basicConfig(format="%(message)s") 

40_log = logging.getLogger("HARK") 

41_log.setLevel(logging.ERROR) 

42 

43 

44def disable_logging(): 

45 _log.disabled = True 

46 

47 

48def enable_logging(): 

49 _log.disabled = False 

50 

51 

52def warnings(): 

53 _log.setLevel(logging.WARNING) 

54 

55 

56def quiet(): 

57 _log.setLevel(logging.ERROR) 

58 

59 

60def verbose(): 

61 _log.setLevel(logging.INFO) 

62 

63 

64def set_verbosity_level(level): 

65 _log.setLevel(level) 

66 

67 

68class Parameters: 

69 """ 

70 A smart container for model parameters that handles age-varying dynamics. 

71 

72 This class stores parameters as an internal dictionary and manages their 

73 age-varying properties, providing both attribute-style and dictionary-style 

74 access. It is designed to handle the time-varying dynamics of parameters 

75 in economic models. 

76 

77 Attributes 

78 ---------- 

79 _length : int 

80 The terminal age of the agents in the model. 

81 _invariant_params : Set[str] 

82 A set of parameter names that are invariant over time. 

83 _varying_params : Set[str] 

84 A set of parameter names that vary over time. 

85 _parameters : Dict[str, Any] 

86 The internal dictionary storing all parameters. 

87 """ 

88 

89 __slots__ = ("_length", "_invariant_params", "_varying_params", "_parameters") 

90 

91 def __init__(self, **parameters: Any) -> None: 

92 """ 

93 Initialize a Parameters object and parse the age-varying dynamics of parameters. 

94 

95 Parameters 

96 ---------- 

97 **parameters : Any 

98 Any number of parameters in the form key=value. 

99 """ 

100 self._length: int = parameters.pop("T_cycle", 1) 

101 self._invariant_params: Set[str] = set() 

102 self._varying_params: Set[str] = set() 

103 self._parameters: Dict[str, Any] = {"T_cycle": self._length} 

104 

105 for key, value in parameters.items(): 

106 self[key] = value 

107 

108 def __getitem__(self, item_or_key: Union[int, str]) -> Union["Parameters", Any]: 

109 """ 

110 Access parameters by age index or parameter name. 

111 

112 If item_or_key is an integer, returns a Parameters object with the parameters 

113 that apply to that age. This includes all invariant parameters and the 

114 `item_or_key`th element of all age-varying parameters. If item_or_key is a 

115 string, it returns the value of the parameter with that name. 

116 

117 Parameters 

118 ---------- 

119 item_or_key : Union[int, str] 

120 Age index or parameter name. 

121 

122 Returns 

123 ------- 

124 Union[Parameters, Any] 

125 A new Parameters object for the specified age, or the value of the 

126 specified parameter. 

127 

128 Raises 

129 ------ 

130 ValueError: 

131 If the age index is out of bounds. 

132 KeyError: 

133 If the parameter name is not found. 

134 TypeError: 

135 If the key is neither an integer nor a string. 

136 """ 

137 if isinstance(item_or_key, int): 

138 if item_or_key >= self._length: 

139 raise ValueError( 

140 f"Age {item_or_key} is out of bounds (max: {self._length - 1})." 

141 ) 

142 

143 params = {key: self._parameters[key] for key in self._invariant_params} 

144 params.update( 

145 { 

146 key: ( 

147 self._parameters[key][item_or_key] 

148 if isinstance(self._parameters[key], (list, tuple, np.ndarray)) 

149 else self._parameters[key] 

150 ) 

151 for key in self._varying_params 

152 } 

153 ) 

154 return Parameters(**params) 

155 elif isinstance(item_or_key, str): 

156 return self._parameters[item_or_key] 

157 else: 

158 raise TypeError("Key must be an integer (age) or string (parameter name).") 

159 

160 def __setitem__(self, key: str, value: Any) -> None: 

161 """ 

162 Set parameter values, automatically inferring time variance. 

163 

164 If the parameter is a scalar, numpy array, boolean, distribution, callable 

165 or None, it is assumed to be invariant over time. If the parameter is a 

166 list or tuple, it is assumed to be varying over time. If the parameter 

167 is a list or tuple of length greater than 1, the length of the list or 

168 tuple must match the `_length` attribute of the Parameters object. 

169 

170 Parameters 

171 ---------- 

172 key : str 

173 Name of the parameter. 

174 value : Any 

175 Value of the parameter. 

176 

177 Raises 

178 ------ 

179 ValueError: 

180 If the parameter name is not a string or if the value type is unsupported. 

181 If the parameter value is inconsistent with the current model length. 

182 """ 

183 if not isinstance(key, str): 

184 raise ValueError(f"Parameter name must be a string, got {type(key)}") 

185 

186 if isinstance( 

187 value, (int, float, np.ndarray, type(None), Distribution, bool, Callable) 

188 ): 

189 self._invariant_params.add(key) 

190 self._varying_params.discard(key) 

191 elif isinstance(value, (list, tuple)): 

192 if len(value) == 1: 

193 value = value[0] 

194 self._invariant_params.add(key) 

195 self._varying_params.discard(key) 

196 elif self._length is None or self._length == 1: 

197 self._length = len(value) 

198 self._varying_params.add(key) 

199 self._invariant_params.discard(key) 

200 elif len(value) == self._length: 

201 self._varying_params.add(key) 

202 self._invariant_params.discard(key) 

203 else: 

204 raise ValueError( 

205 f"Parameter {key} must have length 1 or {self._length}, not {len(value)}" 

206 ) 

207 else: 

208 raise ValueError(f"Unsupported type for parameter {key}: {type(value)}") 

209 

210 self._parameters[key] = value 

211 

212 def __iter__(self) -> Iterator[str]: 

213 """Allow iteration over parameter names.""" 

214 return iter(self._parameters) 

215 

216 def __len__(self) -> int: 

217 """Return the number of parameters.""" 

218 return len(self._parameters) 

219 

220 def keys(self) -> Iterator[str]: 

221 """Return a view of parameter names.""" 

222 return self._parameters.keys() 

223 

224 def values(self) -> Iterator[Any]: 

225 """Return a view of parameter values.""" 

226 return self._parameters.values() 

227 

228 def items(self) -> Iterator[Tuple[str, Any]]: 

229 """Return a view of parameter (name, value) pairs.""" 

230 return self._parameters.items() 

231 

232 def to_dict(self) -> Dict[str, Any]: 

233 """ 

234 Convert parameters to a plain dictionary. 

235 

236 Returns 

237 ------- 

238 Dict[str, Any] 

239 A dictionary containing all parameters. 

240 """ 

241 return dict(self._parameters) 

242 

243 def to_namedtuple(self) -> namedtuple: 

244 """ 

245 Convert parameters to a namedtuple. 

246 

247 Returns 

248 ------- 

249 namedtuple 

250 A namedtuple containing all parameters. 

251 """ 

252 return namedtuple("Parameters", self.keys())(**self.to_dict()) 

253 

254 def update(self, other: Union["Parameters", Dict[str, Any]]) -> None: 

255 """ 

256 Update parameters from another Parameters object or dictionary. 

257 

258 Parameters 

259 ---------- 

260 other : Union[Parameters, Dict[str, Any]] 

261 The source of parameters to update from. 

262 

263 Raises 

264 ------ 

265 TypeError 

266 If the input is neither a Parameters object nor a dictionary. 

267 """ 

268 if isinstance(other, Parameters): 

269 for key, value in other._parameters.items(): 

270 self[key] = value 

271 elif isinstance(other, dict): 

272 for key, value in other.items(): 

273 self[key] = value 

274 else: 

275 raise TypeError( 

276 "Update source must be a Parameters object or a dictionary." 

277 ) 

278 

279 def __repr__(self) -> str: 

280 """Return a detailed string representation of the Parameters object.""" 

281 return ( 

282 f"Parameters(_length={self._length}, " 

283 f"_invariant_params={self._invariant_params}, " 

284 f"_varying_params={self._varying_params}, " 

285 f"_parameters={self._parameters})" 

286 ) 

287 

288 def __str__(self) -> str: 

289 """Return a simple string representation of the Parameters object.""" 

290 return f"Parameters({str(self._parameters)})" 

291 

292 def __getattr__(self, name: str) -> Any: 

293 """ 

294 Allow attribute-style access to parameters. 

295 

296 Parameters 

297 ---------- 

298 name : str 

299 Name of the parameter to access. 

300 

301 Returns 

302 ------- 

303 Any 

304 The value of the specified parameter. 

305 

306 Raises 

307 ------ 

308 AttributeError: 

309 If the parameter name is not found. 

310 """ 

311 if name.startswith("_"): 

312 return super().__getattribute__(name) 

313 try: 

314 return self._parameters[name] 

315 except KeyError: 

316 raise AttributeError(f"'Parameters' object has no attribute '{name}'") 

317 

318 def __setattr__(self, name: str, value: Any) -> None: 

319 """ 

320 Allow attribute-style setting of parameters. 

321 

322 Parameters 

323 ---------- 

324 name : str 

325 Name of the parameter to set. 

326 value : Any 

327 Value to set for the parameter. 

328 """ 

329 if name.startswith("_"): 

330 super().__setattr__(name, value) 

331 else: 

332 self[name] = value 

333 

334 def __contains__(self, item: str) -> bool: 

335 """Check if a parameter exists in the Parameters object.""" 

336 return item in self._parameters 

337 

338 def copy(self) -> "Parameters": 

339 """ 

340 Create a deep copy of the Parameters object. 

341 

342 Returns 

343 ------- 

344 Parameters 

345 A new Parameters object with the same contents. 

346 """ 

347 return deepcopy(self) 

348 

349 def add_to_time_vary(self, *params: str) -> None: 

350 """ 

351 Adds any number of parameters to the time-varying set. 

352 

353 Parameters 

354 ---------- 

355 *params : str 

356 Any number of strings naming parameters to be added to time_vary. 

357 """ 

358 for param in params: 

359 if param in self._parameters: 

360 self._varying_params.add(param) 

361 self._invariant_params.discard(param) 

362 else: 

363 warn( 

364 f"Parameter '{param}' does not exist and cannot be added to time_vary." 

365 ) 

366 

367 def add_to_time_inv(self, *params: str) -> None: 

368 """ 

369 Adds any number of parameters to the time-invariant set. 

370 

371 Parameters 

372 ---------- 

373 *params : str 

374 Any number of strings naming parameters to be added to time_inv. 

375 """ 

376 for param in params: 

377 if param in self._parameters: 

378 self._invariant_params.add(param) 

379 self._varying_params.discard(param) 

380 else: 

381 warn( 

382 f"Parameter '{param}' does not exist and cannot be added to time_inv." 

383 ) 

384 

385 def del_from_time_vary(self, *params: str) -> None: 

386 """ 

387 Removes any number of parameters from the time-varying set. 

388 

389 Parameters 

390 ---------- 

391 *params : str 

392 Any number of strings naming parameters to be removed from time_vary. 

393 """ 

394 for param in params: 

395 self._varying_params.discard(param) 

396 

397 def del_from_time_inv(self, *params: str) -> None: 

398 """ 

399 Removes any number of parameters from the time-invariant set. 

400 

401 Parameters 

402 ---------- 

403 *params : str 

404 Any number of strings naming parameters to be removed from time_inv. 

405 """ 

406 for param in params: 

407 self._invariant_params.discard(param) 

408 

409 def get(self, key: str, default: Any = None) -> Any: 

410 """ 

411 Get a parameter value, returning a default if not found. 

412 

413 Parameters 

414 ---------- 

415 key : str 

416 The parameter name. 

417 default : Any, optional 

418 The default value to return if the key is not found. 

419 

420 Returns 

421 ------- 

422 Any 

423 The parameter value or the default. 

424 """ 

425 return self._parameters.get(key, default) 

426 

427 def set_many(self, **kwargs: Any) -> None: 

428 """ 

429 Set multiple parameters at once. 

430 

431 Parameters 

432 ---------- 

433 **kwargs : Keyword arguments representing parameter names and values. 

434 """ 

435 for key, value in kwargs.items(): 

436 self[key] = value 

437 

438 def is_time_varying(self, key: str) -> bool: 

439 """ 

440 Check if a parameter is time-varying. 

441 

442 Parameters 

443 ---------- 

444 key : str 

445 The parameter name. 

446 

447 Returns 

448 ------- 

449 bool 

450 True if the parameter is time-varying, False otherwise. 

451 """ 

452 return key in self._varying_params 

453 

454 

455class Model: 

456 """ 

457 A class with special handling of parameters assignment. 

458 """ 

459 

460 def __init__(self): 

461 if not hasattr(self, "parameters"): 

462 self.parameters = {} 

463 if not hasattr(self, "constructors"): 

464 self.constructors = {} 

465 

466 def assign_parameters(self, **kwds): 

467 """ 

468 Assign an arbitrary number of attributes to this agent. 

469 

470 Parameters 

471 ---------- 

472 **kwds : keyword arguments 

473 Any number of keyword arguments of the form key=value. 

474 Each value will be assigned to the attribute named in self. 

475 

476 Returns 

477 ------- 

478 None 

479 """ 

480 self.parameters.update(kwds) 

481 for key in kwds: 

482 setattr(self, key, kwds[key]) 

483 

484 def get_parameter(self, name): 

485 """ 

486 Returns a parameter of this model 

487 

488 Parameters 

489 ---------- 

490 name : str 

491 The name of the parameter to get 

492 

493 Returns 

494 ------- 

495 value : The value of the parameter 

496 """ 

497 return self.parameters[name] 

498 

499 def __eq__(self, other): 

500 if isinstance(other, type(self)): 

501 return self.parameters == other.parameters 

502 

503 return NotImplemented 

504 

505 def __str__(self): 

506 type_ = type(self) 

507 module = type_.__module__ 

508 qualname = type_.__qualname__ 

509 

510 s = f"<{module}.{qualname} object at {hex(id(self))}.\n" 

511 s += "Parameters:" 

512 

513 for p in self.parameters: 

514 s += f"\n{p}: {self.parameters[p]}" 

515 

516 s += ">" 

517 return s 

518 

519 def describe(self): 

520 return self.__str__() 

521 

522 def del_param(self, param_name): 

523 """ 

524 Deletes a parameter from this instance, removing it both from the object's 

525 namespace (if it's there) and the parameters dictionary (likewise). 

526 

527 Parameters 

528 ---------- 

529 param_name : str 

530 A string naming a parameter or data to be deleted from this instance. 

531 Removes information from self.parameters dictionary and own namespace. 

532 

533 Returns 

534 ------- 

535 None 

536 """ 

537 if param_name in self.parameters: 

538 del self.parameters[param_name] 

539 if hasattr(self, param_name): 

540 delattr(self, param_name) 

541 

542 def construct(self, *args, force=False): 

543 """ 

544 Top-level method for building constructed inputs. If called without any 

545 inputs, construct builds each of the objects named in the keys of the 

546 constructors dictionary; it draws inputs for the constructors from the 

547 parameters dictionary and adds its results to the same. If passed one or 

548 more strings as arguments, the method builds only the named keys. The 

549 method will do multiple "passes" over the requested keys, as some cons- 

550 tructors require inputs built by other constructors. If any requested 

551 constructors failed to build due to missing data, those keys (and the 

552 missing data) will be named in self._missing_key_data. Other errors are 

553 recorded in the dictionary attribute _constructor_errors. 

554 

555 This method tries to "start from scratch" by removing prior constructed 

556 objects, holding them in a backup dictionary during construction. This 

557 is done so that dependencies among constructors are resolved properly, 

558 without mistakenly relying on "old information". A backup value is used 

559 if a constructor function is set to None (i.e. "don't do anything"), or 

560 if the construct method fails to produce a new object. 

561 

562 Parameters 

563 ---------- 

564 *args : str, optional 

565 Keys of self.constructors that are requested to be constructed. 

566 If no arguments are passed, *all* elements of the dictionary are implied. 

567 force : bool, optional 

568 When True, the method will force its way past any errors, including 

569 missing constructors, missing arguments for constructors, and errors 

570 raised during execution of constructors. Information about all such 

571 errors is stored in the dictionary attributes described above. When 

572 False (default), any errors or exception will be raised. 

573 

574 Returns 

575 ------- 

576 None 

577 """ 

578 # Set up the requested work 

579 if len(args) > 0: 

580 keys = args 

581 else: 

582 keys = list(self.constructors.keys()) 

583 N_keys = len(keys) 

584 keys_complete = np.zeros(N_keys, dtype=bool) 

585 if N_keys == 0: 

586 return # Do nothing if there are no constructed objects 

587 

588 # Remove pre-existing constructed objects, preventing "incomplete" updates, 

589 # but store the current values in a backup dictionary in case something fails 

590 backup = {} 

591 for key in keys: 

592 if hasattr(self, key): 

593 backup[key] = getattr(self, key) 

594 self.del_param(key) 

595 

596 # Get the dictionary of constructor errors 

597 if not hasattr(self, "_constructor_errors"): 

598 self._constructor_errors = {} 

599 errors = self._constructor_errors 

600 

601 # As long as the work isn't complete and we made some progress on the last 

602 # pass, repeatedly perform passes of trying to construct objects 

603 any_keys_incomplete = np.any(np.logical_not(keys_complete)) 

604 go = any_keys_incomplete 

605 while go: 

606 anything_accomplished_this_pass = False # Nothing done yet! 

607 missing_key_data = [] # Keep this up-to-date on each pass 

608 

609 # Loop over keys to be constructed 

610 for i in range(N_keys): 

611 if keys_complete[i]: 

612 continue # This key has already been built 

613 

614 # Get this key and its constructor function 

615 key = keys[i] 

616 try: 

617 constructor = self.constructors[key] 

618 except Exception as not_found: 

619 errors[key] = "No constructor found for " + str(not_found) 

620 if force: 

621 continue 

622 else: 

623 raise ValueError("No constructor found for " + key) from None 

624 

625 # If this constructor is None, do nothing and mark it as completed; 

626 # this includes restoring the previous value if it exists 

627 if constructor is None: 

628 if key in backup.keys(): 

629 setattr(self, key, backup[key]) 

630 self.parameters[key] = backup[key] 

631 keys_complete[i] = True 

632 anything_accomplished_this_pass = True # We did something! 

633 continue 

634 

635 # SPECIAL: if the constructor is get_it_from, handle it separately 

636 if isinstance(constructor, get_it_from): 

637 try: 

638 parent = getattr(self, constructor.name) 

639 query = key 

640 any_missing = False 

641 missing_args = [] 

642 except: 

643 parent = None 

644 query = None 

645 any_missing = True 

646 missing_args = [constructor.name] 

647 temp_dict = {"parent": parent, "query": query} 

648 

649 # Get the names of arguments for this constructor and try to gather them 

650 else: # (if it's not the special case of get_it_from) 

651 args_needed = get_arg_names(constructor) 

652 has_no_default = { 

653 k: v.default is inspect.Parameter.empty 

654 for k, v in inspect.signature(constructor).parameters.items() 

655 } 

656 temp_dict = {} 

657 any_missing = False 

658 missing_args = [] 

659 for j in range(len(args_needed)): 

660 this_arg = args_needed[j] 

661 if hasattr(self, this_arg): 

662 temp_dict[this_arg] = getattr(self, this_arg) 

663 else: 

664 try: 

665 temp_dict[this_arg] = self.parameters[this_arg] 

666 except: 

667 if has_no_default[this_arg]: 

668 # Record missing key-data pair 

669 any_missing = True 

670 missing_key_data.append((key, this_arg)) 

671 missing_args.append(this_arg) 

672 

673 # If all of the required data was found, run the constructor and 

674 # store the result in parameters (and on self) 

675 if not any_missing: 

676 try: 

677 temp = constructor(**temp_dict) 

678 except Exception as problem: 

679 errors[key] = str(type(problem)) + ": " + str(problem) 

680 self.del_param(key) 

681 if force: 

682 continue 

683 else: 

684 raise 

685 setattr(self, key, temp) 

686 self.parameters[key] = temp 

687 if key in errors: 

688 del errors[key] 

689 keys_complete[i] = True 

690 anything_accomplished_this_pass = True # We did something! 

691 else: 

692 msg = "Missing required arguments:" 

693 for arg in missing_args: 

694 msg += " " + arg + "," 

695 msg = msg[:-1] 

696 errors[key] = msg 

697 self.del_param(key) 

698 # Never raise exceptions here, as the arguments might be filled in later 

699 

700 # Check whether another pass should be performed 

701 any_keys_incomplete = np.any(np.logical_not(keys_complete)) 

702 go = any_keys_incomplete and anything_accomplished_this_pass 

703 

704 # Store missing key-data pairs and exit 

705 self._missing_key_data = missing_key_data 

706 self._constructor_errors = errors 

707 if any_keys_incomplete: 

708 msg = "Did not construct these objects:" 

709 for i in range(N_keys): 

710 if keys_complete[i]: 

711 continue 

712 msg += " " + keys[i] + "," 

713 key = keys[i] 

714 if key in backup.keys(): 

715 setattr(self, key, backup[key]) 

716 self.parameters[key] = backup[key] 

717 msg = msg[:-1] 

718 if not force: 

719 raise ValueError(msg) 

720 return 

721 

722 def describe_constructors(self, *args): 

723 """ 

724 Prints to screen a string describing this instance's constructed objects, 

725 including their names, the function that constructs them, the names of 

726 those functions inputs, and whether those inputs are present. 

727 

728 Parameters 

729 ---------- 

730 *args : str, optional 

731 Optional list of strings naming constructed inputs to be described. 

732 If none are passed, all constructors are described. 

733 

734 Returns 

735 ------- 

736 None 

737 """ 

738 if len(args) > 0: 

739 keys = args 

740 else: 

741 keys = list(self.constructors.keys()) 

742 yes = "\u2713" 

743 no = "X" 

744 maybe = "*" 

745 noyes = [no, yes] 

746 

747 out = "" 

748 for key in keys: 

749 has_val = hasattr(self, key) or (key in self.parameters) 

750 

751 # Get the constructor function if possible 

752 try: 

753 constructor = self.constructors[key] 

754 out += ( 

755 noyes[int(has_val)] 

756 + " " 

757 + key 

758 + " : " 

759 + constructor.__name__ 

760 + "\n" 

761 ) 

762 except: 

763 if isinstance(constructor, get_it_from): 

764 parent_name = self.constructors[key].name 

765 out += ( 

766 noyes[int(has_val)] 

767 + " " 

768 + key 

769 + " : get it from " 

770 + parent_name 

771 + "\n" 

772 ) 

773 else: 

774 out += noyes[int(has_val)] + " " + key + " : NO CONSTRUCTOR FOUND\n" 

775 continue 

776 

777 # Get constructor argument names 

778 arg_names = get_arg_names(constructor) 

779 has_no_default = { 

780 k: v.default is inspect.Parameter.empty 

781 for k, v in inspect.signature(constructor).parameters.items() 

782 } 

783 

784 # Check whether each argument exists 

785 for j in range(len(arg_names)): 

786 this_arg = arg_names[j] 

787 if hasattr(self, this_arg) or this_arg in self.parameters: 

788 symb = yes 

789 elif not has_no_default[this_arg]: 

790 symb = maybe 

791 else: 

792 symb = no 

793 out += " " + symb + " " + this_arg + "\n" 

794 

795 # Print the string to screen 

796 print(out) 

797 return 

798 

799 # This is a "synonym" method so that old calls to update() still work 

800 def update(self, *args): 

801 self.construct(*args) 

802 

803 

804class AgentType(Model): 

805 """ 

806 A superclass for economic agents in the HARK framework. Each model should 

807 specify its own subclass of AgentType, inheriting its methods and overwriting 

808 as necessary. Critically, every subclass of AgentType should define class- 

809 specific static values of the attributes time_vary and time_inv as lists of 

810 strings. Each element of time_vary is the name of a field in AgentSubType 

811 that varies over time in the model. Each element of time_inv is the name of 

812 a field in AgentSubType that is constant over time in the model. 

813 

814 Parameters 

815 ---------- 

816 solution_terminal : Solution 

817 A representation of the solution to the terminal period problem of 

818 this AgentType instance, or an initial guess of the solution if this 

819 is an infinite horizon problem. 

820 cycles : int 

821 The number of times the sequence of periods is experienced by this 

822 AgentType in their "lifetime". cycles=1 corresponds to a lifecycle 

823 model, with a certain sequence of one period problems experienced 

824 once before terminating. cycles=0 corresponds to an infinite horizon 

825 model, with a sequence of one period problems repeating indefinitely. 

826 pseudo_terminal : bool 

827 Indicates whether solution_terminal isn't actually part of the 

828 solution to the problem (as a known solution to the terminal period 

829 problem), but instead represents a "scrap value"-style termination. 

830 When True, solution_terminal is not included in the solution; when 

831 False, solution_terminal is the last element of the solution. 

832 tolerance : float 

833 Maximum acceptable "distance" between successive solutions to the 

834 one period problem in an infinite horizon (cycles=0) model in order 

835 for the solution to be considered as having "converged". Inoperative 

836 when cycles>0. 

837 verbose : int 

838 Level of output to be displayed by this instance, default is 1. 

839 quiet : bool 

840 Indicator for whether this instance should operate "quietly", default False. 

841 seed : int 

842 A seed for this instance's random number generator. 

843 construct : bool 

844 Indicator for whether this instance's construct() method should be run 

845 when initialized (default True). When False, an instance of the class 

846 can be created even if not all of its attributes can be constructed. 

847 use_defaults : bool 

848 Indicator for whether this instance should use the values in the class' 

849 default dictionary to fill in parameters and constructors for those not 

850 provided by the user (default True). Setting this to False is useful for 

851 situations where the user wants to be absolutely sure that they know what 

852 is being passed to the class initializer, without resorting to defaults. 

853 

854 Attributes 

855 ---------- 

856 AgentCount : int 

857 The number of agents of this type to use in simulation. 

858 

859 state_vars : list of string 

860 The string labels for this AgentType's model state variables. 

861 """ 

862 

863 time_vary_ = [] 

864 time_inv_ = [] 

865 shock_vars_ = [] 

866 state_vars = [] 

867 poststate_vars = [] 

868 distributions = [] 

869 default_ = {"params": {}, "solver": NullFunc()} 

870 

871 def __init__( 

872 self, 

873 solution_terminal=None, 

874 pseudo_terminal=True, 

875 tolerance=0.000001, 

876 verbose=1, 

877 quiet=False, 

878 seed=0, 

879 construct=True, 

880 use_defaults=True, 

881 **kwds, 

882 ): 

883 super().__init__() 

884 params = deepcopy(self.default_["params"]) if use_defaults else {} 

885 params.update(kwds) 

886 

887 # Correctly handle constructors that have been passed in kwds 

888 if "constructors" in self.default_["params"].keys() and use_defaults: 

889 constructors = deepcopy(self.default_["params"]["constructors"]) 

890 else: 

891 constructors = {} 

892 if "constructors" in kwds.keys(): 

893 constructors.update(kwds["constructors"]) 

894 params["constructors"] = constructors 

895 

896 # Set model file name if possible 

897 try: 

898 self.model_file = copy(self.default_["model"]) 

899 except (KeyError, TypeError): 

900 # Fallback to None if "model" key is missing or invalid for copying 

901 self.model_file = None 

902 

903 if solution_terminal is None: 

904 solution_terminal = NullFunc() 

905 

906 self.solve_one_period = self.default_["solver"] # NOQA 

907 self.solution_terminal = solution_terminal # NOQA 

908 self.pseudo_terminal = pseudo_terminal # NOQA 

909 self.tolerance = tolerance # NOQA 

910 self.verbose = verbose 

911 self.quiet = quiet 

912 set_verbosity_level((4 - verbose) * 10) 

913 self.seed = seed # NOQA 

914 self.track_vars = [] # NOQA 

915 self.state_now = {sv: None for sv in self.state_vars} 

916 self.state_prev = self.state_now.copy() 

917 self.controls = {} 

918 self.shocks = {} 

919 self.read_shocks = False # NOQA 

920 self.shock_history = {} 

921 self.newborn_init_history = {} 

922 self.history = {} 

923 self.assign_parameters(**params) # NOQA 

924 self.reset_rng() # NOQA 

925 self.bilt = {} 

926 if construct: 

927 self.construct() 

928 

929 # Add instance-level lists and objects 

930 self.time_vary = deepcopy(self.time_vary_) 

931 self.time_inv = deepcopy(self.time_inv_) 

932 self.shock_vars = deepcopy(self.shock_vars_) 

933 

934 def add_to_time_vary(self, *params): 

935 """ 

936 Adds any number of parameters to time_vary for this instance. 

937 

938 Parameters 

939 ---------- 

940 params : string 

941 Any number of strings naming attributes to be added to time_vary 

942 

943 Returns 

944 ------- 

945 None 

946 """ 

947 for param in params: 

948 if param not in self.time_vary: 

949 self.time_vary.append(param) 

950 

951 def add_to_time_inv(self, *params): 

952 """ 

953 Adds any number of parameters to time_inv for this instance. 

954 

955 Parameters 

956 ---------- 

957 params : string 

958 Any number of strings naming attributes to be added to time_inv 

959 

960 Returns 

961 ------- 

962 None 

963 """ 

964 for param in params: 

965 if param not in self.time_inv: 

966 self.time_inv.append(param) 

967 

968 def del_from_time_vary(self, *params): 

969 """ 

970 Removes any number of parameters from time_vary for this instance. 

971 

972 Parameters 

973 ---------- 

974 params : string 

975 Any number of strings naming attributes to be removed from time_vary 

976 

977 Returns 

978 ------- 

979 None 

980 """ 

981 for param in params: 

982 if param in self.time_vary: 

983 self.time_vary.remove(param) 

984 

985 def del_from_time_inv(self, *params): 

986 """ 

987 Removes any number of parameters from time_inv for this instance. 

988 

989 Parameters 

990 ---------- 

991 params : string 

992 Any number of strings naming attributes to be removed from time_inv 

993 

994 Returns 

995 ------- 

996 None 

997 """ 

998 for param in params: 

999 if param in self.time_inv: 

1000 self.time_inv.remove(param) 

1001 

1002 def unpack(self, parameter): 

1003 """ 

1004 Unpacks an attribute from a solution object for easier access. 

1005 After the model has been solved, its components (like consumption function) 

1006 reside in the attributes of each element of `ThisType.solution` (e.g. `cFunc`). 

1007 This method creates a (time varying) attribute of the given attribute name 

1008 that contains a list of elements accessible by `ThisType.parameter`. 

1009 

1010 Parameters 

1011 ---------- 

1012 parameter: str 

1013 Name of the attribute to unpack from the solution 

1014 

1015 Returns 

1016 ------- 

1017 none 

1018 """ 

1019 setattr(self, parameter, list()) 

1020 for solution_t in self.solution: 

1021 self.__dict__[parameter].append(solution_t.__dict__[parameter]) 

1022 self.add_to_time_vary(parameter) 

1023 

1024 def solve(self, verbose=False, presolve=True, from_solution=None, from_t=None): 

1025 """ 

1026 Solve the model for this instance of an agent type by backward induction. 

1027 Loops through the sequence of one period problems, passing the solution 

1028 from period t+1 to the problem for period t. 

1029 

1030 Parameters 

1031 ---------- 

1032 verbose : bool, optional 

1033 If True, solution progress is printed to screen. Default False. 

1034 presolve : bool, optional 

1035 If True (default), the pre_solve method is run before solving. 

1036 from_solution: Solution 

1037 If different from None, will be used as the starting point of backward 

1038 induction, instead of self.solution_terminal. 

1039 from_t : int or None 

1040 If not None, indicates which period of the model the solver should start 

1041 from. It should usually only be used in combination with from_solution. 

1042 Stands for the time index that from_solution represents, and thus is 

1043 only compatible with cycles=1 and will be reset to None otherwise. 

1044 

1045 Returns 

1046 ------- 

1047 none 

1048 """ 

1049 

1050 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep- 

1051 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is 

1052 # -np.inf, np.inf/np.inf is np.nan and so on. 

1053 with np.errstate( 

1054 divide="ignore", over="ignore", under="ignore", invalid="ignore" 

1055 ): 

1056 if presolve: 

1057 self.pre_solve() # Do pre-solution stuff 

1058 self.solution = solve_agent( 

1059 self, 

1060 verbose, 

1061 from_solution, 

1062 from_t, 

1063 ) # Solve the model by backward induction 

1064 self.post_solve() # Do post-solution stuff 

1065 

1066 def reset_rng(self): 

1067 """ 

1068 Reset the random number generator and all distributions for this type. 

1069 Type-checking for lists is to handle the following three cases: 

1070 

1071 1) The target is a single distribution object 

1072 2) The target is a list of distribution objects (probably time-varying) 

1073 3) The target is a nested list of distributions, as in ConsMarkovModel. 

1074 """ 

1075 self.RNG = np.random.default_rng(self.seed) 

1076 for name in self.distributions: 

1077 if not hasattr(self, name): 

1078 continue 

1079 

1080 dstn = getattr(self, name) 

1081 if isinstance(dstn, list): 

1082 for D in dstn: 

1083 if isinstance(D, list): 

1084 for d in D: 

1085 d.reset() 

1086 else: 

1087 D.reset() 

1088 else: 

1089 dstn.reset() 

1090 

1091 def check_elements_of_time_vary_are_lists(self): 

1092 """ 

1093 A method to check that elements of time_vary are lists. 

1094 """ 

1095 for param in self.time_vary: 

1096 if not hasattr(self, param): 

1097 continue 

1098 if not isinstance( 

1099 getattr(self, param), 

1100 (IndexDistribution,), 

1101 ): 

1102 assert type(getattr(self, param)) == list, ( 

1103 param 

1104 + " is not a list or time varying distribution," 

1105 + " but should be because it is in time_vary" 

1106 ) 

1107 

1108 def check_restrictions(self): 

1109 """ 

1110 A method to check that various restrictions are met for the model class. 

1111 """ 

1112 return 

1113 

1114 def pre_solve(self): 

1115 """ 

1116 A method that is run immediately before the model is solved, to check inputs or to prepare 

1117 the terminal solution, perhaps. 

1118 

1119 Parameters 

1120 ---------- 

1121 none 

1122 

1123 Returns 

1124 ------- 

1125 none 

1126 """ 

1127 self.check_restrictions() 

1128 self.check_elements_of_time_vary_are_lists() 

1129 return None 

1130 

1131 def post_solve(self): 

1132 """ 

1133 A method that is run immediately after the model is solved, to finalize 

1134 the solution in some way. Does nothing here. 

1135 

1136 Parameters 

1137 ---------- 

1138 none 

1139 

1140 Returns 

1141 ------- 

1142 none 

1143 """ 

1144 return None 

1145 

1146 def initialize_sym(self, **kwargs): 

1147 """ 

1148 Use the new simulator structure to build a simulator from the agents' 

1149 attributes, storing it in a private attribute. 

1150 """ 

1151 self.reset_rng() # ensure seeds are set identically each time 

1152 self._simulator = make_simulator_from_agent(self, **kwargs) 

1153 self._simulator.reset() 

1154 

1155 def initialize_sim(self): 

1156 """ 

1157 Prepares this AgentType for a new simulation. Resets the internal random number generator, 

1158 makes initial states for all agents (using sim_birth), clears histories of tracked variables. 

1159 

1160 Parameters 

1161 ---------- 

1162 None 

1163 

1164 Returns 

1165 ------- 

1166 None 

1167 """ 

1168 if not hasattr(self, "T_sim"): 

1169 raise Exception( 

1170 "To initialize simulation variables it is necessary to first " 

1171 + "set the attribute T_sim to the largest number of observations " 

1172 + "you plan to simulate for each agent including re-births." 

1173 ) 

1174 elif self.T_sim <= 0: 

1175 raise Exception( 

1176 "T_sim represents the largest number of observations " 

1177 + "that can be simulated for an agent, and must be a positive number." 

1178 ) 

1179 

1180 self.reset_rng() 

1181 self.t_sim = 0 

1182 all_agents = np.ones(self.AgentCount, dtype=bool) 

1183 blank_array = np.empty(self.AgentCount) 

1184 blank_array[:] = np.nan 

1185 for var in self.state_now: 

1186 if self.state_now[var] is None: 

1187 self.state_now[var] = copy(blank_array) 

1188 

1189 # elif self.state_prev[var] is None: 

1190 # self.state_prev[var] = copy(blank_array) 

1191 self.t_age = np.zeros( 

1192 self.AgentCount, dtype=int 

1193 ) # Number of periods since agent entry 

1194 self.t_cycle = np.zeros( 

1195 self.AgentCount, dtype=int 

1196 ) # Which cycle period each agent is on 

1197 self.sim_birth(all_agents) 

1198 

1199 # If we are asked to use existing shocks and a set of initial conditions 

1200 # exist, use them 

1201 if self.read_shocks and bool(self.newborn_init_history): 

1202 for var_name in self.state_now: 

1203 # Check that we are actually given a value for the variable 

1204 if var_name in self.newborn_init_history.keys(): 

1205 # Copy only array-like idiosyncratic states. Aggregates should 

1206 # not be set by newborns 

1207 idio = ( 

1208 isinstance(self.state_now[var_name], np.ndarray) 

1209 and len(self.state_now[var_name]) == self.AgentCount 

1210 ) 

1211 if idio: 

1212 self.state_now[var_name] = self.newborn_init_history[var_name][ 

1213 0 

1214 ] 

1215 

1216 else: 

1217 warn( 

1218 "The option for reading shocks was activated but " 

1219 + "the model requires state " 

1220 + var_name 

1221 + ", not contained in " 

1222 + "newborn_init_history." 

1223 ) 

1224 

1225 self.clear_history() 

1226 return None 

1227 

1228 def sim_one_period(self): 

1229 """ 

1230 Simulates one period for this type. Calls the methods get_mortality(), get_shocks() or 

1231 read_shocks, get_states(), get_controls(), and get_poststates(). These should be defined for 

1232 AgentType subclasses, except get_mortality (define its components sim_death and sim_birth 

1233 instead) and read_shocks. 

1234 

1235 Parameters 

1236 ---------- 

1237 None 

1238 

1239 Returns 

1240 ------- 

1241 None 

1242 """ 

1243 if not hasattr(self, "solution"): 

1244 raise Exception( 

1245 "Model instance does not have a solution stored. To simulate, it is necessary" 

1246 " to run the `solve()` method first." 

1247 ) 

1248 

1249 # Mortality adjusts the agent population 

1250 self.get_mortality() # Replace some agents with "newborns" 

1251 

1252 # state_{t-1} 

1253 for var in self.state_now: 

1254 self.state_prev[var] = self.state_now[var] 

1255 

1256 if isinstance(self.state_now[var], np.ndarray): 

1257 self.state_now[var] = np.empty(self.AgentCount) 

1258 else: 

1259 # Probably an aggregate variable. It may be getting set by the Market. 

1260 pass 

1261 

1262 if self.read_shocks: # If shock histories have been pre-specified, use those 

1263 self.read_shocks_from_history() 

1264 else: # Otherwise, draw shocks as usual according to subclass-specific method 

1265 self.get_shocks() 

1266 self.get_states() # Determine each agent's state at decision time 

1267 self.get_controls() # Determine each agent's choice or control variables based on states 

1268 self.get_poststates() # Calculate variables that come *after* decision-time 

1269 

1270 # Advance time for all agents 

1271 self.t_age = self.t_age + 1 # Age all consumers by one period 

1272 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle 

1273 self.t_cycle[self.t_cycle == self.T_cycle] = ( 

1274 0 # Resetting to zero for those who have reached the end 

1275 ) 

1276 

1277 def make_shock_history(self): 

1278 """ 

1279 Makes a pre-specified history of shocks for the simulation. Shock variables should be named 

1280 in self.shock_vars, a list of strings that is subclass-specific. This method runs a subset 

1281 of the standard simulation loop by simulating only mortality and shocks; each variable named 

1282 in shock_vars is stored in a T_sim x AgentCount array in history dictionary self.history[X]. 

1283 Automatically sets self.read_shocks to True so that these pre-specified shocks are used for 

1284 all subsequent calls to simulate(). 

1285 

1286 Parameters 

1287 ---------- 

1288 None 

1289 

1290 Returns 

1291 ------- 

1292 None 

1293 """ 

1294 # Re-initialize the simulation 

1295 self.initialize_sim() 

1296 

1297 # Make blank history arrays for each shock variable (and mortality) 

1298 for var_name in self.shock_vars: 

1299 self.shock_history[var_name] = ( 

1300 np.zeros((self.T_sim, self.AgentCount)) + np.nan 

1301 ) 

1302 self.shock_history["who_dies"] = np.zeros( 

1303 (self.T_sim, self.AgentCount), dtype=bool 

1304 ) 

1305 

1306 # Also make blank arrays for the draws of newborns' initial conditions 

1307 for var_name in self.state_vars: 

1308 self.newborn_init_history[var_name] = ( 

1309 np.zeros((self.T_sim, self.AgentCount)) + np.nan 

1310 ) 

1311 

1312 # Record the initial condition of the newborns created by 

1313 # initialize_sim -> sim_births 

1314 for var_name in self.state_vars: 

1315 # Check whether the state is idiosyncratic or an aggregate 

1316 idio = ( 

1317 isinstance(self.state_now[var_name], np.ndarray) 

1318 and len(self.state_now[var_name]) == self.AgentCount 

1319 ) 

1320 if idio: 

1321 self.newborn_init_history[var_name][self.t_sim] = self.state_now[ 

1322 var_name 

1323 ] 

1324 else: 

1325 # Aggregate state is a scalar. Assign it to every agent. 

1326 self.newborn_init_history[var_name][self.t_sim, :] = self.state_now[ 

1327 var_name 

1328 ] 

1329 

1330 # Make and store the history of shocks for each period 

1331 for t in range(self.T_sim): 

1332 # Deaths 

1333 self.get_mortality() 

1334 self.shock_history["who_dies"][t, :] = self.who_dies 

1335 

1336 # Initial conditions of newborns 

1337 if self.who_dies.any(): 

1338 for var_name in self.state_vars: 

1339 # Check whether the state is idiosyncratic or an aggregate 

1340 idio = ( 

1341 isinstance(self.state_now[var_name], np.ndarray) 

1342 and len(self.state_now[var_name]) == self.AgentCount 

1343 ) 

1344 if idio: 

1345 self.newborn_init_history[var_name][t, self.who_dies] = ( 

1346 self.state_now[var_name][self.who_dies] 

1347 ) 

1348 else: 

1349 self.newborn_init_history[var_name][t, self.who_dies] = ( 

1350 self.state_now[var_name] 

1351 ) 

1352 

1353 # Other Shocks 

1354 self.get_shocks() 

1355 for var_name in self.shock_vars: 

1356 self.shock_history[var_name][t, :] = self.shocks[var_name] 

1357 

1358 self.t_sim += 1 

1359 self.t_age = self.t_age + 1 # Age all consumers by one period 

1360 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle 

1361 self.t_cycle[self.t_cycle == self.T_cycle] = ( 

1362 0 # Resetting to zero for those who have reached the end 

1363 ) 

1364 

1365 # Flag that shocks can be read rather than simulated 

1366 self.read_shocks = True 

1367 

1368 def get_mortality(self): 

1369 """ 

1370 Simulates mortality or agent turnover according to some model-specific rules named sim_death 

1371 and sim_birth (methods of an AgentType subclass). sim_death takes no arguments and returns 

1372 a Boolean array of size AgentCount, indicating which agents of this type have "died" and 

1373 must be replaced. sim_birth takes such a Boolean array as an argument and generates initial 

1374 post-decision states for those agent indices. 

1375 

1376 Parameters 

1377 ---------- 

1378 None 

1379 

1380 Returns 

1381 ------- 

1382 None 

1383 """ 

1384 if self.read_shocks: 

1385 who_dies = self.shock_history["who_dies"][self.t_sim, :] 

1386 # Instead of simulating births, assign the saved newborn initial conditions 

1387 if who_dies.any(): 

1388 for var_name in self.state_now: 

1389 if var_name in self.newborn_init_history.keys(): 

1390 # Copy only array-like idiosyncratic states. Aggregates should 

1391 # not be set by newborns 

1392 idio = ( 

1393 isinstance(self.state_now[var_name], np.ndarray) 

1394 and len(self.state_now[var_name]) == self.AgentCount 

1395 ) 

1396 if idio: 

1397 self.state_now[var_name][who_dies] = ( 

1398 self.newborn_init_history[var_name][ 

1399 self.t_sim, who_dies 

1400 ] 

1401 ) 

1402 

1403 else: 

1404 warn( 

1405 "The option for reading shocks was activated but " 

1406 + "the model requires state " 

1407 + var_name 

1408 + ", not contained in " 

1409 + "newborn_init_history." 

1410 ) 

1411 

1412 # Reset ages of newborns 

1413 self.t_age[who_dies] = 0 

1414 self.t_cycle[who_dies] = 0 

1415 else: 

1416 who_dies = self.sim_death() 

1417 self.sim_birth(who_dies) 

1418 self.who_dies = who_dies 

1419 return None 

1420 

1421 def sim_death(self): 

1422 """ 

1423 Determines which agents in the current population "die" or should be replaced. Takes no 

1424 inputs, returns a Boolean array of size self.AgentCount, which has True for agents who die 

1425 and False for those that survive. Returns all False by default, must be overwritten by a 

1426 subclass to have replacement events. 

1427 

1428 Parameters 

1429 ---------- 

1430 None 

1431 

1432 Returns 

1433 ------- 

1434 who_dies : np.array 

1435 Boolean array of size self.AgentCount indicating which agents die and are replaced. 

1436 """ 

1437 who_dies = np.zeros(self.AgentCount, dtype=bool) 

1438 return who_dies 

1439 

1440 def sim_birth(self, which_agents): 

1441 """ 

1442 Makes new agents for the simulation. Takes a boolean array as an input, indicating which 

1443 agent indices are to be "born". Does nothing by default, must be overwritten by a subclass. 

1444 

1445 Parameters 

1446 ---------- 

1447 which_agents : np.array(Bool) 

1448 Boolean array of size self.AgentCount indicating which agents should be "born". 

1449 

1450 Returns 

1451 ------- 

1452 None 

1453 """ 

1454 print("AgentType subclass must define method sim_birth!") 

1455 return None 

1456 

1457 def get_shocks(self): 

1458 """ 

1459 Gets values of shock variables for the current period. Does nothing by default, but can 

1460 be overwritten by subclasses of AgentType. 

1461 

1462 Parameters 

1463 ---------- 

1464 None 

1465 

1466 Returns 

1467 ------- 

1468 None 

1469 """ 

1470 return None 

1471 

1472 def read_shocks_from_history(self): 

1473 """ 

1474 Reads values of shock variables for the current period from history arrays. 

1475 For each variable X named in self.shock_vars, this attribute of self is 

1476 set to self.history[X][self.t_sim,:]. 

1477 

1478 This method is only ever called if self.read_shocks is True. This can 

1479 be achieved by using the method make_shock_history() (or manually after 

1480 storing a "handcrafted" shock history). 

1481 

1482 Parameters 

1483 ---------- 

1484 None 

1485 

1486 Returns 

1487 ------- 

1488 None 

1489 """ 

1490 for var_name in self.shock_vars: 

1491 self.shocks[var_name] = self.shock_history[var_name][self.t_sim, :] 

1492 

1493 def get_states(self): 

1494 """ 

1495 Gets values of state variables for the current period. 

1496 By default, calls transition function and assigns values 

1497 to the state_now dictionary. 

1498 

1499 Parameters 

1500 ---------- 

1501 None 

1502 

1503 Returns 

1504 ------- 

1505 None 

1506 """ 

1507 new_states = self.transition() 

1508 

1509 for i, var in enumerate(self.state_now): 

1510 # a hack for now to deal with 'post-states' 

1511 if i < len(new_states): 

1512 self.state_now[var] = new_states[i] 

1513 

1514 def transition(self): 

1515 """ 

1516 

1517 Parameters 

1518 ---------- 

1519 None 

1520 

1521 [Eventually, to match dolo spec: 

1522 exogenous_prev, endogenous_prev, controls, exogenous, parameters] 

1523 

1524 Returns 

1525 ------- 

1526 

1527 endogenous_state: () 

1528 Tuple with new values of the endogenous states 

1529 """ 

1530 return () 

1531 

1532 def get_controls(self): 

1533 """ 

1534 Gets values of control variables for the current period, probably by using current states. 

1535 Does nothing by default, but can be overwritten by subclasses of AgentType. 

1536 

1537 Parameters 

1538 ---------- 

1539 None 

1540 

1541 Returns 

1542 ------- 

1543 None 

1544 """ 

1545 return None 

1546 

1547 def get_poststates(self): 

1548 """ 

1549 Gets values of post-decision state variables for the current period, 

1550 probably by current 

1551 states and controls and maybe market-level events or shock variables. 

1552 Does nothing by 

1553 default, but can be overwritten by subclasses of AgentType. 

1554 

1555 Parameters 

1556 ---------- 

1557 None 

1558 

1559 Returns 

1560 ------- 

1561 None 

1562 """ 

1563 return None 

1564 

1565 def symulate(self, T=None): 

1566 """ 

1567 Run the new simulation structure, with history results written to the 

1568 hystory attribute of self. 

1569 """ 

1570 self._simulator.simulate(T) 

1571 self.hystory = self._simulator.history 

1572 

1573 def describe_model(self, display=True): 

1574 """ 

1575 Print to screen information about this agent's model, based on its model 

1576 file. This is useful for learning about outcome variable names for tracking 

1577 during simulation, or for use with sequence space Jacobians. 

1578 """ 

1579 if not hasattr(self, "_simulator"): 

1580 self.initialize_sym() 

1581 self._simulator.describe(display=display) 

1582 

1583 def simulate(self, sim_periods=None): 

1584 """ 

1585 Simulates this agent type for a given number of periods. Defaults to self.T_sim, 

1586 or all remaining periods to simulate (T_sim - t_sim). Records histories of 

1587 attributes named in self.track_vars in self.history[varname]. 

1588 

1589 Parameters 

1590 ---------- 

1591 sim_periods : int or None 

1592 Number of periods to simulate. Default is all remaining periods (usually T_sim). 

1593 

1594 Returns 

1595 ------- 

1596 history : dict 

1597 The history tracked during the simulation. 

1598 """ 

1599 if not hasattr(self, "t_sim"): 

1600 raise Exception( 

1601 "It seems that the simulation variables were not initialize before calling " 

1602 + "simulate(). Call initialize_sim() to initialize the variables before calling simulate() again." 

1603 ) 

1604 

1605 if not hasattr(self, "T_sim"): 

1606 raise Exception( 

1607 "This agent type instance must have the attribute T_sim set to a positive integer." 

1608 + "Set T_sim to match the largest dataset you might simulate, and run this agent's" 

1609 + "initialize_sim() method before running simulate() again." 

1610 ) 

1611 

1612 if sim_periods is not None and self.T_sim < sim_periods: 

1613 raise Exception( 

1614 "To simulate, sim_periods has to be larger than the maximum data set size " 

1615 + "T_sim. Either increase the attribute T_sim of this agent type instance " 

1616 + "and call the initialize_sim() method again, or set sim_periods <= T_sim." 

1617 ) 

1618 

1619 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep- 

1620 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is 

1621 # -np.inf, np.inf/np.inf is np.nan and so on. 

1622 with np.errstate( 

1623 divide="ignore", over="ignore", under="ignore", invalid="ignore" 

1624 ): 

1625 if sim_periods is None: 

1626 sim_periods = self.T_sim - self.t_sim 

1627 

1628 for t in range(sim_periods): 

1629 self.sim_one_period() 

1630 

1631 for var_name in self.track_vars: 

1632 if var_name in self.state_now: 

1633 self.history[var_name][self.t_sim, :] = self.state_now[var_name] 

1634 elif var_name in self.shocks: 

1635 self.history[var_name][self.t_sim, :] = self.shocks[var_name] 

1636 elif var_name in self.controls: 

1637 self.history[var_name][self.t_sim, :] = self.controls[var_name] 

1638 else: 

1639 if var_name == "who_dies" and self.t_sim > 1: 

1640 self.history[var_name][self.t_sim - 1, :] = getattr( 

1641 self, var_name 

1642 ) 

1643 else: 

1644 self.history[var_name][self.t_sim, :] = getattr( 

1645 self, var_name 

1646 ) 

1647 self.t_sim += 1 

1648 

1649 def clear_history(self): 

1650 """ 

1651 Clears the histories of the attributes named in self.track_vars. 

1652 

1653 Parameters 

1654 ---------- 

1655 None 

1656 

1657 Returns 

1658 ------- 

1659 None 

1660 """ 

1661 for var_name in self.track_vars: 

1662 self.history[var_name] = np.empty((self.T_sim, self.AgentCount)) 

1663 self.history[var_name].fill(np.nan) 

1664 

1665 def make_basic_SSJ(self, shock, outcomes, grids, **kwargs): 

1666 """ 

1667 Construct and return sequence space Jacobian matrices for specified outcomes 

1668 with respect to specified "shock" variable. This "basic" method only works 

1669 for "one period infinite horizon" models (cycles=0, T_cycle=1). See documen- 

1670 tation for simulator.make_basic_SSJ_matrices for more information. 

1671 """ 

1672 return make_basic_SSJ_matrices(self, shock, outcomes, grids, **kwargs) 

1673 

1674 def calc_impulse_response_manually(self, shock, outcomes, grids, **kwargs): 

1675 """ 

1676 Calculate and return the impulse response(s) of a perturbation to the shock 

1677 parameter in period t=s, essentially computing one column of the sequence 

1678 space Jacobian matrix manually. This "basic" method only works for "one 

1679 period infinite horizon" models (cycles=0, T_cycle=1). See documentation 

1680 for simulator.calc_shock_response_manually for more information. 

1681 """ 

1682 return calc_shock_response_manually(self, shock, outcomes, grids, **kwargs) 

1683 

1684 

1685def solve_agent(agent, verbose, from_solution=None, from_t=None): 

1686 """ 

1687 Solve the dynamic model for one agent type using backwards induction. This 

1688 function iterates on "cycles" of an agent's model either a given number of 

1689 times or until solution convergence if an infinite horizon model is used 

1690 (with agent.cycles = 0). 

1691 

1692 Parameters 

1693 ---------- 

1694 agent : AgentType 

1695 The microeconomic AgentType whose dynamic problem 

1696 is to be solved. 

1697 verbose : boolean 

1698 If True, solution progress is printed to screen (when cycles != 1). 

1699 from_solution: Solution 

1700 If different from None, will be used as the starting point of backward 

1701 induction, instead of self.solution_terminal 

1702 from_t : int or None 

1703 If not None, indicates which period of the model the solver should start 

1704 from. It should usually only be used in combination with from_solution. 

1705 Stands for the time index that from_solution represents, and thus is 

1706 only compatible with cycles=1 and will be reset to None otherwise. 

1707 

1708 Returns 

1709 ------- 

1710 solution : [Solution] 

1711 A list of solutions to the one period problems that the agent will 

1712 encounter in his "lifetime". 

1713 """ 

1714 # Check to see whether this is an (in)finite horizon problem 

1715 cycles_left = agent.cycles # NOQA 

1716 infinite_horizon = cycles_left == 0 # NOQA 

1717 

1718 if from_solution is None: 

1719 solution_last = agent.solution_terminal # NOQA 

1720 else: 

1721 solution_last = from_solution 

1722 if agent.cycles != 1: 

1723 from_t = None 

1724 

1725 # Initialize the solution, which includes the terminal solution if it's not a pseudo-terminal period 

1726 solution = [] 

1727 if not agent.pseudo_terminal: 

1728 solution.insert(0, deepcopy(solution_last)) 

1729 

1730 # Initialize the process, then loop over cycles 

1731 go = True # NOQA 

1732 completed_cycles = 0 # NOQA 

1733 max_cycles = 5000 # NOQA - escape clause 

1734 if verbose: 

1735 t_last = time() 

1736 while go: 

1737 # Solve a cycle of the model, recording it if horizon is finite 

1738 solution_cycle = solve_one_cycle(agent, solution_last, from_t) 

1739 if not infinite_horizon: 

1740 solution = solution_cycle + solution 

1741 

1742 # Check for termination: identical solutions across 

1743 # cycle iterations or run out of cycles 

1744 solution_now = solution_cycle[0] 

1745 if infinite_horizon: 

1746 if completed_cycles > 0: 

1747 solution_distance = solution_now.distance(solution_last) 

1748 agent.solution_distance = ( 

1749 solution_distance # Add these attributes so users can 

1750 ) 

1751 agent.completed_cycles = ( 

1752 completed_cycles # query them to see if solution is ready 

1753 ) 

1754 go = ( 

1755 solution_distance > agent.tolerance 

1756 and completed_cycles < max_cycles 

1757 ) 

1758 else: # Assume solution does not converge after only one cycle 

1759 solution_distance = 100.0 

1760 go = True 

1761 else: 

1762 cycles_left += -1 

1763 go = cycles_left > 0 

1764 

1765 # Update the "last period solution" 

1766 solution_last = solution_now 

1767 completed_cycles += 1 

1768 

1769 # Display progress if requested 

1770 if verbose: 

1771 t_now = time() 

1772 if infinite_horizon: 

1773 print( 

1774 "Finished cycle #" 

1775 + str(completed_cycles) 

1776 + " in " 

1777 + str(t_now - t_last) 

1778 + " seconds, solution distance = " 

1779 + str(solution_distance) 

1780 ) 

1781 else: 

1782 print( 

1783 "Finished cycle #" 

1784 + str(completed_cycles) 

1785 + " of " 

1786 + str(agent.cycles) 

1787 + " in " 

1788 + str(t_now - t_last) 

1789 + " seconds." 

1790 ) 

1791 t_last = t_now 

1792 

1793 # Record the last cycle if horizon is infinite (solution is still empty!) 

1794 if infinite_horizon: 

1795 solution = ( 

1796 solution_cycle # PseudoTerminal=False impossible for infinite horizon 

1797 ) 

1798 

1799 return solution 

1800 

1801 

1802def solve_one_cycle(agent, solution_last, from_t): 

1803 """ 

1804 Solve one "cycle" of the dynamic model for one agent type. This function 

1805 iterates over the periods within an agent's cycle, updating the time-varying 

1806 parameters and passing them to the single period solver(s). 

1807 

1808 Parameters 

1809 ---------- 

1810 agent : AgentType 

1811 The microeconomic AgentType whose dynamic problem is to be solved. 

1812 solution_last : Solution 

1813 A representation of the solution of the period that comes after the 

1814 end of the sequence of one period problems. This might be the term- 

1815 inal period solution, a "pseudo terminal" solution, or simply the 

1816 solution to the earliest period from the succeeding cycle. 

1817 from_t : int or None 

1818 If not None, indicates which period of the model the solver should start 

1819 from. When used, represents the time index that solution_last is from. 

1820 

1821 Returns 

1822 ------- 

1823 solution_cycle : [Solution] 

1824 A list of one period solutions for one "cycle" of the AgentType's 

1825 microeconomic model. 

1826 """ 

1827 

1828 # Check if the agent has a 'Parameters' attribute of the 'Parameters' class 

1829 # if so, take advantage of it. Else, use the old method 

1830 if hasattr(agent, "params") and isinstance(agent.params, Parameters): 

1831 T = agent.params._length if from_t is None else from_t 

1832 

1833 # Initialize the solution for this cycle, then iterate on periods 

1834 solution_cycle = [] 

1835 solution_next = solution_last 

1836 

1837 cycles_range = [0] + list(range(T - 1, 0, -1)) 

1838 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range: 

1839 # Update which single period solver to use (if it depends on time) 

1840 if hasattr(agent.solve_one_period, "__getitem__"): 

1841 solve_one_period = agent.solve_one_period[k] 

1842 else: 

1843 solve_one_period = agent.solve_one_period 

1844 

1845 if hasattr(solve_one_period, "solver_args"): 

1846 these_args = solve_one_period.solver_args 

1847 else: 

1848 these_args = get_arg_names(solve_one_period) 

1849 

1850 # Make a temporary dictionary for this period 

1851 temp_pars = agent.params[k] 

1852 temp_dict = { 

1853 name: solution_next if name == "solution_next" else temp_pars[name] 

1854 for name in these_args 

1855 } 

1856 

1857 # Solve one period, add it to the solution, and move to the next period 

1858 solution_t = solve_one_period(**temp_dict) 

1859 solution_cycle.insert(0, solution_t) 

1860 solution_next = solution_t 

1861 

1862 else: 

1863 # Calculate number of periods per cycle, defaults to 1 if all variables are time invariant 

1864 if len(agent.time_vary) > 0: 

1865 T = agent.T_cycle if from_t is None else from_t 

1866 else: 

1867 T = 1 

1868 

1869 solve_dict = { 

1870 parameter: agent.__dict__[parameter] for parameter in agent.time_inv 

1871 } 

1872 solve_dict.update({parameter: None for parameter in agent.time_vary}) 

1873 

1874 # Initialize the solution for this cycle, then iterate on periods 

1875 solution_cycle = [] 

1876 solution_next = solution_last 

1877 

1878 cycles_range = [0] + list(range(T - 1, 0, -1)) 

1879 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range: 

1880 # Update which single period solver to use (if it depends on time) 

1881 if hasattr(agent.solve_one_period, "__getitem__"): 

1882 solve_one_period = agent.solve_one_period[k] 

1883 else: 

1884 solve_one_period = agent.solve_one_period 

1885 

1886 if hasattr(solve_one_period, "solver_args"): 

1887 these_args = solve_one_period.solver_args 

1888 else: 

1889 these_args = get_arg_names(solve_one_period) 

1890 

1891 # Update time-varying single period inputs 

1892 for name in agent.time_vary: 

1893 if name in these_args: 

1894 solve_dict[name] = agent.__dict__[name][k] 

1895 solve_dict["solution_next"] = solution_next 

1896 

1897 # Make a temporary dictionary for this period 

1898 temp_dict = {name: solve_dict[name] for name in these_args} 

1899 

1900 # Solve one period, add it to the solution, and move to the next period 

1901 solution_t = solve_one_period(**temp_dict) 

1902 solution_cycle.insert(0, solution_t) 

1903 solution_next = solution_t 

1904 

1905 # Return the list of per-period solutions 

1906 return solution_cycle 

1907 

1908 

1909def make_one_period_oo_solver(solver_class): 

1910 """ 

1911 Returns a function that solves a single period consumption-saving 

1912 problem. 

1913 Parameters 

1914 ---------- 

1915 solver_class : Solver 

1916 A class of Solver to be used. 

1917 ------- 

1918 solver_function : function 

1919 A function for solving one period of a problem. 

1920 """ 

1921 

1922 def one_period_solver(**kwds): 

1923 solver = solver_class(**kwds) 

1924 

1925 # not ideal; better if this is defined in all Solver classes 

1926 if hasattr(solver, "prepare_to_solve"): 

1927 solver.prepare_to_solve() 

1928 

1929 solution_now = solver.solve() 

1930 return solution_now 

1931 

1932 one_period_solver.solver_class = solver_class 

1933 # This can be revisited once it is possible to export parameters 

1934 one_period_solver.solver_args = get_arg_names(solver_class.__init__)[1:] 

1935 

1936 return one_period_solver 

1937 

1938 

1939# ======================================================================== 

1940# ======================================================================== 

1941 

1942 

1943class Market(Model): 

1944 """ 

1945 A superclass to represent a central clearinghouse of information. Used for 

1946 dynamic general equilibrium models to solve the "macroeconomic" model as a 

1947 layer on top of the "microeconomic" models of one or more AgentTypes. 

1948 

1949 Parameters 

1950 ---------- 

1951 agents : [AgentType] 

1952 A list of all the AgentTypes in this market. 

1953 sow_vars : [string] 

1954 Names of variables generated by the "aggregate market process" that should 

1955 be "sown" to the agents in the market. Aggregate state, etc. 

1956 reap_vars : [string] 

1957 Names of variables to be collected ("reaped") from agents in the market 

1958 to be used in the "aggregate market process". 

1959 const_vars : [string] 

1960 Names of attributes of the Market instance that are used in the "aggregate 

1961 market process" but do not come from agents-- they are constant or simply 

1962 parameters inherent to the process. 

1963 track_vars : [string] 

1964 Names of variables generated by the "aggregate market process" that should 

1965 be tracked as a "history" so that a new dynamic rule can be calculated. 

1966 This is often a subset of sow_vars. 

1967 dyn_vars : [string] 

1968 Names of variables that constitute a "dynamic rule". 

1969 mill_rule : function 

1970 A function that takes inputs named in reap_vars and returns a tuple the 

1971 same size and order as sow_vars. The "aggregate market process" that 

1972 transforms individual agent actions/states/data into aggregate data to 

1973 be sent back to agents. 

1974 calc_dynamics : function 

1975 A function that takes inputs named in track_vars and returns an object 

1976 with attributes named in dyn_vars. Looks at histories of aggregate 

1977 variables and generates a new "dynamic rule" for agents to believe and 

1978 act on. 

1979 act_T : int 

1980 The number of times that the "aggregate market process" should be run 

1981 in order to generate a history of aggregate variables. 

1982 tolerance: float 

1983 Minimum acceptable distance between "dynamic rules" to consider the 

1984 Market solution process converged. Distance is a user-defined metric. 

1985 """ 

1986 

1987 def __init__( 

1988 self, 

1989 agents=None, 

1990 sow_vars=None, 

1991 reap_vars=None, 

1992 const_vars=None, 

1993 track_vars=None, 

1994 dyn_vars=None, 

1995 mill_rule=None, 

1996 calc_dynamics=None, 

1997 act_T=1000, 

1998 tolerance=0.000001, 

1999 **kwds, 

2000 ): 

2001 super().__init__() 

2002 self.agents = agents if agents is not None else list() # NOQA 

2003 

2004 self.reap_vars = reap_vars if reap_vars is not None else list() # NOQA 

2005 self.reap_state = {var: [] for var in self.reap_vars} 

2006 

2007 self.sow_vars = sow_vars if sow_vars is not None else list() # NOQA 

2008 # dictionaries for tracking initial and current values 

2009 # of the sow variables. 

2010 self.sow_init = {var: None for var in self.sow_vars} 

2011 self.sow_state = {var: None for var in self.sow_vars} 

2012 

2013 const_vars = const_vars if const_vars is not None else list() # NOQA 

2014 self.const_vars = {var: None for var in const_vars} 

2015 

2016 self.track_vars = track_vars if track_vars is not None else list() # NOQA 

2017 self.dyn_vars = dyn_vars if dyn_vars is not None else list() # NOQA 

2018 

2019 if mill_rule is not None: # To prevent overwriting of method-based mill_rules 

2020 self.mill_rule = mill_rule 

2021 if calc_dynamics is not None: # Ditto for calc_dynamics 

2022 self.calc_dynamics = calc_dynamics 

2023 self.act_T = act_T # NOQA 

2024 self.tolerance = tolerance # NOQA 

2025 self.max_loops = 1000 # NOQA 

2026 self.history = {} 

2027 self.assign_parameters(**kwds) 

2028 

2029 self.print_parallel_error_once = True 

2030 # Print the error associated with calling the parallel method 

2031 # "solve_agents" one time. If set to false, the error will never 

2032 # print. See "solve_agents" for why this prints once or never. 

2033 

2034 def solve_agents(self): 

2035 """ 

2036 Solves the microeconomic problem for all AgentTypes in this market. 

2037 

2038 Parameters 

2039 ---------- 

2040 None 

2041 

2042 Returns 

2043 ------- 

2044 None 

2045 """ 

2046 try: 

2047 multi_thread_commands(self.agents, ["solve()"]) 

2048 except Exception as err: 

2049 if self.print_parallel_error_once: 

2050 # Set flag to False so this is only printed once. 

2051 self.print_parallel_error_once = False 

2052 print( 

2053 "**** WARNING: could not execute multi_thread_commands in HARK.core.Market.solve_agents() ", 

2054 "so using the serial version instead. This will likely be slower. " 

2055 "The multiTreadCommands() functions failed with the following error:", 

2056 "\n", 

2057 sys.exc_info()[0], 

2058 ":", 

2059 err, 

2060 ) # sys.exc_info()[0]) 

2061 multi_thread_commands_fake(self.agents, ["solve()"]) 

2062 

2063 def solve(self): 

2064 """ 

2065 "Solves" the market by finding a "dynamic rule" that governs the aggregate 

2066 market state such that when agents believe in these dynamics, their actions 

2067 collectively generate the same dynamic rule. 

2068 

2069 Parameters 

2070 ---------- 

2071 None 

2072 

2073 Returns 

2074 ------- 

2075 None 

2076 """ 

2077 go = True 

2078 max_loops = self.max_loops # Failsafe against infinite solution loop 

2079 completed_loops = 0 

2080 old_dynamics = None 

2081 

2082 while go: # Loop until the dynamic process converges or we hit the loop cap 

2083 self.solve_agents() # Solve each AgentType's micro problem 

2084 self.make_history() # "Run" the model while tracking aggregate variables 

2085 new_dynamics = self.update_dynamics() # Find a new aggregate dynamic rule 

2086 

2087 # Check to see if the dynamic rule has converged (if this is not the first loop) 

2088 if completed_loops > 0: 

2089 distance = new_dynamics.distance(old_dynamics) 

2090 else: 

2091 distance = 1000000.0 

2092 

2093 # Move to the next loop if the terminal conditions are not met 

2094 old_dynamics = new_dynamics 

2095 completed_loops += 1 

2096 go = distance >= self.tolerance and completed_loops < max_loops 

2097 

2098 self.dynamics = new_dynamics # Store the final dynamic rule in self 

2099 

2100 def reap(self): 

2101 """ 

2102 Collects attributes named in reap_vars from each AgentType in the market, 

2103 storing them in respectively named attributes of self. 

2104 

2105 Parameters 

2106 ---------- 

2107 none 

2108 

2109 Returns 

2110 ------- 

2111 none 

2112 """ 

2113 for var in self.reap_state: 

2114 harvest = [] 

2115 

2116 for agent in self.agents: 

2117 # TODO: generalized variable lookup across namespaces 

2118 if var in agent.state_now: 

2119 # or state_now ?? 

2120 harvest.append(agent.state_now[var]) 

2121 

2122 self.reap_state[var] = harvest 

2123 

2124 def sow(self): 

2125 """ 

2126 Distributes attrributes named in sow_vars from self to each AgentType 

2127 in the market, storing them in respectively named attributes. 

2128 

2129 Parameters 

2130 ---------- 

2131 none 

2132 

2133 Returns 

2134 ------- 

2135 none 

2136 """ 

2137 for sow_var in self.sow_state: 

2138 for this_type in self.agents: 

2139 if sow_var in this_type.state_now: 

2140 this_type.state_now[sow_var] = self.sow_state[sow_var] 

2141 if sow_var in this_type.shocks: 

2142 this_type.shocks[sow_var] = self.sow_state[sow_var] 

2143 else: 

2144 setattr(this_type, sow_var, self.sow_state[sow_var]) 

2145 

2146 def mill(self): 

2147 """ 

2148 Processes the variables collected from agents using the function mill_rule, 

2149 storing the results in attributes named in aggr_sow. 

2150 

2151 Parameters 

2152 ---------- 

2153 none 

2154 

2155 Returns 

2156 ------- 

2157 none 

2158 """ 

2159 # Make a dictionary of inputs for the mill_rule 

2160 mill_dict = copy(self.reap_state) 

2161 mill_dict.update(self.const_vars) 

2162 

2163 # Run the mill_rule and store its output in self 

2164 product = self.mill_rule(**mill_dict) 

2165 

2166 for i, sow_var in enumerate(self.sow_state): 

2167 self.sow_state[sow_var] = product[i] 

2168 

2169 def cultivate(self): 

2170 """ 

2171 Has each AgentType in agents perform their market_action method, using 

2172 variables sown from the market (and maybe also "private" variables). 

2173 The market_action method should store new results in attributes named in 

2174 reap_vars to be reaped later. 

2175 

2176 Parameters 

2177 ---------- 

2178 none 

2179 

2180 Returns 

2181 ------- 

2182 none 

2183 """ 

2184 for this_type in self.agents: 

2185 this_type.market_action() 

2186 

2187 def reset(self): 

2188 """ 

2189 Reset the state of the market (attributes in sow_vars, etc) to some 

2190 user-defined initial state, and erase the histories of tracked variables. 

2191 

2192 Parameters 

2193 ---------- 

2194 none 

2195 

2196 Returns 

2197 ------- 

2198 none 

2199 """ 

2200 # Reset the history of tracked variables 

2201 self.history = {var_name: [] for var_name in self.track_vars} 

2202 

2203 # Set the sow variables to their initial levels 

2204 for var_name in self.sow_state: 

2205 self.sow_state[var_name] = self.sow_init[var_name] 

2206 

2207 # Reset each AgentType in the market 

2208 for this_type in self.agents: 

2209 this_type.reset() 

2210 

2211 def store(self): 

2212 """ 

2213 Record the current value of each variable X named in track_vars in an 

2214 dictionary field named history[X]. 

2215 

2216 Parameters 

2217 ---------- 

2218 none 

2219 

2220 Returns 

2221 ------- 

2222 none 

2223 """ 

2224 for var_name in self.track_vars: 

2225 if var_name in self.sow_state: 

2226 value_now = self.sow_state[var_name] 

2227 elif var_name in self.reap_state: 

2228 value_now = self.reap_state[var_name] 

2229 elif var_name in self.const_vars: 

2230 value_now = self.const_vars[var_name] 

2231 else: 

2232 value_now = getattr(self, var_name) 

2233 

2234 self.history[var_name].append(value_now) 

2235 

2236 def make_history(self): 

2237 """ 

2238 Runs a loop of sow-->cultivate-->reap-->mill act_T times, tracking the 

2239 evolution of variables X named in track_vars in dictionary fields 

2240 history[X]. 

2241 

2242 Parameters 

2243 ---------- 

2244 none 

2245 

2246 Returns 

2247 ------- 

2248 none 

2249 """ 

2250 self.reset() # Initialize the state of the market 

2251 for t in range(self.act_T): 

2252 self.sow() # Distribute aggregated information/state to agents 

2253 self.cultivate() # Agents take action 

2254 self.reap() # Collect individual data from agents 

2255 self.mill() # Process individual data into aggregate data 

2256 self.store() # Record variables of interest 

2257 

2258 def update_dynamics(self): 

2259 """ 

2260 Calculates a new "aggregate dynamic rule" using the history of variables 

2261 named in track_vars, and distributes this rule to AgentTypes in agents. 

2262 

2263 Parameters 

2264 ---------- 

2265 none 

2266 

2267 Returns 

2268 ------- 

2269 dynamics : instance 

2270 The new "aggregate dynamic rule" that agents believe in and act on. 

2271 Should have attributes named in dyn_vars. 

2272 """ 

2273 # Make a dictionary of inputs for the dynamics calculator 

2274 arg_names = list(get_arg_names(self.calc_dynamics)) 

2275 if "self" in arg_names: 

2276 arg_names.remove("self") 

2277 update_dict = {name: self.history[name] for name in arg_names} 

2278 # Calculate a new dynamic rule and distribute it to the agents in agent_list 

2279 dynamics = self.calc_dynamics(**update_dict) # User-defined dynamics calculator 

2280 for var_name in self.dyn_vars: 

2281 this_obj = getattr(dynamics, var_name) 

2282 for this_type in self.agents: 

2283 setattr(this_type, var_name, this_obj) 

2284 return dynamics 

2285 

2286 

2287def distribute_params(agent, param_name, param_count, distribution): 

2288 """ 

2289 Distributes heterogeneous values of one parameter to the AgentTypes in self.agents. 

2290 Parameters 

2291 ---------- 

2292 agent: AgentType 

2293 An agent to clone. 

2294 param_name : string 

2295 Name of the parameter to be assigned. 

2296 param_count : int 

2297 Number of different values the parameter will take on. 

2298 distribution : Distribution 

2299 A 1-D distribution. 

2300 

2301 Returns 

2302 ------- 

2303 agent_set : [AgentType] 

2304 A list of param_count agents, ex ante heterogeneous with 

2305 respect to param_name. The AgentCount of the original 

2306 will be split between the agents of the returned 

2307 list in proportion to the given distribution. 

2308 """ 

2309 param_dist = distribution.discretize(N=param_count) 

2310 

2311 agent_set = [deepcopy(agent) for i in range(param_count)] 

2312 

2313 for j in range(param_count): 

2314 agent_set[j].assign_parameters( 

2315 **{"AgentCount": int(agent.AgentCount * param_dist.pmv[j])} 

2316 ) 

2317 # agent_set[j].__dict__[param_name] = param_dist.atoms[j] 

2318 

2319 agent_set[j].assign_parameters(**{param_name: param_dist.atoms[0, j]}) 

2320 

2321 return agent_set 

2322 

2323 

2324@dataclass 

2325class AgentPopulation: 

2326 """ 

2327 A class for representing a population of ex-ante heterogeneous agents. 

2328 """ 

2329 

2330 agent_type: AgentType # type of agent in the population 

2331 parameters: dict # dictionary of parameters 

2332 seed: int = 0 # random seed 

2333 time_var: List[str] = field(init=False) 

2334 time_inv: List[str] = field(init=False) 

2335 distributed_params: List[str] = field(init=False) 

2336 agent_type_count: Optional[int] = field(init=False) 

2337 term_age: Optional[int] = field(init=False) 

2338 continuous_distributions: Dict[str, Distribution] = field(init=False) 

2339 discrete_distributions: Dict[str, Distribution] = field(init=False) 

2340 population_parameters: List[Dict[str, Union[List[float], float]]] = field( 

2341 init=False 

2342 ) 

2343 agents: List[AgentType] = field(init=False) 

2344 agent_database: pd.DataFrame = field(init=False) 

2345 solution: List[Any] = field(init=False) 

2346 

2347 def __post_init__(self): 

2348 """ 

2349 Initialize the population of agents, determine distributed parameters, 

2350 and infer `agent_type_count` and `term_age`. 

2351 """ 

2352 # create a dummy agent and obtain its time-varying 

2353 # and time-invariant attributes 

2354 dummy_agent = self.agent_type() 

2355 self.time_var = dummy_agent.time_vary 

2356 self.time_inv = dummy_agent.time_inv 

2357 

2358 # create list of distributed parameters 

2359 # these are parameters that differ across agents 

2360 self.distributed_params = [ 

2361 key 

2362 for key, param in self.parameters.items() 

2363 if (isinstance(param, list) and isinstance(param[0], list)) 

2364 or isinstance(param, Distribution) 

2365 or (isinstance(param, DataArray) and param.dims[0] == "agent") 

2366 ] 

2367 

2368 self.__infer_counts__() 

2369 

2370 self.print_parallel_error_once = True 

2371 # Print warning once if parallel simulation fails 

2372 

2373 def __infer_counts__(self): 

2374 """ 

2375 Infer `agent_type_count` and `term_age` from the parameters. 

2376 If parameters include a `Distribution` type, a list of lists, 

2377 or a `DataArray` with `agent` as the first dimension, then 

2378 the AgentPopulation contains ex-ante heterogenous agents. 

2379 """ 

2380 

2381 # infer agent_type_count from distributed parameters 

2382 agent_type_count = 1 

2383 for key in self.distributed_params: 

2384 param = self.parameters[key] 

2385 if isinstance(param, Distribution): 

2386 agent_type_count = None 

2387 warn( 

2388 "Cannot infer agent_type_count from a Distribution. " 

2389 "Please provide approximation parameters." 

2390 ) 

2391 break 

2392 elif isinstance(param, list): 

2393 agent_type_count = max(agent_type_count, len(param)) 

2394 elif isinstance(param, DataArray) and param.dims[0] == "agent": 

2395 agent_type_count = max(agent_type_count, param.shape[0]) 

2396 

2397 self.agent_type_count = agent_type_count 

2398 

2399 # infer term_age from all parameters 

2400 term_age = 1 

2401 for param in self.parameters.values(): 

2402 if isinstance(param, Distribution): 

2403 term_age = None 

2404 warn( 

2405 "Cannot infer term_age from a Distribution. " 

2406 "Please provide approximation parameters." 

2407 ) 

2408 break 

2409 elif isinstance(param, list) and isinstance(param[0], list): 

2410 term_age = max(term_age, len(param[0])) 

2411 elif isinstance(param, DataArray) and param.dims[-1] == "age": 

2412 term_age = max(term_age, param.shape[-1]) 

2413 

2414 self.term_age = term_age 

2415 

2416 def approx_distributions(self, approx_params: dict): 

2417 """ 

2418 Approximate continuous distributions with discrete ones. If the initial 

2419 parameters include a `Distribution` type, then the AgentPopulation is 

2420 not ready to solve, and stands for an abstract population. To solve the 

2421 AgentPopulation, we need discretization parameters for each continuous 

2422 distribution. This method approximates the continuous distributions with 

2423 discrete ones, and updates the parameters dictionary. 

2424 """ 

2425 self.continuous_distributions = {} 

2426 self.discrete_distributions = {} 

2427 

2428 for key, args in approx_params.items(): 

2429 param = self.parameters[key] 

2430 if key in self.distributed_params and isinstance(param, Distribution): 

2431 self.continuous_distributions[key] = param 

2432 self.discrete_distributions[key] = param.discretize(**args) 

2433 else: 

2434 raise ValueError( 

2435 f"Warning: parameter {key} is not a Distribution found " 

2436 f"in agent type {self.agent_type}" 

2437 ) 

2438 

2439 if len(self.discrete_distributions) > 1: 

2440 joint_dist = combine_indep_dstns(*self.discrete_distributions.values()) 

2441 else: 

2442 joint_dist = list(self.discrete_distributions.values())[0] 

2443 

2444 for i, key in enumerate(self.discrete_distributions): 

2445 self.parameters[key] = DataArray(joint_dist.atoms[i], dims=("agent")) 

2446 

2447 self.__infer_counts__() 

2448 

2449 def __parse_parameters__(self) -> None: 

2450 """ 

2451 Creates distributed dictionaries of parameters for each ex-ante 

2452 heterogeneous agent in the parameterized population. The parameters 

2453 are stored in a list of dictionaries, where each dictionary contains 

2454 the parameters for one agent. Expands parameters that vary over time 

2455 to a list of length `term_age`. 

2456 """ 

2457 

2458 population_parameters = [] # container for dictionaries of each agent subgroup 

2459 for agent in range(self.agent_type_count): 

2460 agent_parameters = {} 

2461 for key, param in self.parameters.items(): 

2462 if key in self.time_var: 

2463 # parameters that vary over time have to be repeated 

2464 if isinstance(param, (int, float)): 

2465 parameter_per_t = [param] * self.term_age 

2466 elif isinstance(param, list): 

2467 if isinstance(param[0], list): 

2468 parameter_per_t = param[agent] 

2469 else: 

2470 parameter_per_t = param 

2471 elif isinstance(param, DataArray): 

2472 if param.dims[0] == "agent": 

2473 if param.dims[-1] == "age": 

2474 parameter_per_t = param[agent].item() 

2475 else: 

2476 parameter_per_t = param.item() 

2477 elif param.dims[0] == "age": 

2478 parameter_per_t = param.item() 

2479 

2480 agent_parameters[key] = parameter_per_t 

2481 

2482 elif key in self.time_inv: 

2483 if isinstance(param, (int, float)): 

2484 agent_parameters[key] = param 

2485 elif isinstance(param, list): 

2486 if isinstance(param[0], list): 

2487 agent_parameters[key] = param[agent] 

2488 else: 

2489 agent_parameters[key] = param 

2490 elif isinstance(param, DataArray) and param.dims[0] == "agent": 

2491 agent_parameters[key] = param[agent].item() 

2492 

2493 else: 

2494 if isinstance(param, (int, float)): 

2495 agent_parameters[key] = param # assume time inv 

2496 elif isinstance(param, list): 

2497 if isinstance(param[0], list): 

2498 agent_parameters[key] = param[agent] # assume agent vary 

2499 else: 

2500 agent_parameters[key] = param # assume time vary 

2501 elif isinstance(param, DataArray): 

2502 if param.dims[0] == "agent": 

2503 if param.dims[-1] == "age": 

2504 agent_parameters[key] = param[ 

2505 agent 

2506 ].item() # assume agent vary 

2507 else: 

2508 agent_parameters[key] = param.item() # assume time vary 

2509 elif param.dims[0] == "age": 

2510 agent_parameters[key] = param.item() # assume time vary 

2511 

2512 population_parameters.append(agent_parameters) 

2513 

2514 self.population_parameters = population_parameters 

2515 

2516 def create_distributed_agents(self): 

2517 """ 

2518 Parses the parameters dictionary and creates a list of agents with the 

2519 appropriate parameters. Also sets the seed for each agent. 

2520 """ 

2521 

2522 self.__parse_parameters__() 

2523 

2524 rng = np.random.default_rng(self.seed) 

2525 

2526 self.agents = [ 

2527 self.agent_type(seed=rng.integers(0, 2**31 - 1), **agent_dict) 

2528 for agent_dict in self.population_parameters 

2529 ] 

2530 

2531 def create_database(self): 

2532 """ 

2533 Optionally creates a pandas DataFrame with the parameters for each agent. 

2534 """ 

2535 database = pd.DataFrame(self.population_parameters) 

2536 database["agents"] = self.agents 

2537 

2538 self.agent_database = database 

2539 

2540 def solve(self): 

2541 """ 

2542 Solves each agent of the population serially. 

2543 """ 

2544 

2545 # see Market class for an example of how to solve distributed agents in parallel 

2546 

2547 for agent in self.agents: 

2548 agent.solve() 

2549 

2550 def unpack_solutions(self): 

2551 """ 

2552 Unpacks the solutions of each agent into an attribute of the population. 

2553 """ 

2554 self.solution = [agent.solution for agent in self.agents] 

2555 

2556 def initialize_sim(self): 

2557 """ 

2558 Initializes the simulation for each agent. 

2559 """ 

2560 for agent in self.agents: 

2561 agent.initialize_sim() 

2562 

2563 def simulate(self, num_jobs=None): 

2564 """ 

2565 Simulates each agent of the population. 

2566 

2567 Parameters 

2568 ---------- 

2569 num_jobs : int, optional 

2570 Number of parallel jobs to use. Defaults to using all available 

2571 cores when ``None``. Falls back to serial execution if parallel 

2572 processing fails. 

2573 """ 

2574 try: 

2575 multi_thread_commands(self.agents, ["simulate()"], num_jobs) 

2576 except Exception as err: 

2577 if getattr(self, "print_parallel_error_once", False): 

2578 self.print_parallel_error_once = False 

2579 print( 

2580 "**** WARNING: could not execute multi_thread_commands in HARK.core.AgentPopulation.simulate() ", 

2581 "so using the serial version instead. This will likely be slower. ", 

2582 "The multi_thread_commands() function failed with the following error:\n", 

2583 sys.exc_info()[0], 

2584 ":", 

2585 err, 

2586 ) 

2587 multi_thread_commands_fake(self.agents, ["simulate()"], num_jobs) 

2588 

2589 def __iter__(self): 

2590 """ 

2591 Allows for iteration over the agents in the population. 

2592 """ 

2593 return iter(self.agents) 

2594 

2595 def __getitem__(self, idx): 

2596 """ 

2597 Allows for indexing into the population. 

2598 """ 

2599 return self.agents[idx] 

2600 

2601 

2602############################################################################### 

2603 

2604 

2605def multi_thread_commands_fake( 

2606 agent_list: List, command_list: List, num_jobs=None 

2607) -> None: 

2608 """ 

2609 Executes the list of commands in command_list for each AgentType in agent_list 

2610 in an ordinary, single-threaded loop. Each command should be a method of 

2611 that AgentType subclass. This function exists so as to easily disable 

2612 multithreading, as it uses the same syntax as multi_thread_commands. 

2613 

2614 Parameters 

2615 ---------- 

2616 agent_list : [AgentType] 

2617 A list of instances of AgentType on which the commands will be run. 

2618 command_list : [string] 

2619 A list of commands to run for each AgentType. 

2620 num_jobs : None 

2621 Dummy input to match syntax of multi_thread_commands. Does nothing. 

2622 

2623 Returns 

2624 ------- 

2625 none 

2626 """ 

2627 for agent in agent_list: 

2628 for command in command_list: 

2629 # TODO: Code should be updated to pass in the method name instead of method() 

2630 getattr(agent, command[:-2])() 

2631 

2632 

2633def multi_thread_commands(agent_list: List, command_list: List, num_jobs=None) -> None: 

2634 """ 

2635 Executes the list of commands in command_list for each AgentType in agent_list 

2636 using a multithreaded system. Each command should be a method of that AgentType subclass. 

2637 

2638 Parameters 

2639 ---------- 

2640 agent_list : [AgentType] 

2641 A list of instances of AgentType on which the commands will be run. 

2642 command_list : [string] 

2643 A list of commands to run for each AgentType in agent_list. 

2644 

2645 Returns 

2646 ------- 

2647 None 

2648 """ 

2649 if len(agent_list) == 1: 

2650 multi_thread_commands_fake(agent_list, command_list) 

2651 return None 

2652 

2653 # Default number of parallel jobs is the smaller of number of AgentTypes in 

2654 # the input and the number of available cores. 

2655 if num_jobs is None: 

2656 num_jobs = min(len(agent_list), multiprocessing.cpu_count()) 

2657 

2658 # Send each command in command_list to each of the types in agent_list to be run 

2659 agent_list_out = Parallel(n_jobs=num_jobs)( 

2660 delayed(run_commands)(*args) 

2661 for args in zip(agent_list, len(agent_list) * [command_list]) 

2662 ) 

2663 

2664 # Replace the original types with the output from the parallel call 

2665 for j in range(len(agent_list)): 

2666 agent_list[j] = agent_list_out[j] 

2667 

2668 

2669def run_commands(agent: Any, command_list: List) -> Any: 

2670 """ 

2671 Executes each command in command_list on a given AgentType. The commands 

2672 should be methods of that AgentType's subclass. 

2673 

2674 Parameters 

2675 ---------- 

2676 agent : AgentType 

2677 An instance of AgentType on which the commands will be run. 

2678 command_list : [string] 

2679 A list of commands that the agent should run, as methods. 

2680 

2681 Returns 

2682 ------- 

2683 agent : AgentType 

2684 The same AgentType instance passed as input, after running the commands. 

2685 """ 

2686 for command in command_list: 

2687 # TODO: Code should be updated to pass in the method name instead of method() 

2688 getattr(agent, command[:-2])() 

2689 return agent