Coverage for HARK / core.py: 96%

967 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-25 05:22 +0000

1""" 

2High-level functions and classes for solving a wide variety of economic models. 

3The "core" of HARK is a framework for "microeconomic" and "macroeconomic" 

4models. A micro model concerns the dynamic optimization problem for some type 

5of agents, where agents take the inputs to their problem as exogenous. A macro 

6model adds an additional layer, endogenizing some of the inputs to the micro 

7problem by finding a general equilibrium dynamic rule. 

8""" 

9 

10# Import basic modules 

11import inspect 

12import sys 

13from collections import namedtuple 

14from copy import copy, deepcopy 

15from dataclasses import dataclass, field 

16from time import time 

17from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union 

18from warnings import warn 

19import multiprocessing 

20from joblib import Parallel, delayed 

21 

22import numpy as np 

23import pandas as pd 

24from xarray import DataArray 

25 

26from HARK.distributions import ( 

27 Distribution, 

28 IndexDistribution, 

29 combine_indep_dstns, 

30) 

31from HARK.utilities import NullFunc, get_arg_names, get_it_from 

32from HARK.simulator import make_simulator_from_agent 

33from HARK.SSJutils import ( 

34 make_basic_SSJ_matrices, 

35 calc_shock_response_manually, 

36) 

37from HARK.metric import MetricObject, distance_metric 

38 

39__all__ = [ 

40 "AgentType", 

41 "Market", 

42 "Parameters", 

43 "Model", 

44 "AgentPopulation", 

45 "multi_thread_commands", 

46 "multi_thread_commands_fake", 

47 "NullFunc", 

48 "make_one_period_oo_solver", 

49 "distribute_params", 

50] 

51 

52 

53class Parameters: 

54 """ 

55 A smart container for model parameters that handles age-varying dynamics. 

56 

57 This class stores parameters as an internal dictionary and manages their 

58 age-varying properties, providing both attribute-style and dictionary-style 

59 access. It is designed to handle the time-varying dynamics of parameters 

60 in economic models. 

61 

62 Attributes 

63 ---------- 

64 _length : int 

65 The terminal age of the agents in the model. 

66 _invariant_params : Set[str] 

67 A set of parameter names that are invariant over time. 

68 _varying_params : Set[str] 

69 A set of parameter names that vary over time. 

70 _parameters : Dict[str, Any] 

71 The internal dictionary storing all parameters. 

72 """ 

73 

74 __slots__ = ( 

75 "_length", 

76 "_invariant_params", 

77 "_varying_params", 

78 "_parameters", 

79 "_frozen", 

80 "_namedtuple_cache", 

81 ) 

82 

83 def __init__(self, **parameters: Any) -> None: 

84 """ 

85 Initialize a Parameters object and parse the age-varying dynamics of parameters. 

86 

87 Parameters 

88 ---------- 

89 T_cycle : int, optional 

90 The number of time periods in the model cycle (default: 1). 

91 Must be >= 1. 

92 frozen : bool, optional 

93 If True, the Parameters object will be immutable after initialization 

94 (default: False). 

95 _time_inv : List[str], optional 

96 List of parameter names to explicitly mark as time-invariant, 

97 overriding automatic inference. 

98 _time_vary : List[str], optional 

99 List of parameter names to explicitly mark as time-varying, 

100 overriding automatic inference. 

101 **parameters : Any 

102 Any number of parameters in the form key=value. 

103 

104 Raises 

105 ------ 

106 ValueError 

107 If T_cycle is less than 1. 

108 

109 Notes 

110 ----- 

111 Automatic time-variance inference rules: 

112 - Scalars (int, float, bool, None) are time-invariant 

113 - NumPy arrays are time-invariant (use lists/tuples for time-varying) 

114 - Single-element lists/tuples [x] are unwrapped to x and time-invariant 

115 - Multi-element lists/tuples are time-varying if length matches T_cycle 

116 - 2D arrays with first dimension matching T_cycle are time-varying 

117 - Distributions and Callables are time-invariant 

118 

119 Use _time_inv or _time_vary to override automatic inference when needed. 

120 """ 

121 # Extract special parameters 

122 self._length: int = parameters.pop("T_cycle", 1) 

123 frozen: bool = parameters.pop("frozen", False) 

124 time_inv_override: List[str] = parameters.pop("_time_inv", []) 

125 time_vary_override: List[str] = parameters.pop("_time_vary", []) 

126 

127 # Validate T_cycle 

128 if self._length < 1: 

129 raise ValueError(f"T_cycle must be >= 1, got {self._length}") 

130 

131 # Initialize internal state 

132 self._invariant_params: Set[str] = set() 

133 self._varying_params: Set[str] = set() 

134 self._parameters: Dict[str, Any] = {"T_cycle": self._length} 

135 self._frozen: bool = False # Set to False initially to allow setup 

136 self._namedtuple_cache: Optional[type] = None 

137 

138 # Set parameters using automatic inference 

139 for key, value in parameters.items(): 

140 self[key] = value 

141 

142 # Apply explicit overrides 

143 for param in time_inv_override: 

144 if param in self._parameters: 

145 self._invariant_params.add(param) 

146 self._varying_params.discard(param) 

147 

148 for param in time_vary_override: 

149 if param in self._parameters: 

150 self._varying_params.add(param) 

151 self._invariant_params.discard(param) 

152 

153 # Freeze if requested 

154 self._frozen = frozen 

155 

156 def __getitem__(self, item_or_key: Union[int, str]) -> Union["Parameters", Any]: 

157 """ 

158 Access parameters by age index or parameter name. 

159 

160 If item_or_key is an integer, returns a Parameters object with the parameters 

161 that apply to that age. This includes all invariant parameters and the 

162 `item_or_key`th element of all age-varying parameters. If item_or_key is a 

163 string, it returns the value of the parameter with that name. 

164 

165 Parameters 

166 ---------- 

167 item_or_key : Union[int, str] 

168 Age index or parameter name. 

169 

170 Returns 

171 ------- 

172 Union[Parameters, Any] 

173 A new Parameters object for the specified age, or the value of the 

174 specified parameter. 

175 

176 Raises 

177 ------ 

178 ValueError: 

179 If the age index is out of bounds. 

180 KeyError: 

181 If the parameter name is not found. 

182 TypeError: 

183 If the key is neither an integer nor a string. 

184 """ 

185 if isinstance(item_or_key, int): 

186 if item_or_key < 0 or item_or_key >= self._length: 

187 raise ValueError( 

188 f"Age {item_or_key} is out of bounds (valid: 0-{self._length - 1})." 

189 ) 

190 

191 params = {key: self._parameters[key] for key in self._invariant_params} 

192 params.update( 

193 { 

194 key: ( 

195 self._parameters[key][item_or_key] 

196 if isinstance(self._parameters[key], (list, tuple, np.ndarray)) 

197 else self._parameters[key] 

198 ) 

199 for key in self._varying_params 

200 } 

201 ) 

202 return Parameters(**params) 

203 elif isinstance(item_or_key, str): 

204 return self._parameters[item_or_key] 

205 else: 

206 raise TypeError("Key must be an integer (age) or string (parameter name).") 

207 

208 def __setitem__(self, key: str, value: Any) -> None: 

209 """ 

210 Set parameter values, automatically inferring time variance. 

211 

212 If the parameter is a scalar, numpy array, boolean, distribution, callable 

213 or None, it is assumed to be invariant over time. If the parameter is a 

214 list or tuple, it is assumed to be varying over time. If the parameter 

215 is a list or tuple of length greater than 1, the length of the list or 

216 tuple must match the `_length` attribute of the Parameters object. 

217 

218 2D numpy arrays with first dimension matching T_cycle are treated as 

219 time-varying parameters. 

220 

221 Parameters 

222 ---------- 

223 key : str 

224 Name of the parameter. 

225 value : Any 

226 Value of the parameter. 

227 

228 Raises 

229 ------ 

230 ValueError: 

231 If the parameter name is not a string or if the value type is unsupported. 

232 If the parameter value is inconsistent with the current model length. 

233 RuntimeError: 

234 If the Parameters object is frozen. 

235 """ 

236 if self._frozen: 

237 raise RuntimeError("Cannot modify frozen Parameters object") 

238 

239 if not isinstance(key, str): 

240 raise ValueError(f"Parameter name must be a string, got {type(key)}") 

241 

242 # Check for 2D numpy arrays with time-varying first dimension 

243 if isinstance(value, np.ndarray) and value.ndim >= 2: 

244 if value.shape[0] == self._length: 

245 self._varying_params.add(key) 

246 self._invariant_params.discard(key) 

247 else: 

248 self._invariant_params.add(key) 

249 self._varying_params.discard(key) 

250 elif isinstance( 

251 value, 

252 ( 

253 int, 

254 float, 

255 np.ndarray, 

256 type(None), 

257 Distribution, 

258 bool, 

259 Callable, 

260 MetricObject, 

261 ), 

262 ): 

263 self._invariant_params.add(key) 

264 self._varying_params.discard(key) 

265 elif isinstance(value, (list, tuple)): 

266 if len(value) == 1: 

267 value = value[0] 

268 self._invariant_params.add(key) 

269 self._varying_params.discard(key) 

270 elif self._length is None or self._length == 1: 

271 self._length = len(value) 

272 self._varying_params.add(key) 

273 self._invariant_params.discard(key) 

274 elif len(value) == self._length: 

275 self._varying_params.add(key) 

276 self._invariant_params.discard(key) 

277 else: 

278 raise ValueError( 

279 f"Parameter {key} must have length 1 or {self._length}, not {len(value)}" 

280 ) 

281 else: 

282 raise ValueError(f"Unsupported type for parameter {key}: {type(value)}") 

283 

284 self._parameters[key] = value 

285 

286 def __iter__(self) -> Iterator[str]: 

287 """Allow iteration over parameter names.""" 

288 return iter(self._parameters) 

289 

290 def __len__(self) -> int: 

291 """Return the number of parameters.""" 

292 return len(self._parameters) 

293 

294 def keys(self) -> Iterator[str]: 

295 """Return a view of parameter names.""" 

296 return self._parameters.keys() 

297 

298 def values(self) -> Iterator[Any]: 

299 """Return a view of parameter values.""" 

300 return self._parameters.values() 

301 

302 def items(self) -> Iterator[Tuple[str, Any]]: 

303 """Return a view of parameter (name, value) pairs.""" 

304 return self._parameters.items() 

305 

306 def to_dict(self) -> Dict[str, Any]: 

307 """ 

308 Convert parameters to a plain dictionary. 

309 

310 Returns 

311 ------- 

312 Dict[str, Any] 

313 A dictionary containing all parameters. 

314 """ 

315 return dict(self._parameters) 

316 

317 def to_namedtuple(self) -> namedtuple: 

318 """ 

319 Convert parameters to a namedtuple. 

320 

321 The namedtuple class is cached for efficiency on repeated calls. 

322 

323 Returns 

324 ------- 

325 namedtuple 

326 A namedtuple containing all parameters. 

327 """ 

328 if self._namedtuple_cache is None: 

329 self._namedtuple_cache = namedtuple("Parameters", self.keys()) 

330 return self._namedtuple_cache(**self.to_dict()) 

331 

332 def update(self, other: Union["Parameters", Dict[str, Any]]) -> None: 

333 """ 

334 Update parameters from another Parameters object or dictionary. 

335 

336 Parameters 

337 ---------- 

338 other : Union[Parameters, Dict[str, Any]] 

339 The source of parameters to update from. 

340 

341 Raises 

342 ------ 

343 TypeError 

344 If the input is neither a Parameters object nor a dictionary. 

345 """ 

346 if isinstance(other, Parameters): 

347 for key, value in other._parameters.items(): 

348 self[key] = value 

349 elif isinstance(other, dict): 

350 for key, value in other.items(): 

351 self[key] = value 

352 else: 

353 raise TypeError( 

354 "Update source must be a Parameters object or a dictionary." 

355 ) 

356 

357 def __repr__(self) -> str: 

358 """Return a detailed string representation of the Parameters object.""" 

359 return ( 

360 f"Parameters(_length={self._length}, " 

361 f"_invariant_params={self._invariant_params}, " 

362 f"_varying_params={self._varying_params}, " 

363 f"_parameters={self._parameters})" 

364 ) 

365 

366 def __str__(self) -> str: 

367 """Return a simple string representation of the Parameters object.""" 

368 return f"Parameters({str(self._parameters)})" 

369 

370 def __getattr__(self, name: str) -> Any: 

371 """ 

372 Allow attribute-style access to parameters. 

373 

374 Parameters 

375 ---------- 

376 name : str 

377 Name of the parameter to access. 

378 

379 Returns 

380 ------- 

381 Any 

382 The value of the specified parameter. 

383 

384 Raises 

385 ------ 

386 AttributeError: 

387 If the parameter name is not found. 

388 """ 

389 if name.startswith("_"): 

390 return super().__getattribute__(name) 

391 try: 

392 return self._parameters[name] 

393 except KeyError: 

394 raise AttributeError(f"'Parameters' object has no attribute '{name}'") 

395 

396 def __setattr__(self, name: str, value: Any) -> None: 

397 """ 

398 Allow attribute-style setting of parameters. 

399 

400 Parameters 

401 ---------- 

402 name : str 

403 Name of the parameter to set. 

404 value : Any 

405 Value to set for the parameter. 

406 """ 

407 if name.startswith("_"): 

408 super().__setattr__(name, value) 

409 else: 

410 self[name] = value 

411 

412 def __contains__(self, item: str) -> bool: 

413 """Check if a parameter exists in the Parameters object.""" 

414 return item in self._parameters 

415 

416 def copy(self) -> "Parameters": 

417 """ 

418 Create a deep copy of the Parameters object. 

419 

420 Returns 

421 ------- 

422 Parameters 

423 A new Parameters object with the same contents. 

424 """ 

425 return deepcopy(self) 

426 

427 def add_to_time_vary(self, *params: str) -> None: 

428 """ 

429 Adds any number of parameters to the time-varying set. 

430 

431 Parameters 

432 ---------- 

433 *params : str 

434 Any number of strings naming parameters to be added to time_vary. 

435 """ 

436 for param in params: 

437 if param in self._parameters: 

438 self._varying_params.add(param) 

439 self._invariant_params.discard(param) 

440 else: 

441 warn( 

442 f"Parameter '{param}' does not exist and cannot be added to time_vary." 

443 ) 

444 

445 def add_to_time_inv(self, *params: str) -> None: 

446 """ 

447 Adds any number of parameters to the time-invariant set. 

448 

449 Parameters 

450 ---------- 

451 *params : str 

452 Any number of strings naming parameters to be added to time_inv. 

453 """ 

454 for param in params: 

455 if param in self._parameters: 

456 self._invariant_params.add(param) 

457 self._varying_params.discard(param) 

458 else: 

459 warn( 

460 f"Parameter '{param}' does not exist and cannot be added to time_inv." 

461 ) 

462 

463 def del_from_time_vary(self, *params: str) -> None: 

464 """ 

465 Removes any number of parameters from the time-varying set. 

466 

467 Parameters 

468 ---------- 

469 *params : str 

470 Any number of strings naming parameters to be removed from time_vary. 

471 """ 

472 for param in params: 

473 self._varying_params.discard(param) 

474 

475 def del_from_time_inv(self, *params: str) -> None: 

476 """ 

477 Removes any number of parameters from the time-invariant set. 

478 

479 Parameters 

480 ---------- 

481 *params : str 

482 Any number of strings naming parameters to be removed from time_inv. 

483 """ 

484 for param in params: 

485 self._invariant_params.discard(param) 

486 

487 def get(self, key: str, default: Any = None) -> Any: 

488 """ 

489 Get a parameter value, returning a default if not found. 

490 

491 Parameters 

492 ---------- 

493 key : str 

494 The parameter name. 

495 default : Any, optional 

496 The default value to return if the key is not found. 

497 

498 Returns 

499 ------- 

500 Any 

501 The parameter value or the default. 

502 """ 

503 return self._parameters.get(key, default) 

504 

505 def set_many(self, **kwargs: Any) -> None: 

506 """ 

507 Set multiple parameters at once. 

508 

509 Parameters 

510 ---------- 

511 **kwargs : Keyword arguments representing parameter names and values. 

512 """ 

513 for key, value in kwargs.items(): 

514 self[key] = value 

515 

516 def is_time_varying(self, key: str) -> bool: 

517 """ 

518 Check if a parameter is time-varying. 

519 

520 Parameters 

521 ---------- 

522 key : str 

523 The parameter name. 

524 

525 Returns 

526 ------- 

527 bool 

528 True if the parameter is time-varying, False otherwise. 

529 """ 

530 return key in self._varying_params 

531 

532 def at_age(self, age: int) -> "Parameters": 

533 """ 

534 Get parameters for a specific age. 

535 

536 This is an alternative to integer indexing (params[age]) that is more 

537 explicit and avoids potential confusion with dictionary-style access. 

538 

539 Parameters 

540 ---------- 

541 age : int 

542 The age index to retrieve parameters for. 

543 

544 Returns 

545 ------- 

546 Parameters 

547 A new Parameters object with parameters for the specified age. 

548 

549 Raises 

550 ------ 

551 ValueError 

552 If the age index is out of bounds. 

553 

554 Examples 

555 -------- 

556 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97], sigma=2.0) 

557 >>> age_1_params = params.at_age(1) 

558 >>> age_1_params.beta 

559 0.96 

560 """ 

561 return self[age] 

562 

563 def validate(self) -> None: 

564 """ 

565 Validate parameter consistency. 

566 

567 Checks that all time-varying parameters have length matching T_cycle. 

568 This is useful after manual modifications or when parameters are set 

569 programmatically. 

570 

571 Raises 

572 ------ 

573 ValueError 

574 If any time-varying parameter has incorrect length. 

575 

576 Examples 

577 -------- 

578 >>> params = Parameters(T_cycle=3, beta=[0.95, 0.96, 0.97]) 

579 >>> params.validate() # Passes 

580 >>> params.add_to_time_vary("beta") 

581 >>> params.validate() # Still passes 

582 """ 

583 errors = [] 

584 for param in self._varying_params: 

585 value = self._parameters[param] 

586 if isinstance(value, (list, tuple)): 

587 if len(value) != self._length: 

588 errors.append( 

589 f"Parameter '{param}' has length {len(value)}, expected {self._length}" 

590 ) 

591 elif isinstance(value, np.ndarray): 

592 if value.ndim == 0: 

593 errors.append( 

594 f"Parameter '{param}' is a 0-dimensional array (scalar), " 

595 "which should not be time-varying" 

596 ) 

597 elif value.ndim >= 2: 

598 if value.shape[0] != self._length: 

599 errors.append( 

600 f"Parameter '{param}' has first dimension {value.shape[0]}, expected {self._length}" 

601 ) 

602 elif value.ndim == 1: 

603 if len(value) != self._length: 

604 errors.append( 

605 f"Parameter '{param}' has length {len(value)}, expected {self._length}" 

606 ) 

607 elif value.ndim == 0: 

608 errors.append( 

609 f"Parameter '{param}' is a 0-dimensional numpy array, expected length {self._length}" 

610 ) 

611 

612 if errors: 

613 raise ValueError( 

614 "Parameter validation failed:\n" + "\n".join(f" - {e}" for e in errors) 

615 ) 

616 

617 

618class Model: 

619 """ 

620 A class with special handling of parameters assignment. 

621 """ 

622 

623 def __init__(self): 

624 if not hasattr(self, "parameters"): 

625 self.parameters = {} 

626 if not hasattr(self, "constructors"): 

627 self.constructors = {} 

628 

629 def assign_parameters(self, **kwds): 

630 """ 

631 Assign an arbitrary number of attributes to this agent. 

632 

633 Parameters 

634 ---------- 

635 **kwds : keyword arguments 

636 Any number of keyword arguments of the form key=value. 

637 Each value will be assigned to the attribute named in self. 

638 

639 Returns 

640 ------- 

641 None 

642 """ 

643 self.parameters.update(kwds) 

644 for key in kwds: 

645 setattr(self, key, kwds[key]) 

646 

647 def get_parameter(self, name): 

648 """ 

649 Returns a parameter of this model 

650 

651 Parameters 

652 ---------- 

653 name : str 

654 The name of the parameter to get 

655 

656 Returns 

657 ------- 

658 value : The value of the parameter 

659 """ 

660 return self.parameters[name] 

661 

662 def __eq__(self, other): 

663 if isinstance(other, type(self)): 

664 return self.parameters == other.parameters 

665 

666 return NotImplemented 

667 

668 def __str__(self): 

669 type_ = type(self) 

670 module = type_.__module__ 

671 qualname = type_.__qualname__ 

672 

673 s = f"<{module}.{qualname} object at {hex(id(self))}.\n" 

674 s += "Parameters:" 

675 

676 for p in self.parameters: 

677 s += f"\n{p}: {self.parameters[p]}" 

678 

679 s += ">" 

680 return s 

681 

682 def describe(self): 

683 return self.__str__() 

684 

685 def del_param(self, param_name): 

686 """ 

687 Deletes a parameter from this instance, removing it both from the object's 

688 namespace (if it's there) and the parameters dictionary (likewise). 

689 

690 Parameters 

691 ---------- 

692 param_name : str 

693 A string naming a parameter or data to be deleted from this instance. 

694 Removes information from self.parameters dictionary and own namespace. 

695 

696 Returns 

697 ------- 

698 None 

699 """ 

700 if param_name in self.parameters: 

701 del self.parameters[param_name] 

702 if hasattr(self, param_name): 

703 delattr(self, param_name) 

704 

705 def construct(self, *args, force=False): 

706 """ 

707 Top-level method for building constructed inputs. If called without any 

708 inputs, construct builds each of the objects named in the keys of the 

709 constructors dictionary; it draws inputs for the constructors from the 

710 parameters dictionary and adds its results to the same. If passed one or 

711 more strings as arguments, the method builds only the named keys. The 

712 method will do multiple "passes" over the requested keys, as some cons- 

713 tructors require inputs built by other constructors. If any requested 

714 constructors failed to build due to missing data, those keys (and the 

715 missing data) will be named in self._missing_key_data. Other errors are 

716 recorded in the dictionary attribute _constructor_errors. 

717 

718 This method tries to "start from scratch" by removing prior constructed 

719 objects, holding them in a backup dictionary during construction. This 

720 is done so that dependencies among constructors are resolved properly, 

721 without mistakenly relying on "old information". A backup value is used 

722 if a constructor function is set to None (i.e. "don't do anything"), or 

723 if the construct method fails to produce a new object. 

724 

725 Parameters 

726 ---------- 

727 *args : str, optional 

728 Keys of self.constructors that are requested to be constructed. 

729 If no arguments are passed, *all* elements of the dictionary are implied. 

730 force : bool, optional 

731 When True, the method will force its way past any errors, including 

732 missing constructors, missing arguments for constructors, and errors 

733 raised during execution of constructors. Information about all such 

734 errors is stored in the dictionary attributes described above. When 

735 False (default), any errors or exception will be raised. 

736 

737 Returns 

738 ------- 

739 None 

740 """ 

741 # Set up the requested work 

742 if len(args) > 0: 

743 keys = args 

744 else: 

745 keys = list(self.constructors.keys()) 

746 N_keys = len(keys) 

747 keys_complete = np.zeros(N_keys, dtype=bool) 

748 if N_keys == 0: 

749 return # Do nothing if there are no constructed objects 

750 

751 # Remove pre-existing constructed objects, preventing "incomplete" updates, 

752 # but store the current values in a backup dictionary in case something fails 

753 backup = {} 

754 for key in keys: 

755 if hasattr(self, key): 

756 backup[key] = getattr(self, key) 

757 self.del_param(key) 

758 

759 # Get the dictionary of constructor errors 

760 if not hasattr(self, "_constructor_errors"): 

761 self._constructor_errors = {} 

762 errors = self._constructor_errors 

763 

764 # As long as the work isn't complete and we made some progress on the last 

765 # pass, repeatedly perform passes of trying to construct objects 

766 any_keys_incomplete = np.any(np.logical_not(keys_complete)) 

767 go = any_keys_incomplete 

768 while go: 

769 anything_accomplished_this_pass = False # Nothing done yet! 

770 missing_key_data = [] # Keep this up-to-date on each pass 

771 

772 # Loop over keys to be constructed 

773 for i in range(N_keys): 

774 if keys_complete[i]: 

775 continue # This key has already been built 

776 

777 # Get this key and its constructor function 

778 key = keys[i] 

779 try: 

780 constructor = self.constructors[key] 

781 except Exception as not_found: 

782 errors[key] = "No constructor found for " + str(not_found) 

783 if force: 

784 continue 

785 else: 

786 raise KeyError("No constructor found for " + key) from None 

787 

788 # If this constructor is None, do nothing and mark it as completed; 

789 # this includes restoring the previous value if it exists 

790 if constructor is None: 

791 if key in backup.keys(): 

792 setattr(self, key, backup[key]) 

793 self.parameters[key] = backup[key] 

794 keys_complete[i] = True 

795 anything_accomplished_this_pass = True # We did something! 

796 continue 

797 

798 # SPECIAL: if the constructor is get_it_from, handle it separately 

799 if isinstance(constructor, get_it_from): 

800 try: 

801 parent = getattr(self, constructor.name) 

802 query = key 

803 any_missing = False 

804 missing_args = [] 

805 except AttributeError: 

806 parent = None 

807 query = None 

808 any_missing = True 

809 missing_args = [constructor.name] 

810 temp_dict = {"parent": parent, "query": query} 

811 

812 # Get the names of arguments for this constructor and try to gather them 

813 else: # (if it's not the special case of get_it_from) 

814 args_needed = get_arg_names(constructor) 

815 has_no_default = { 

816 k: v.default is inspect.Parameter.empty 

817 for k, v in inspect.signature(constructor).parameters.items() 

818 } 

819 temp_dict = {} 

820 any_missing = False 

821 missing_args = [] 

822 for j in range(len(args_needed)): 

823 this_arg = args_needed[j] 

824 if hasattr(self, this_arg): 

825 temp_dict[this_arg] = getattr(self, this_arg) 

826 else: 

827 try: 

828 temp_dict[this_arg] = self.parameters[this_arg] 

829 except KeyError: 

830 if has_no_default[this_arg]: 

831 # Record missing key-data pair 

832 any_missing = True 

833 missing_key_data.append((key, this_arg)) 

834 missing_args.append(this_arg) 

835 

836 # If all of the required data was found, run the constructor and 

837 # store the result in parameters (and on self) 

838 if not any_missing: 

839 try: 

840 temp = constructor(**temp_dict) 

841 except Exception as problem: 

842 errors[key] = str(type(problem)) + ": " + str(problem) 

843 self.del_param(key) 

844 if force: 

845 continue 

846 else: 

847 raise 

848 setattr(self, key, temp) 

849 self.parameters[key] = temp 

850 if key in errors: 

851 del errors[key] 

852 keys_complete[i] = True 

853 anything_accomplished_this_pass = True # We did something! 

854 else: 

855 msg = "Missing required arguments:" 

856 for arg in missing_args: 

857 msg += " " + arg + "," 

858 msg = msg[:-1] 

859 errors[key] = msg 

860 self.del_param(key) 

861 # Never raise exceptions here, as the arguments might be filled in later 

862 

863 # Check whether another pass should be performed 

864 any_keys_incomplete = np.any(np.logical_not(keys_complete)) 

865 go = any_keys_incomplete and anything_accomplished_this_pass 

866 

867 # Store missing key-data pairs and exit 

868 self._missing_key_data = missing_key_data 

869 self._constructor_errors = errors 

870 if any_keys_incomplete: 

871 msg = "Did not construct these objects:" 

872 for i in range(N_keys): 

873 if keys_complete[i]: 

874 continue 

875 msg += " " + keys[i] + "," 

876 key = keys[i] 

877 if key in backup.keys(): 

878 setattr(self, key, backup[key]) 

879 self.parameters[key] = backup[key] 

880 msg = msg[:-1] 

881 if not force: 

882 raise ValueError(msg) 

883 return 

884 

885 def describe_constructors(self, *args): 

886 """ 

887 Prints to screen a string describing this instance's constructed objects, 

888 including their names, the function that constructs them, the names of 

889 those functions inputs, and whether those inputs are present. 

890 

891 Parameters 

892 ---------- 

893 *args : str, optional 

894 Optional list of strings naming constructed inputs to be described. 

895 If none are passed, all constructors are described. 

896 

897 Returns 

898 ------- 

899 None 

900 """ 

901 if len(args) > 0: 

902 keys = args 

903 else: 

904 keys = list(self.constructors.keys()) 

905 yes = "\u2713" 

906 no = "X" 

907 maybe = "*" 

908 noyes = [no, yes] 

909 

910 out = "" 

911 for key in keys: 

912 has_val = hasattr(self, key) or (key in self.parameters) 

913 

914 try: 

915 constructor = self.constructors[key] 

916 except KeyError: 

917 out += noyes[int(has_val)] + " " + key + " : NO CONSTRUCTOR FOUND\n" 

918 continue 

919 

920 # Get the constructor function if possible 

921 if isinstance(constructor, get_it_from): 

922 parent_name = self.constructors[key].name 

923 out += ( 

924 noyes[int(has_val)] 

925 + " " 

926 + key 

927 + " : get it from " 

928 + parent_name 

929 + "\n" 

930 ) 

931 continue 

932 else: 

933 out += ( 

934 noyes[int(has_val)] 

935 + " " 

936 + key 

937 + " : " 

938 + constructor.__name__ 

939 + "\n" 

940 ) 

941 

942 # Get constructor argument names 

943 arg_names = get_arg_names(constructor) 

944 has_no_default = { 

945 k: v.default is inspect.Parameter.empty 

946 for k, v in inspect.signature(constructor).parameters.items() 

947 } 

948 

949 # Check whether each argument exists 

950 for j in range(len(arg_names)): 

951 this_arg = arg_names[j] 

952 if hasattr(self, this_arg) or this_arg in self.parameters: 

953 symb = yes 

954 elif not has_no_default[this_arg]: 

955 symb = maybe 

956 else: 

957 symb = no 

958 out += " " + symb + " " + this_arg + "\n" 

959 

960 # Print the string to screen 

961 print(out) 

962 return 

963 

964 # This is a "synonym" method so that old calls to update() still work 

965 def update(self, *args, **kwargs): 

966 self.construct(*args, **kwargs) 

967 

968 

969class AgentType(Model): 

970 """ 

971 A superclass for economic agents in the HARK framework. Each model should 

972 specify its own subclass of AgentType, inheriting its methods and overwriting 

973 as necessary. Critically, every subclass of AgentType should define class- 

974 specific static values of the attributes time_vary and time_inv as lists of 

975 strings. Each element of time_vary is the name of a field in AgentSubType 

976 that varies over time in the model. Each element of time_inv is the name of 

977 a field in AgentSubType that is constant over time in the model. 

978 

979 Parameters 

980 ---------- 

981 solution_terminal : Solution 

982 A representation of the solution to the terminal period problem of 

983 this AgentType instance, or an initial guess of the solution if this 

984 is an infinite horizon problem. 

985 cycles : int 

986 The number of times the sequence of periods is experienced by this 

987 AgentType in their "lifetime". cycles=1 corresponds to a lifecycle 

988 model, with a certain sequence of one period problems experienced 

989 once before terminating. cycles=0 corresponds to an infinite horizon 

990 model, with a sequence of one period problems repeating indefinitely. 

991 pseudo_terminal : bool 

992 Indicates whether solution_terminal isn't actually part of the 

993 solution to the problem (as a known solution to the terminal period 

994 problem), but instead represents a "scrap value"-style termination. 

995 When True, solution_terminal is not included in the solution; when 

996 False, solution_terminal is the last element of the solution. 

997 tolerance : float 

998 Maximum acceptable "distance" between successive solutions to the 

999 one period problem in an infinite horizon (cycles=0) model in order 

1000 for the solution to be considered as having "converged". Inoperative 

1001 when cycles>0. 

1002 verbose : int 

1003 Level of output to be displayed by this instance, default is 1. 

1004 quiet : bool 

1005 Indicator for whether this instance should operate "quietly", default False. 

1006 seed : int 

1007 A seed for this instance's random number generator. 

1008 construct : bool 

1009 Indicator for whether this instance's construct() method should be run 

1010 when initialized (default True). When False, an instance of the class 

1011 can be created even if not all of its attributes can be constructed. 

1012 use_defaults : bool 

1013 Indicator for whether this instance should use the values in the class' 

1014 default dictionary to fill in parameters and constructors for those not 

1015 provided by the user (default True). Setting this to False is useful for 

1016 situations where the user wants to be absolutely sure that they know what 

1017 is being passed to the class initializer, without resorting to defaults. 

1018 

1019 Attributes 

1020 ---------- 

1021 AgentCount : int 

1022 The number of agents of this type to use in simulation. 

1023 

1024 state_vars : list of string 

1025 The string labels for this AgentType's model state variables. 

1026 """ 

1027 

1028 time_vary_ = [] 

1029 time_inv_ = [] 

1030 shock_vars_ = [] 

1031 state_vars = [] 

1032 poststate_vars = [] 

1033 distributions = [] 

1034 default_ = {"params": {}, "solver": NullFunc()} 

1035 

1036 def __init__( 

1037 self, 

1038 solution_terminal=None, 

1039 pseudo_terminal=True, 

1040 tolerance=0.000001, 

1041 verbose=1, 

1042 quiet=False, 

1043 seed=0, 

1044 construct=True, 

1045 use_defaults=True, 

1046 **kwds, 

1047 ): 

1048 super().__init__() 

1049 params = deepcopy(self.default_["params"]) if use_defaults else {} 

1050 params.update(kwds) 

1051 

1052 # Correctly handle constructors that have been passed in kwds 

1053 if "constructors" in self.default_["params"].keys() and use_defaults: 

1054 constructors = deepcopy(self.default_["params"]["constructors"]) 

1055 else: 

1056 constructors = {} 

1057 if "constructors" in kwds.keys(): 

1058 constructors.update(kwds["constructors"]) 

1059 params["constructors"] = constructors 

1060 

1061 # Set default track_vars 

1062 if "track_vars" in self.default_.keys() and use_defaults: 

1063 self.track_vars = copy(self.default_["track_vars"]) 

1064 else: 

1065 self.track_vars = [] 

1066 

1067 # Set model file name if possible 

1068 try: 

1069 self.model_file = copy(self.default_["model"]) 

1070 except (KeyError, TypeError): 

1071 # Fallback to None if "model" key is missing or invalid for copying 

1072 self.model_file = None 

1073 

1074 if solution_terminal is None: 

1075 solution_terminal = NullFunc() 

1076 

1077 self.solve_one_period = self.default_["solver"] # NOQA 

1078 self.solution_terminal = solution_terminal # NOQA 

1079 self.pseudo_terminal = pseudo_terminal # NOQA 

1080 self.tolerance = tolerance # NOQA 

1081 self.verbose = verbose 

1082 self.quiet = quiet 

1083 self.seed = seed # NOQA 

1084 self.state_now = {sv: None for sv in self.state_vars} 

1085 self.state_prev = self.state_now.copy() 

1086 self.controls = {} 

1087 self.shocks = {} 

1088 self.read_shocks = False # NOQA 

1089 self.shock_history = {} 

1090 self.newborn_init_history = {} 

1091 self.history = {} 

1092 self.assign_parameters(**params) # NOQA 

1093 self.reset_rng() # NOQA 

1094 self.bilt = {} 

1095 if construct: 

1096 self.construct() 

1097 

1098 # Add instance-level lists and objects 

1099 self.time_vary = deepcopy(self.time_vary_) 

1100 self.time_inv = deepcopy(self.time_inv_) 

1101 self.shock_vars = deepcopy(self.shock_vars_) 

1102 

1103 def add_to_time_vary(self, *params): 

1104 """ 

1105 Adds any number of parameters to time_vary for this instance. 

1106 

1107 Parameters 

1108 ---------- 

1109 params : string 

1110 Any number of strings naming attributes to be added to time_vary 

1111 

1112 Returns 

1113 ------- 

1114 None 

1115 """ 

1116 for param in params: 

1117 if param not in self.time_vary: 

1118 self.time_vary.append(param) 

1119 

1120 def add_to_time_inv(self, *params): 

1121 """ 

1122 Adds any number of parameters to time_inv for this instance. 

1123 

1124 Parameters 

1125 ---------- 

1126 params : string 

1127 Any number of strings naming attributes to be added to time_inv 

1128 

1129 Returns 

1130 ------- 

1131 None 

1132 """ 

1133 for param in params: 

1134 if param not in self.time_inv: 

1135 self.time_inv.append(param) 

1136 

1137 def del_from_time_vary(self, *params): 

1138 """ 

1139 Removes any number of parameters from time_vary for this instance. 

1140 

1141 Parameters 

1142 ---------- 

1143 params : string 

1144 Any number of strings naming attributes to be removed from time_vary 

1145 

1146 Returns 

1147 ------- 

1148 None 

1149 """ 

1150 for param in params: 

1151 if param in self.time_vary: 

1152 self.time_vary.remove(param) 

1153 

1154 def del_from_time_inv(self, *params): 

1155 """ 

1156 Removes any number of parameters from time_inv for this instance. 

1157 

1158 Parameters 

1159 ---------- 

1160 params : string 

1161 Any number of strings naming attributes to be removed from time_inv 

1162 

1163 Returns 

1164 ------- 

1165 None 

1166 """ 

1167 for param in params: 

1168 if param in self.time_inv: 

1169 self.time_inv.remove(param) 

1170 

1171 def unpack(self, name): 

1172 """ 

1173 Unpacks an attribute from a solution object for easier access. 

1174 After the model has been solved, its components (like consumption function) 

1175 reside in the attributes of each element of `ThisType.solution` (e.g. `cFunc`). 

1176 This method creates a (time varying) attribute of the given attribute name 

1177 that contains a list of elements accessible by `ThisType.parameter`. 

1178 

1179 Parameters 

1180 ---------- 

1181 name: str 

1182 Name of the attribute to unpack from the solution 

1183 

1184 Returns 

1185 ------- 

1186 none 

1187 """ 

1188 # Use list comprehension for better performance instead of loop with append 

1189 if type(self.solution[0]) is dict: 

1190 setattr(self, name, [soln_t[name] for soln_t in self.solution]) 

1191 else: 

1192 setattr(self, name, [soln_t.__dict__[name] for soln_t in self.solution]) 

1193 self.add_to_time_vary(name) 

1194 

1195 def solve( 

1196 self, 

1197 verbose=False, 

1198 presolve=True, 

1199 postsolve=True, 

1200 from_solution=None, 

1201 from_t=None, 

1202 ): 

1203 """ 

1204 Solve the model for this instance of an agent type by backward induction. 

1205 Loops through the sequence of one period problems, passing the solution 

1206 from period t+1 to the problem for period t. 

1207 

1208 Parameters 

1209 ---------- 

1210 verbose : bool, optional 

1211 If True, solution progress is printed to screen. Default False. 

1212 presolve : bool, optional 

1213 If True (default), the pre_solve method is run before solving. 

1214 postsolve : bool, optional 

1215 If True (default), the post_solve method is run after solving. 

1216 from_solution: Solution 

1217 If different from None, will be used as the starting point of backward 

1218 induction, instead of self.solution_terminal. 

1219 from_t : int or None 

1220 If not None, indicates which period of the model the solver should start 

1221 from. It should usually only be used in combination with from_solution. 

1222 Stands for the time index that from_solution represents, and thus is 

1223 only compatible with cycles=1 and will be reset to None otherwise. 

1224 

1225 Returns 

1226 ------- 

1227 none 

1228 """ 

1229 

1230 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep- 

1231 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is 

1232 # -np.inf, np.inf/np.inf is np.nan and so on. 

1233 with np.errstate( 

1234 divide="ignore", over="ignore", under="ignore", invalid="ignore" 

1235 ): 

1236 if presolve: 

1237 self.pre_solve() # Do pre-solution stuff 

1238 self.solution = solve_agent( 

1239 self, 

1240 verbose, 

1241 from_solution, 

1242 from_t, 

1243 ) # Solve the model by backward induction 

1244 if postsolve: 

1245 self.post_solve() # Do post-solution stuff 

1246 

1247 def reset_rng(self): 

1248 """ 

1249 Reset the random number generator and all distributions for this type. 

1250 Type-checking for lists is to handle the following three cases: 

1251 

1252 1) The target is a single distribution object 

1253 2) The target is a list of distribution objects (probably time-varying) 

1254 3) The target is a nested list of distributions, as in ConsMarkovModel. 

1255 """ 

1256 self.RNG = np.random.default_rng(self.seed) 

1257 for name in self.distributions: 

1258 if not hasattr(self, name): 

1259 continue 

1260 

1261 dstn = getattr(self, name) 

1262 if isinstance(dstn, list): 

1263 for D in dstn: 

1264 if isinstance(D, list): 

1265 for d in D: 

1266 d.reset() 

1267 else: 

1268 D.reset() 

1269 else: 

1270 dstn.reset() 

1271 

1272 def check_elements_of_time_vary_are_lists(self): 

1273 """ 

1274 A method to check that elements of time_vary are lists. 

1275 """ 

1276 for param in self.time_vary: 

1277 if not hasattr(self, param): 

1278 continue 

1279 if not isinstance( 

1280 getattr(self, param), 

1281 (IndexDistribution,), 

1282 ): 

1283 assert type(getattr(self, param)) == list, ( 

1284 param 

1285 + " is not a list or time varying distribution," 

1286 + " but should be because it is in time_vary" 

1287 ) 

1288 

1289 def check_restrictions(self): 

1290 """ 

1291 A method to check that various restrictions are met for the model class. 

1292 """ 

1293 return 

1294 

1295 def pre_solve(self): 

1296 """ 

1297 A method that is run immediately before the model is solved, to check inputs or to prepare 

1298 the terminal solution, perhaps. 

1299 

1300 Parameters 

1301 ---------- 

1302 none 

1303 

1304 Returns 

1305 ------- 

1306 none 

1307 """ 

1308 self.check_restrictions() 

1309 self.check_elements_of_time_vary_are_lists() 

1310 return None 

1311 

1312 def post_solve(self): 

1313 """ 

1314 A method that is run immediately after the model is solved, to finalize 

1315 the solution in some way. Does nothing here. 

1316 

1317 Parameters 

1318 ---------- 

1319 none 

1320 

1321 Returns 

1322 ------- 

1323 none 

1324 """ 

1325 return None 

1326 

1327 def initialize_sym(self, **kwargs): 

1328 """ 

1329 Use the new simulator structure to build a simulator from the agents' 

1330 attributes, storing it in a private attribute. 

1331 """ 

1332 self.reset_rng() # ensure seeds are set identically each time 

1333 self._simulator = make_simulator_from_agent(self, **kwargs) 

1334 self._simulator.reset() 

1335 

1336 def initialize_sim(self): 

1337 """ 

1338 Prepares this AgentType for a new simulation. Resets the internal random number generator, 

1339 makes initial states for all agents (using sim_birth), clears histories of tracked variables. 

1340 

1341 Parameters 

1342 ---------- 

1343 None 

1344 

1345 Returns 

1346 ------- 

1347 None 

1348 """ 

1349 if not hasattr(self, "T_sim"): 

1350 raise Exception( 

1351 "To initialize simulation variables it is necessary to first " 

1352 + "set the attribute T_sim to the largest number of observations " 

1353 + "you plan to simulate for each agent including re-births." 

1354 ) 

1355 elif self.T_sim <= 0: 

1356 raise Exception( 

1357 "T_sim represents the largest number of observations " 

1358 + "that can be simulated for an agent, and must be a positive number." 

1359 ) 

1360 

1361 self.reset_rng() 

1362 self.t_sim = 0 

1363 all_agents = np.ones(self.AgentCount, dtype=bool) 

1364 blank_array = np.empty(self.AgentCount) 

1365 blank_array[:] = np.nan 

1366 for var in self.state_vars: 

1367 self.state_now[var] = copy(blank_array) 

1368 

1369 # Number of periods since agent entry 

1370 self.t_age = np.zeros(self.AgentCount, dtype=int) 

1371 # Which cycle period each agent is on 

1372 self.t_cycle = np.zeros(self.AgentCount, dtype=int) 

1373 self.sim_birth(all_agents) 

1374 

1375 # If we are asked to use existing shocks and a set of initial conditions 

1376 # exist, use them 

1377 if self.read_shocks and bool(self.newborn_init_history): 

1378 for var_name in self.state_now: 

1379 # Check that we are actually given a value for the variable 

1380 if var_name in self.newborn_init_history.keys(): 

1381 # Copy only array-like idiosyncratic states. Aggregates should 

1382 # not be set by newborns 

1383 idio = ( 

1384 isinstance(self.state_now[var_name], np.ndarray) 

1385 and len(self.state_now[var_name]) == self.AgentCount 

1386 ) 

1387 if idio: 

1388 self.state_now[var_name] = self.newborn_init_history[var_name][ 

1389 0 

1390 ] 

1391 

1392 else: 

1393 warn( 

1394 "The option for reading shocks was activated but " 

1395 + "the model requires state " 

1396 + var_name 

1397 + ", not contained in " 

1398 + "newborn_init_history." 

1399 ) 

1400 

1401 self.clear_history() 

1402 return None 

1403 

1404 def sim_one_period(self): 

1405 """ 

1406 Simulates one period for this type. Calls the methods get_mortality(), get_shocks() or 

1407 read_shocks, get_states(), get_controls(), and get_poststates(). These should be defined for 

1408 AgentType subclasses, except get_mortality (define its components sim_death and sim_birth 

1409 instead) and read_shocks. 

1410 

1411 Parameters 

1412 ---------- 

1413 None 

1414 

1415 Returns 

1416 ------- 

1417 None 

1418 """ 

1419 if not hasattr(self, "solution"): 

1420 raise Exception( 

1421 "Model instance does not have a solution stored. To simulate, it is necessary" 

1422 " to run the `solve()` method first." 

1423 ) 

1424 

1425 # Mortality adjusts the agent population 

1426 self.get_mortality() # Replace some agents with "newborns" 

1427 

1428 # state_{t-1} 

1429 for var in self.state_now: 

1430 self.state_prev[var] = self.state_now[var] 

1431 

1432 if isinstance(self.state_now[var], np.ndarray): 

1433 self.state_now[var] = np.empty(self.AgentCount) 

1434 else: 

1435 # Probably an aggregate variable. It may be getting set by the Market. 

1436 pass 

1437 

1438 if self.read_shocks: # If shock histories have been pre-specified, use those 

1439 self.read_shocks_from_history() 

1440 else: # Otherwise, draw shocks as usual according to subclass-specific method 

1441 self.get_shocks() 

1442 self.get_states() # Determine each agent's state at decision time 

1443 self.get_controls() # Determine each agent's choice or control variables based on states 

1444 self.get_poststates() # Calculate variables that come *after* decision-time 

1445 

1446 # Advance time for all agents 

1447 self.t_age = self.t_age + 1 # Age all consumers by one period 

1448 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle 

1449 self.t_cycle[self.t_cycle == self.T_cycle] = ( 

1450 0 # Resetting to zero for those who have reached the end 

1451 ) 

1452 

1453 def make_shock_history(self): 

1454 """ 

1455 Makes a pre-specified history of shocks for the simulation. Shock variables should be named 

1456 in self.shock_vars, a list of strings that is subclass-specific. This method runs a subset 

1457 of the standard simulation loop by simulating only mortality and shocks; each variable named 

1458 in shock_vars is stored in a T_sim x AgentCount array in history dictionary self.history[X]. 

1459 Automatically sets self.read_shocks to True so that these pre-specified shocks are used for 

1460 all subsequent calls to simulate(). 

1461 

1462 Parameters 

1463 ---------- 

1464 None 

1465 

1466 Returns 

1467 ------- 

1468 None 

1469 """ 

1470 # Re-initialize the simulation 

1471 self.initialize_sim() 

1472 

1473 # Make blank history arrays for each shock variable (and mortality) 

1474 for var_name in self.shock_vars: 

1475 self.shock_history[var_name] = ( 

1476 np.zeros((self.T_sim, self.AgentCount)) + np.nan 

1477 ) 

1478 self.shock_history["who_dies"] = np.zeros( 

1479 (self.T_sim, self.AgentCount), dtype=bool 

1480 ) 

1481 

1482 # Also make blank arrays for the draws of newborns' initial conditions 

1483 for var_name in self.state_vars: 

1484 self.newborn_init_history[var_name] = ( 

1485 np.zeros((self.T_sim, self.AgentCount)) + np.nan 

1486 ) 

1487 

1488 # Record the initial condition of the newborns created by 

1489 # initialize_sim -> sim_births 

1490 for var_name in self.state_vars: 

1491 # Check whether the state is idiosyncratic or an aggregate 

1492 idio = ( 

1493 isinstance(self.state_now[var_name], np.ndarray) 

1494 and len(self.state_now[var_name]) == self.AgentCount 

1495 ) 

1496 if idio: 

1497 self.newborn_init_history[var_name][self.t_sim] = self.state_now[ 

1498 var_name 

1499 ] 

1500 else: 

1501 # Aggregate state is a scalar. Assign it to every agent. 

1502 self.newborn_init_history[var_name][self.t_sim, :] = self.state_now[ 

1503 var_name 

1504 ] 

1505 

1506 # Make and store the history of shocks for each period 

1507 for t in range(self.T_sim): 

1508 # Deaths 

1509 self.get_mortality() 

1510 self.shock_history["who_dies"][t, :] = self.who_dies 

1511 

1512 # Initial conditions of newborns 

1513 if self.who_dies.any(): 

1514 for var_name in self.state_vars: 

1515 # Check whether the state is idiosyncratic or an aggregate 

1516 idio = ( 

1517 isinstance(self.state_now[var_name], np.ndarray) 

1518 and len(self.state_now[var_name]) == self.AgentCount 

1519 ) 

1520 if idio: 

1521 self.newborn_init_history[var_name][t, self.who_dies] = ( 

1522 self.state_now[var_name][self.who_dies] 

1523 ) 

1524 else: 

1525 self.newborn_init_history[var_name][t, self.who_dies] = ( 

1526 self.state_now[var_name] 

1527 ) 

1528 

1529 # Other Shocks 

1530 self.get_shocks() 

1531 for var_name in self.shock_vars: 

1532 self.shock_history[var_name][t, :] = self.shocks[var_name] 

1533 

1534 self.t_sim += 1 

1535 self.t_age = self.t_age + 1 # Age all consumers by one period 

1536 self.t_cycle = self.t_cycle + 1 # Age all consumers within their cycle 

1537 self.t_cycle[self.t_cycle == self.T_cycle] = ( 

1538 0 # Resetting to zero for those who have reached the end 

1539 ) 

1540 

1541 # Flag that shocks can be read rather than simulated 

1542 self.read_shocks = True 

1543 

1544 def get_mortality(self): 

1545 """ 

1546 Simulates mortality or agent turnover according to some model-specific rules named sim_death 

1547 and sim_birth (methods of an AgentType subclass). sim_death takes no arguments and returns 

1548 a Boolean array of size AgentCount, indicating which agents of this type have "died" and 

1549 must be replaced. sim_birth takes such a Boolean array as an argument and generates initial 

1550 post-decision states for those agent indices. 

1551 

1552 Parameters 

1553 ---------- 

1554 None 

1555 

1556 Returns 

1557 ------- 

1558 None 

1559 """ 

1560 if self.read_shocks: 

1561 who_dies = self.shock_history["who_dies"][self.t_sim, :] 

1562 # Instead of simulating births, assign the saved newborn initial conditions 

1563 if who_dies.any(): 

1564 for var_name in self.state_now: 

1565 if var_name in self.newborn_init_history.keys(): 

1566 # Copy only array-like idiosyncratic states. Aggregates should 

1567 # not be set by newborns 

1568 idio = ( 

1569 isinstance(self.state_now[var_name], np.ndarray) 

1570 and len(self.state_now[var_name]) == self.AgentCount 

1571 ) 

1572 if idio: 

1573 self.state_now[var_name][who_dies] = ( 

1574 self.newborn_init_history[var_name][ 

1575 self.t_sim, who_dies 

1576 ] 

1577 ) 

1578 

1579 else: 

1580 warn( 

1581 "The option for reading shocks was activated but " 

1582 + "the model requires state " 

1583 + var_name 

1584 + ", not contained in " 

1585 + "newborn_init_history." 

1586 ) 

1587 

1588 # Reset ages of newborns 

1589 self.t_age[who_dies] = 0 

1590 self.t_cycle[who_dies] = 0 

1591 else: 

1592 who_dies = self.sim_death() 

1593 self.sim_birth(who_dies) 

1594 self.who_dies = who_dies 

1595 return None 

1596 

1597 def sim_death(self): 

1598 """ 

1599 Determines which agents in the current population "die" or should be replaced. Takes no 

1600 inputs, returns a Boolean array of size self.AgentCount, which has True for agents who die 

1601 and False for those that survive. Returns all False by default, must be overwritten by a 

1602 subclass to have replacement events. 

1603 

1604 Parameters 

1605 ---------- 

1606 None 

1607 

1608 Returns 

1609 ------- 

1610 who_dies : np.array 

1611 Boolean array of size self.AgentCount indicating which agents die and are replaced. 

1612 """ 

1613 who_dies = np.zeros(self.AgentCount, dtype=bool) 

1614 return who_dies 

1615 

1616 def sim_birth(self, which_agents): # pragma: nocover 

1617 """ 

1618 Makes new agents for the simulation. Takes a boolean array as an input, indicating which 

1619 agent indices are to be "born". Does nothing by default, must be overwritten by a subclass. 

1620 

1621 Parameters 

1622 ---------- 

1623 which_agents : np.array(Bool) 

1624 Boolean array of size self.AgentCount indicating which agents should be "born". 

1625 

1626 Returns 

1627 ------- 

1628 None 

1629 """ 

1630 raise Exception("AgentType subclass must define method sim_birth!") 

1631 

1632 def get_shocks(self): # pragma: nocover 

1633 """ 

1634 Gets values of shock variables for the current period. Does nothing by default, but can 

1635 be overwritten by subclasses of AgentType. 

1636 

1637 Parameters 

1638 ---------- 

1639 None 

1640 

1641 Returns 

1642 ------- 

1643 None 

1644 """ 

1645 return None 

1646 

1647 def read_shocks_from_history(self): 

1648 """ 

1649 Reads values of shock variables for the current period from history arrays. 

1650 For each variable X named in self.shock_vars, this attribute of self is 

1651 set to self.history[X][self.t_sim,:]. 

1652 

1653 This method is only ever called if self.read_shocks is True. This can 

1654 be achieved by using the method make_shock_history() (or manually after 

1655 storing a "handcrafted" shock history). 

1656 

1657 Parameters 

1658 ---------- 

1659 None 

1660 

1661 Returns 

1662 ------- 

1663 None 

1664 """ 

1665 for var_name in self.shock_vars: 

1666 self.shocks[var_name] = self.shock_history[var_name][self.t_sim, :] 

1667 

1668 def get_states(self): 

1669 """ 

1670 Gets values of state variables for the current period. 

1671 By default, calls transition function and assigns values 

1672 to the state_now dictionary. 

1673 

1674 Parameters 

1675 ---------- 

1676 None 

1677 

1678 Returns 

1679 ------- 

1680 None 

1681 """ 

1682 new_states = self.transition() 

1683 

1684 for i, var in enumerate(self.state_now): 

1685 # a hack for now to deal with 'post-states' 

1686 if i < len(new_states): 

1687 self.state_now[var] = new_states[i] 

1688 

1689 def transition(self): # pragma: nocover 

1690 """ 

1691 

1692 Parameters 

1693 ---------- 

1694 None 

1695 

1696 [Eventually, to match dolo spec: 

1697 exogenous_prev, endogenous_prev, controls, exogenous, parameters] 

1698 

1699 Returns 

1700 ------- 

1701 

1702 endogenous_state: () 

1703 Tuple with new values of the endogenous states 

1704 """ 

1705 return () 

1706 

1707 def get_controls(self): # pragma: nocover 

1708 """ 

1709 Gets values of control variables for the current period, probably by using current states. 

1710 Does nothing by default, but can be overwritten by subclasses of AgentType. 

1711 

1712 Parameters 

1713 ---------- 

1714 None 

1715 

1716 Returns 

1717 ------- 

1718 None 

1719 """ 

1720 return None 

1721 

1722 def get_poststates(self): 

1723 """ 

1724 Gets values of post-decision state variables for the current period, 

1725 probably by current 

1726 states and controls and maybe market-level events or shock variables. 

1727 Does nothing by 

1728 default, but can be overwritten by subclasses of AgentType. 

1729 

1730 Parameters 

1731 ---------- 

1732 None 

1733 

1734 Returns 

1735 ------- 

1736 None 

1737 """ 

1738 return None 

1739 

1740 def symulate(self, T=None): 

1741 """ 

1742 Run the new simulation structure, with history results written to the 

1743 hystory attribute of self. 

1744 """ 

1745 self._simulator.simulate(T) 

1746 self.hystory = self._simulator.history 

1747 

1748 def describe_model(self, display=True): 

1749 """ 

1750 Print to screen information about this agent's model, based on its model 

1751 file. This is useful for learning about outcome variable names for tracking 

1752 during simulation, or for use with sequence space Jacobians. 

1753 """ 

1754 if not hasattr(self, "_simulator"): 

1755 self.initialize_sym() 

1756 self._simulator.describe(display=display) 

1757 

1758 def simulate(self, sim_periods=None): 

1759 """ 

1760 Simulates this agent type for a given number of periods. Defaults to self.T_sim, 

1761 or all remaining periods to simulate (T_sim - t_sim). Records histories of 

1762 attributes named in self.track_vars in self.history[varname]. 

1763 

1764 Parameters 

1765 ---------- 

1766 sim_periods : int or None 

1767 Number of periods to simulate. Default is all remaining periods (usually T_sim). 

1768 

1769 Returns 

1770 ------- 

1771 history : dict 

1772 The history tracked during the simulation. 

1773 """ 

1774 if not hasattr(self, "t_sim"): 

1775 raise Exception( 

1776 "It seems that the simulation variables were not initialize before calling " 

1777 + "simulate(). Call initialize_sim() to initialize the variables before calling simulate() again." 

1778 ) 

1779 

1780 if not hasattr(self, "T_sim"): 

1781 raise Exception( 

1782 "This agent type instance must have the attribute T_sim set to a positive integer." 

1783 + "Set T_sim to match the largest dataset you might simulate, and run this agent's" 

1784 + "initialize_sim() method before running simulate() again." 

1785 ) 

1786 

1787 if sim_periods is not None and self.T_sim < sim_periods: 

1788 raise Exception( 

1789 "To simulate, sim_periods has to be larger than the maximum data set size " 

1790 + "T_sim. Either increase the attribute T_sim of this agent type instance " 

1791 + "and call the initialize_sim() method again, or set sim_periods <= T_sim." 

1792 ) 

1793 

1794 # Ignore floating point "errors". Numpy calls it "errors", but really it's excep- 

1795 # tions with well-defined answers such as 1.0/0.0 that is np.inf, -1.0/0.0 that is 

1796 # -np.inf, np.inf/np.inf is np.nan and so on. 

1797 with np.errstate( 

1798 divide="ignore", over="ignore", under="ignore", invalid="ignore" 

1799 ): 

1800 if sim_periods is None: 

1801 sim_periods = self.T_sim - self.t_sim 

1802 

1803 for t in range(sim_periods): 

1804 self.sim_one_period() 

1805 

1806 for var_name in self.track_vars: 

1807 if var_name in self.state_now: 

1808 self.history[var_name][self.t_sim, :] = self.state_now[var_name] 

1809 elif var_name in self.shocks: 

1810 self.history[var_name][self.t_sim, :] = self.shocks[var_name] 

1811 elif var_name in self.controls: 

1812 self.history[var_name][self.t_sim, :] = self.controls[var_name] 

1813 else: 

1814 if var_name == "who_dies" and self.t_sim > 1: 

1815 self.history[var_name][self.t_sim - 1, :] = getattr( 

1816 self, var_name 

1817 ) 

1818 else: 

1819 self.history[var_name][self.t_sim, :] = getattr( 

1820 self, var_name 

1821 ) 

1822 self.t_sim += 1 

1823 

1824 def clear_history(self): 

1825 """ 

1826 Clears the histories of the attributes named in self.track_vars. 

1827 

1828 Parameters 

1829 ---------- 

1830 None 

1831 

1832 Returns 

1833 ------- 

1834 None 

1835 """ 

1836 for var_name in self.track_vars: 

1837 self.history[var_name] = np.empty((self.T_sim, self.AgentCount)) 

1838 self.history[var_name].fill(np.nan) 

1839 

1840 def make_basic_SSJ(self, shock, outcomes, grids, **kwargs): 

1841 """ 

1842 Construct and return sequence space Jacobian matrices for specified outcomes 

1843 with respect to specified "shock" variable. This "basic" method only works 

1844 for "one period infinite horizon" models (cycles=0, T_cycle=1). See documen- 

1845 tation for simulator.make_basic_SSJ_matrices for more information. 

1846 """ 

1847 return make_basic_SSJ_matrices(self, shock, outcomes, grids, **kwargs) 

1848 

1849 def calc_impulse_response_manually(self, shock, outcomes, grids, **kwargs): 

1850 """ 

1851 Calculate and return the impulse response(s) of a perturbation to the shock 

1852 parameter in period t=s, essentially computing one column of the sequence 

1853 space Jacobian matrix manually. This "basic" method only works for "one 

1854 period infinite horizon" models (cycles=0, T_cycle=1). See documentation 

1855 for simulator.calc_shock_response_manually for more information. 

1856 """ 

1857 return calc_shock_response_manually(self, shock, outcomes, grids, **kwargs) 

1858 

1859 

1860def solve_agent(agent, verbose, from_solution=None, from_t=None): 

1861 """ 

1862 Solve the dynamic model for one agent type using backwards induction. This 

1863 function iterates on "cycles" of an agent's model either a given number of 

1864 times or until solution convergence if an infinite horizon model is used 

1865 (with agent.cycles = 0). 

1866 

1867 Parameters 

1868 ---------- 

1869 agent : AgentType 

1870 The microeconomic AgentType whose dynamic problem 

1871 is to be solved. 

1872 verbose : boolean 

1873 If True, solution progress is printed to screen (when cycles != 1). 

1874 from_solution: Solution 

1875 If different from None, will be used as the starting point of backward 

1876 induction, instead of self.solution_terminal 

1877 from_t : int or None 

1878 If not None, indicates which period of the model the solver should start 

1879 from. It should usually only be used in combination with from_solution. 

1880 Stands for the time index that from_solution represents, and thus is 

1881 only compatible with cycles=1 and will be reset to None otherwise. 

1882 

1883 Returns 

1884 ------- 

1885 solution : [Solution] 

1886 A list of solutions to the one period problems that the agent will 

1887 encounter in his "lifetime". 

1888 """ 

1889 # Check to see whether this is an (in)finite horizon problem 

1890 cycles_left = agent.cycles # NOQA 

1891 infinite_horizon = cycles_left == 0 # NOQA 

1892 

1893 if from_solution is None: 

1894 solution_last = agent.solution_terminal # NOQA 

1895 else: 

1896 solution_last = from_solution 

1897 if agent.cycles != 1: 

1898 from_t = None 

1899 

1900 # Initialize the solution, which includes the terminal solution if it's not a pseudo-terminal period 

1901 solution = [] 

1902 if not agent.pseudo_terminal: 

1903 solution.insert(0, deepcopy(solution_last)) 

1904 

1905 # Initialize the process, then loop over cycles 

1906 go = True # NOQA 

1907 completed_cycles = 0 # NOQA 

1908 max_cycles = 5000 # NOQA - escape clause 

1909 if verbose: 

1910 t_last = time() 

1911 while go: 

1912 # Solve a cycle of the model, recording it if horizon is finite 

1913 solution_cycle = solve_one_cycle(agent, solution_last, from_t) 

1914 if not infinite_horizon: 

1915 solution = solution_cycle + solution 

1916 

1917 # Check for termination: identical solutions across 

1918 # cycle iterations or run out of cycles 

1919 solution_now = solution_cycle[0] 

1920 if infinite_horizon: 

1921 if completed_cycles > 0: 

1922 solution_distance = distance_metric(solution_now, solution_last) 

1923 agent.solution_distance = ( 

1924 solution_distance # Add these attributes so users can 

1925 ) 

1926 agent.completed_cycles = ( 

1927 completed_cycles # query them to see if solution is ready 

1928 ) 

1929 go = ( 

1930 solution_distance > agent.tolerance 

1931 and completed_cycles < max_cycles 

1932 ) 

1933 else: # Assume solution does not converge after only one cycle 

1934 solution_distance = 100.0 

1935 go = True 

1936 else: 

1937 cycles_left += -1 

1938 go = cycles_left > 0 

1939 

1940 # Update the "last period solution" 

1941 solution_last = solution_now 

1942 completed_cycles += 1 

1943 

1944 # Display progress if requested 

1945 if verbose: 

1946 t_now = time() 

1947 if infinite_horizon: 

1948 print( 

1949 "Finished cycle #" 

1950 + str(completed_cycles) 

1951 + " in " 

1952 + str(t_now - t_last) 

1953 + " seconds, solution distance = " 

1954 + str(solution_distance) 

1955 ) 

1956 else: 

1957 print( 

1958 "Finished cycle #" 

1959 + str(completed_cycles) 

1960 + " of " 

1961 + str(agent.cycles) 

1962 + " in " 

1963 + str(t_now - t_last) 

1964 + " seconds." 

1965 ) 

1966 t_last = t_now 

1967 

1968 # Record the last cycle if horizon is infinite (solution is still empty!) 

1969 if infinite_horizon: 

1970 solution = ( 

1971 solution_cycle # PseudoTerminal=False impossible for infinite horizon 

1972 ) 

1973 

1974 return solution 

1975 

1976 

1977def solve_one_cycle(agent, solution_last, from_t): 

1978 """ 

1979 Solve one "cycle" of the dynamic model for one agent type. This function 

1980 iterates over the periods within an agent's cycle, updating the time-varying 

1981 parameters and passing them to the single period solver(s). 

1982 

1983 Parameters 

1984 ---------- 

1985 agent : AgentType 

1986 The microeconomic AgentType whose dynamic problem is to be solved. 

1987 solution_last : Solution 

1988 A representation of the solution of the period that comes after the 

1989 end of the sequence of one period problems. This might be the term- 

1990 inal period solution, a "pseudo terminal" solution, or simply the 

1991 solution to the earliest period from the succeeding cycle. 

1992 from_t : int or None 

1993 If not None, indicates which period of the model the solver should start 

1994 from. When used, represents the time index that solution_last is from. 

1995 

1996 Returns 

1997 ------- 

1998 solution_cycle : [Solution] 

1999 A list of one period solutions for one "cycle" of the AgentType's 

2000 microeconomic model. 

2001 """ 

2002 

2003 # Check if the agent has a 'Parameters' attribute of the 'Parameters' class 

2004 # if so, take advantage of it. Else, use the old method 

2005 if hasattr(agent, "parameters") and isinstance(agent.parameters, Parameters): 

2006 T = agent.parameters._length if from_t is None else from_t 

2007 

2008 # Initialize the solution for this cycle, then iterate on periods 

2009 solution_cycle = [] 

2010 solution_next = solution_last 

2011 

2012 cycles_range = [0] + list(range(T - 1, 0, -1)) 

2013 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range: 

2014 # Update which single period solver to use (if it depends on time) 

2015 if hasattr(agent.solve_one_period, "__getitem__"): 

2016 solve_one_period = agent.solve_one_period[k] 

2017 else: 

2018 solve_one_period = agent.solve_one_period 

2019 

2020 if hasattr(solve_one_period, "solver_args"): 

2021 these_args = solve_one_period.solver_args 

2022 else: 

2023 these_args = get_arg_names(solve_one_period) 

2024 

2025 # Make a temporary dictionary for this period 

2026 temp_pars = agent.parameters[k] 

2027 temp_dict = { 

2028 name: solution_next if name == "solution_next" else temp_pars[name] 

2029 for name in these_args 

2030 } 

2031 

2032 # Solve one period, add it to the solution, and move to the next period 

2033 solution_t = solve_one_period(**temp_dict) 

2034 solution_cycle.insert(0, solution_t) 

2035 solution_next = solution_t 

2036 

2037 else: 

2038 # Calculate number of periods per cycle, defaults to 1 if all variables are time invariant 

2039 if len(agent.time_vary) > 0: 

2040 T = agent.T_cycle if from_t is None else from_t 

2041 else: 

2042 T = 1 

2043 

2044 solve_dict = { 

2045 parameter: agent.__dict__[parameter] for parameter in agent.time_inv 

2046 } 

2047 solve_dict.update({parameter: None for parameter in agent.time_vary}) 

2048 

2049 # Initialize the solution for this cycle, then iterate on periods 

2050 solution_cycle = [] 

2051 solution_next = solution_last 

2052 

2053 cycles_range = [0] + list(range(T - 1, 0, -1)) 

2054 for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range: 

2055 # Update which single period solver to use (if it depends on time) 

2056 if hasattr(agent.solve_one_period, "__getitem__"): 

2057 solve_one_period = agent.solve_one_period[k] 

2058 else: 

2059 solve_one_period = agent.solve_one_period 

2060 

2061 if hasattr(solve_one_period, "solver_args"): 

2062 these_args = solve_one_period.solver_args 

2063 else: 

2064 these_args = get_arg_names(solve_one_period) 

2065 

2066 # Update time-varying single period inputs 

2067 for name in agent.time_vary: 

2068 if name in these_args: 

2069 solve_dict[name] = agent.__dict__[name][k] 

2070 solve_dict["solution_next"] = solution_next 

2071 

2072 # Make a temporary dictionary for this period 

2073 temp_dict = {name: solve_dict[name] for name in these_args} 

2074 

2075 # Solve one period, add it to the solution, and move to the next period 

2076 solution_t = solve_one_period(**temp_dict) 

2077 solution_cycle.insert(0, solution_t) 

2078 solution_next = solution_t 

2079 

2080 # Return the list of per-period solutions 

2081 return solution_cycle 

2082 

2083 

2084def make_one_period_oo_solver(solver_class): 

2085 """ 

2086 Returns a function that solves a single period consumption-saving 

2087 problem. 

2088 Parameters 

2089 ---------- 

2090 solver_class : Solver 

2091 A class of Solver to be used. 

2092 ------- 

2093 solver_function : function 

2094 A function for solving one period of a problem. 

2095 """ 

2096 

2097 def one_period_solver(**kwds): 

2098 solver = solver_class(**kwds) 

2099 

2100 # not ideal; better if this is defined in all Solver classes 

2101 if hasattr(solver, "prepare_to_solve"): 

2102 solver.prepare_to_solve() 

2103 

2104 solution_now = solver.solve() 

2105 return solution_now 

2106 

2107 one_period_solver.solver_class = solver_class 

2108 # This can be revisited once it is possible to export parameters 

2109 one_period_solver.solver_args = get_arg_names(solver_class.__init__)[1:] 

2110 

2111 return one_period_solver 

2112 

2113 

2114# ======================================================================== 

2115# ======================================================================== 

2116 

2117 

2118class Market(Model): 

2119 """ 

2120 A superclass to represent a central clearinghouse of information. Used for 

2121 dynamic general equilibrium models to solve the "macroeconomic" model as a 

2122 layer on top of the "microeconomic" models of one or more AgentTypes. 

2123 

2124 Parameters 

2125 ---------- 

2126 agents : [AgentType] 

2127 A list of all the AgentTypes in this market. 

2128 sow_vars : [string] 

2129 Names of variables generated by the "aggregate market process" that should 

2130 be "sown" to the agents in the market. Aggregate state, etc. 

2131 reap_vars : [string] 

2132 Names of variables to be collected ("reaped") from agents in the market 

2133 to be used in the "aggregate market process". 

2134 const_vars : [string] 

2135 Names of attributes of the Market instance that are used in the "aggregate 

2136 market process" but do not come from agents-- they are constant or simply 

2137 parameters inherent to the process. 

2138 track_vars : [string] 

2139 Names of variables generated by the "aggregate market process" that should 

2140 be tracked as a "history" so that a new dynamic rule can be calculated. 

2141 This is often a subset of sow_vars. 

2142 dyn_vars : [string] 

2143 Names of variables that constitute a "dynamic rule". 

2144 mill_rule : function 

2145 A function that takes inputs named in reap_vars and returns a tuple the 

2146 same size and order as sow_vars. The "aggregate market process" that 

2147 transforms individual agent actions/states/data into aggregate data to 

2148 be sent back to agents. 

2149 calc_dynamics : function 

2150 A function that takes inputs named in track_vars and returns an object 

2151 with attributes named in dyn_vars. Looks at histories of aggregate 

2152 variables and generates a new "dynamic rule" for agents to believe and 

2153 act on. 

2154 act_T : int 

2155 The number of times that the "aggregate market process" should be run 

2156 in order to generate a history of aggregate variables. 

2157 tolerance: float 

2158 Minimum acceptable distance between "dynamic rules" to consider the 

2159 Market solution process converged. Distance is a user-defined metric. 

2160 """ 

2161 

2162 def __init__( 

2163 self, 

2164 agents=None, 

2165 sow_vars=None, 

2166 reap_vars=None, 

2167 const_vars=None, 

2168 track_vars=None, 

2169 dyn_vars=None, 

2170 mill_rule=None, 

2171 calc_dynamics=None, 

2172 distributions=None, 

2173 act_T=1000, 

2174 tolerance=0.000001, 

2175 seed=0, 

2176 **kwds, 

2177 ): 

2178 super().__init__() 

2179 self.agents = agents if agents is not None else list() # NOQA 

2180 

2181 self.reap_vars = reap_vars if reap_vars is not None else list() # NOQA 

2182 self.reap_state = {var: [] for var in self.reap_vars} 

2183 

2184 self.sow_vars = sow_vars if sow_vars is not None else list() # NOQA 

2185 # dictionaries for tracking initial and current values 

2186 # of the sow variables. 

2187 self.sow_init = {var: None for var in self.sow_vars} 

2188 self.sow_state = {var: None for var in self.sow_vars} 

2189 

2190 const_vars = const_vars if const_vars is not None else list() # NOQA 

2191 self.const_vars = {var: None for var in const_vars} 

2192 

2193 self.track_vars = track_vars if track_vars is not None else list() # NOQA 

2194 self.dyn_vars = dyn_vars if dyn_vars is not None else list() # NOQA 

2195 self.distributions = distributions if distributions is not None else list() # NOQA 

2196 

2197 if mill_rule is not None: # To prevent overwriting of method-based mill_rules 

2198 self.mill_rule = mill_rule 

2199 if calc_dynamics is not None: # Ditto for calc_dynamics 

2200 self.calc_dynamics = calc_dynamics 

2201 self.act_T = act_T # NOQA 

2202 self.tolerance = tolerance # NOQA 

2203 self.seed = seed 

2204 self.max_loops = 1000 # NOQA 

2205 self.history = {} 

2206 self.assign_parameters(**kwds) 

2207 self.RNG = np.random.default_rng(self.seed) 

2208 

2209 self.print_parallel_error_once = True 

2210 # Print the error associated with calling the parallel method 

2211 # "solve_agents" one time. If set to false, the error will never 

2212 # print. See "solve_agents" for why this prints once or never. 

2213 

2214 def solve_agents(self): 

2215 """ 

2216 Solves the microeconomic problem for all AgentTypes in this market. 

2217 

2218 Parameters 

2219 ---------- 

2220 None 

2221 

2222 Returns 

2223 ------- 

2224 None 

2225 """ 

2226 try: 

2227 multi_thread_commands(self.agents, ["solve()"]) 

2228 except Exception as err: 

2229 if self.print_parallel_error_once: 

2230 # Set flag to False so this is only printed once. 

2231 self.print_parallel_error_once = False 

2232 print( 

2233 "**** WARNING: could not execute multi_thread_commands in HARK.core.Market.solve_agents() ", 

2234 "so using the serial version instead. This will likely be slower. " 

2235 "The multi_thread_commands() functions failed with the following error:", 

2236 "\n", 

2237 sys.exc_info()[0], 

2238 ":", 

2239 err, 

2240 ) # sys.exc_info()[0]) 

2241 multi_thread_commands_fake(self.agents, ["solve()"]) 

2242 

2243 def solve(self): 

2244 """ 

2245 "Solves" the market by finding a "dynamic rule" that governs the aggregate 

2246 market state such that when agents believe in these dynamics, their actions 

2247 collectively generate the same dynamic rule. 

2248 

2249 Parameters 

2250 ---------- 

2251 None 

2252 

2253 Returns 

2254 ------- 

2255 None 

2256 """ 

2257 go = True 

2258 max_loops = self.max_loops # Failsafe against infinite solution loop 

2259 completed_loops = 0 

2260 old_dynamics = None 

2261 

2262 while go: # Loop until the dynamic process converges or we hit the loop cap 

2263 self.solve_agents() # Solve each AgentType's micro problem 

2264 self.make_history() # "Run" the model while tracking aggregate variables 

2265 new_dynamics = self.update_dynamics() # Find a new aggregate dynamic rule 

2266 

2267 # Check to see if the dynamic rule has converged (if this is not the first loop) 

2268 if completed_loops > 0: 

2269 distance = new_dynamics.distance(old_dynamics) 

2270 else: 

2271 distance = 1000000.0 

2272 

2273 # Move to the next loop if the terminal conditions are not met 

2274 old_dynamics = new_dynamics 

2275 completed_loops += 1 

2276 go = distance >= self.tolerance and completed_loops < max_loops 

2277 

2278 self.dynamics = new_dynamics # Store the final dynamic rule in self 

2279 

2280 def reap(self): 

2281 """ 

2282 Collects attributes named in reap_vars from each AgentType in the market, 

2283 storing them in respectively named attributes of self. 

2284 

2285 Parameters 

2286 ---------- 

2287 none 

2288 

2289 Returns 

2290 ------- 

2291 none 

2292 """ 

2293 for var in self.reap_state: 

2294 harvest = [] 

2295 

2296 for agent in self.agents: 

2297 # TODO: generalized variable lookup across namespaces 

2298 if var in agent.state_now: 

2299 # or state_now ?? 

2300 harvest.append(agent.state_now[var]) 

2301 

2302 self.reap_state[var] = harvest 

2303 

2304 def sow(self): 

2305 """ 

2306 Distributes attrributes named in sow_vars from self to each AgentType 

2307 in the market, storing them in respectively named attributes. 

2308 

2309 Parameters 

2310 ---------- 

2311 none 

2312 

2313 Returns 

2314 ------- 

2315 none 

2316 """ 

2317 for sow_var in self.sow_state: 

2318 for this_type in self.agents: 

2319 if sow_var in this_type.state_now: 

2320 this_type.state_now[sow_var] = self.sow_state[sow_var] 

2321 if sow_var in this_type.shocks: 

2322 this_type.shocks[sow_var] = self.sow_state[sow_var] 

2323 else: 

2324 setattr(this_type, sow_var, self.sow_state[sow_var]) 

2325 

2326 def mill(self): 

2327 """ 

2328 Processes the variables collected from agents using the function mill_rule, 

2329 storing the results in attributes named in aggr_sow. 

2330 

2331 Parameters 

2332 ---------- 

2333 none 

2334 

2335 Returns 

2336 ------- 

2337 none 

2338 """ 

2339 # Make a dictionary of inputs for the mill_rule 

2340 mill_dict = copy(self.reap_state) 

2341 mill_dict.update(self.const_vars) 

2342 

2343 # Run the mill_rule and store its output in self 

2344 product = self.mill_rule(**mill_dict) 

2345 

2346 for i, sow_var in enumerate(self.sow_state): 

2347 self.sow_state[sow_var] = product[i] 

2348 

2349 def cultivate(self): 

2350 """ 

2351 Has each AgentType in agents perform their market_action method, using 

2352 variables sown from the market (and maybe also "private" variables). 

2353 The market_action method should store new results in attributes named in 

2354 reap_vars to be reaped later. 

2355 

2356 Parameters 

2357 ---------- 

2358 none 

2359 

2360 Returns 

2361 ------- 

2362 none 

2363 """ 

2364 for this_type in self.agents: 

2365 this_type.market_action() 

2366 

2367 def reset(self): 

2368 """ 

2369 Reset the state of the market (attributes in sow_vars, etc) to some 

2370 user-defined initial state, and erase the histories of tracked variables. 

2371 Also resets the internal RNG so that draws can be reproduced. 

2372 

2373 Parameters 

2374 ---------- 

2375 none 

2376 

2377 Returns 

2378 ------- 

2379 none 

2380 """ 

2381 # Reset internal RNG and distributions 

2382 for name in self.distributions: 

2383 if not hasattr(self, name): 

2384 continue 

2385 dstn = getattr(self, name) 

2386 if isinstance(dstn, list): 

2387 for D in dstn: 

2388 D.reset() 

2389 else: 

2390 dstn.reset() 

2391 

2392 # Reset the history of tracked variables 

2393 self.history = {var_name: [] for var_name in self.track_vars} 

2394 

2395 # Set the sow variables to their initial levels 

2396 for var_name in self.sow_state: 

2397 self.sow_state[var_name] = self.sow_init[var_name] 

2398 

2399 # Reset each AgentType in the market 

2400 for this_type in self.agents: 

2401 this_type.reset() 

2402 

2403 def store(self): 

2404 """ 

2405 Record the current value of each variable X named in track_vars in an 

2406 dictionary field named history[X]. 

2407 

2408 Parameters 

2409 ---------- 

2410 none 

2411 

2412 Returns 

2413 ------- 

2414 none 

2415 """ 

2416 for var_name in self.track_vars: 

2417 if var_name in self.sow_state: 

2418 value_now = self.sow_state[var_name] 

2419 elif var_name in self.reap_state: 

2420 value_now = self.reap_state[var_name] 

2421 elif var_name in self.const_vars: 

2422 value_now = self.const_vars[var_name] 

2423 else: 

2424 value_now = getattr(self, var_name) 

2425 

2426 self.history[var_name].append(value_now) 

2427 

2428 def make_history(self): 

2429 """ 

2430 Runs a loop of sow-->cultivate-->reap-->mill act_T times, tracking the 

2431 evolution of variables X named in track_vars in dictionary fields 

2432 history[X]. 

2433 

2434 Parameters 

2435 ---------- 

2436 none 

2437 

2438 Returns 

2439 ------- 

2440 none 

2441 """ 

2442 self.reset() # Initialize the state of the market 

2443 for t in range(self.act_T): 

2444 self.sow() # Distribute aggregated information/state to agents 

2445 self.cultivate() # Agents take action 

2446 self.reap() # Collect individual data from agents 

2447 self.mill() # Process individual data into aggregate data 

2448 self.store() # Record variables of interest 

2449 

2450 def update_dynamics(self): 

2451 """ 

2452 Calculates a new "aggregate dynamic rule" using the history of variables 

2453 named in track_vars, and distributes this rule to AgentTypes in agents. 

2454 

2455 Parameters 

2456 ---------- 

2457 none 

2458 

2459 Returns 

2460 ------- 

2461 dynamics : instance 

2462 The new "aggregate dynamic rule" that agents believe in and act on. 

2463 Should have attributes named in dyn_vars. 

2464 """ 

2465 # Make a dictionary of inputs for the dynamics calculator 

2466 arg_names = list(get_arg_names(self.calc_dynamics)) 

2467 if "self" in arg_names: 

2468 arg_names.remove("self") 

2469 update_dict = {name: self.history[name] for name in arg_names} 

2470 # Calculate a new dynamic rule and distribute it to the agents in agent_list 

2471 dynamics = self.calc_dynamics(**update_dict) # User-defined dynamics calculator 

2472 for var_name in self.dyn_vars: 

2473 this_obj = getattr(dynamics, var_name) 

2474 for this_type in self.agents: 

2475 setattr(this_type, var_name, this_obj) 

2476 return dynamics 

2477 

2478 

2479def distribute_params(agent, param_name, param_count, distribution): 

2480 """ 

2481 Distributes heterogeneous values of one parameter to the AgentTypes in self.agents. 

2482 Parameters 

2483 ---------- 

2484 agent: AgentType 

2485 An agent to clone. 

2486 param_name : string 

2487 Name of the parameter to be assigned. 

2488 param_count : int 

2489 Number of different values the parameter will take on. 

2490 distribution : Distribution 

2491 A 1-D distribution. 

2492 

2493 Returns 

2494 ------- 

2495 agent_set : [AgentType] 

2496 A list of param_count agents, ex ante heterogeneous with 

2497 respect to param_name. The AgentCount of the original 

2498 will be split between the agents of the returned 

2499 list in proportion to the given distribution. 

2500 """ 

2501 param_dist = distribution.discretize(N=param_count) 

2502 

2503 agent_set = [deepcopy(agent) for i in range(param_count)] 

2504 

2505 for j in range(param_count): 

2506 agent_set[j].assign_parameters( 

2507 **{"AgentCount": int(agent.AgentCount * param_dist.pmv[j])} 

2508 ) 

2509 agent_set[j].assign_parameters(**{param_name: param_dist.atoms[0, j]}) 

2510 

2511 return agent_set 

2512 

2513 

2514@dataclass 

2515class AgentPopulation: 

2516 """ 

2517 A class for representing a population of ex-ante heterogeneous agents. 

2518 """ 

2519 

2520 agent_type: AgentType # type of agent in the population 

2521 parameters: dict # dictionary of parameters 

2522 seed: int = 0 # random seed 

2523 time_var: List[str] = field(init=False) 

2524 time_inv: List[str] = field(init=False) 

2525 distributed_params: List[str] = field(init=False) 

2526 agent_type_count: Optional[int] = field(init=False) 

2527 term_age: Optional[int] = field(init=False) 

2528 continuous_distributions: Dict[str, Distribution] = field(init=False) 

2529 discrete_distributions: Dict[str, Distribution] = field(init=False) 

2530 population_parameters: List[Dict[str, Union[List[float], float]]] = field( 

2531 init=False 

2532 ) 

2533 agents: List[AgentType] = field(init=False) 

2534 agent_database: pd.DataFrame = field(init=False) 

2535 solution: List[Any] = field(init=False) 

2536 

2537 def __post_init__(self): 

2538 """ 

2539 Initialize the population of agents, determine distributed parameters, 

2540 and infer `agent_type_count` and `term_age`. 

2541 """ 

2542 # create a dummy agent and obtain its time-varying 

2543 # and time-invariant attributes 

2544 dummy_agent = self.agent_type() 

2545 self.time_var = dummy_agent.time_vary 

2546 self.time_inv = dummy_agent.time_inv 

2547 

2548 # create list of distributed parameters 

2549 # these are parameters that differ across agents 

2550 self.distributed_params = [ 

2551 key 

2552 for key, param in self.parameters.items() 

2553 if (isinstance(param, list) and isinstance(param[0], list)) 

2554 or isinstance(param, Distribution) 

2555 or (isinstance(param, DataArray) and param.dims[0] == "agent") 

2556 ] 

2557 

2558 self.__infer_counts__() 

2559 

2560 self.print_parallel_error_once = True 

2561 # Print warning once if parallel simulation fails 

2562 

2563 def __infer_counts__(self): 

2564 """ 

2565 Infer `agent_type_count` and `term_age` from the parameters. 

2566 If parameters include a `Distribution` type, a list of lists, 

2567 or a `DataArray` with `agent` as the first dimension, then 

2568 the AgentPopulation contains ex-ante heterogenous agents. 

2569 """ 

2570 

2571 # infer agent_type_count from distributed parameters 

2572 agent_type_count = 1 

2573 for key in self.distributed_params: 

2574 param = self.parameters[key] 

2575 if isinstance(param, Distribution): 

2576 agent_type_count = None 

2577 warn( 

2578 "Cannot infer agent_type_count from a Distribution. " 

2579 "Please provide approximation parameters." 

2580 ) 

2581 break 

2582 elif isinstance(param, list): 

2583 agent_type_count = max(agent_type_count, len(param)) 

2584 elif isinstance(param, DataArray) and param.dims[0] == "agent": 

2585 agent_type_count = max(agent_type_count, param.shape[0]) 

2586 

2587 self.agent_type_count = agent_type_count 

2588 

2589 # infer term_age from all parameters 

2590 term_age = 1 

2591 for param in self.parameters.values(): 

2592 if isinstance(param, Distribution): 

2593 term_age = None 

2594 warn( 

2595 "Cannot infer term_age from a Distribution. " 

2596 "Please provide approximation parameters." 

2597 ) 

2598 break 

2599 elif isinstance(param, list) and isinstance(param[0], list): 

2600 term_age = max(term_age, len(param[0])) 

2601 elif isinstance(param, DataArray) and param.dims[-1] == "age": 

2602 term_age = max(term_age, param.shape[-1]) 

2603 

2604 self.term_age = term_age 

2605 

2606 def approx_distributions(self, approx_params: dict): 

2607 """ 

2608 Approximate continuous distributions with discrete ones. If the initial 

2609 parameters include a `Distribution` type, then the AgentPopulation is 

2610 not ready to solve, and stands for an abstract population. To solve the 

2611 AgentPopulation, we need discretization parameters for each continuous 

2612 distribution. This method approximates the continuous distributions with 

2613 discrete ones, and updates the parameters dictionary. 

2614 """ 

2615 self.continuous_distributions = {} 

2616 self.discrete_distributions = {} 

2617 

2618 for key, args in approx_params.items(): 

2619 param = self.parameters[key] 

2620 if key in self.distributed_params and isinstance(param, Distribution): 

2621 self.continuous_distributions[key] = param 

2622 self.discrete_distributions[key] = param.discretize(**args) 

2623 else: 

2624 raise ValueError( 

2625 f"Warning: parameter {key} is not a Distribution found " 

2626 f"in agent type {self.agent_type}" 

2627 ) 

2628 

2629 if len(self.discrete_distributions) > 1: 

2630 joint_dist = combine_indep_dstns(*self.discrete_distributions.values()) 

2631 else: 

2632 joint_dist = list(self.discrete_distributions.values())[0] 

2633 

2634 for i, key in enumerate(self.discrete_distributions): 

2635 self.parameters[key] = DataArray(joint_dist.atoms[i], dims=("agent")) 

2636 

2637 self.__infer_counts__() 

2638 

2639 def __parse_parameters__(self) -> None: 

2640 """ 

2641 Creates distributed dictionaries of parameters for each ex-ante 

2642 heterogeneous agent in the parameterized population. The parameters 

2643 are stored in a list of dictionaries, where each dictionary contains 

2644 the parameters for one agent. Expands parameters that vary over time 

2645 to a list of length `term_age`. 

2646 """ 

2647 

2648 population_parameters = [] # container for dictionaries of each agent subgroup 

2649 for agent in range(self.agent_type_count): 

2650 agent_parameters = {} 

2651 for key, param in self.parameters.items(): 

2652 if key in self.time_var: 

2653 # parameters that vary over time have to be repeated 

2654 if isinstance(param, (int, float)): 

2655 parameter_per_t = [param] * self.term_age 

2656 elif isinstance(param, list): 

2657 if isinstance(param[0], list): 

2658 parameter_per_t = param[agent] 

2659 else: 

2660 parameter_per_t = param 

2661 elif isinstance(param, DataArray): 

2662 if param.dims[0] == "agent": 

2663 if len(param.dims) > 1 and param.dims[-1] == "age": 

2664 parameter_per_t = param[agent].values.tolist() 

2665 else: 

2666 parameter_per_t = param[agent].item() 

2667 elif param.dims[0] == "age": 

2668 parameter_per_t = param.values.tolist() 

2669 

2670 agent_parameters[key] = parameter_per_t 

2671 

2672 elif key in self.time_inv: 

2673 if isinstance(param, (int, float)): 

2674 agent_parameters[key] = param 

2675 elif isinstance(param, list): 

2676 if isinstance(param[0], list): 

2677 agent_parameters[key] = param[agent] 

2678 else: 

2679 agent_parameters[key] = param 

2680 elif isinstance(param, DataArray) and param.dims[0] == "agent": 

2681 agent_parameters[key] = param[agent].item() 

2682 

2683 else: 

2684 if isinstance(param, (int, float)): 

2685 agent_parameters[key] = param # assume time inv 

2686 elif isinstance(param, list): 

2687 if isinstance(param[0], list): 

2688 agent_parameters[key] = param[agent] # assume agent vary 

2689 else: 

2690 agent_parameters[key] = param # assume time vary 

2691 elif isinstance(param, DataArray): 

2692 if param.dims[0] == "agent": 

2693 if len(param.dims) > 1 and param.dims[-1] == "age": 

2694 agent_parameters[key] = param[agent].values.tolist() 

2695 else: 

2696 agent_parameters[key] = param[agent].item() 

2697 elif param.dims[0] == "age": 

2698 agent_parameters[key] = param.values.tolist() 

2699 

2700 population_parameters.append(agent_parameters) 

2701 

2702 self.population_parameters = population_parameters 

2703 

2704 def create_distributed_agents(self): 

2705 """ 

2706 Parses the parameters dictionary and creates a list of agents with the 

2707 appropriate parameters. Also sets the seed for each agent. 

2708 """ 

2709 

2710 self.__parse_parameters__() 

2711 

2712 rng = np.random.default_rng(self.seed) 

2713 

2714 self.agents = [ 

2715 self.agent_type(seed=rng.integers(0, 2**31 - 1), **agent_dict) 

2716 for agent_dict in self.population_parameters 

2717 ] 

2718 

2719 def create_database(self): 

2720 """ 

2721 Optionally creates a pandas DataFrame with the parameters for each agent. 

2722 """ 

2723 database = pd.DataFrame(self.population_parameters) 

2724 database["agents"] = self.agents 

2725 

2726 self.agent_database = database 

2727 

2728 def solve(self): 

2729 """ 

2730 Solves each agent of the population serially. 

2731 """ 

2732 

2733 # see Market class for an example of how to solve distributed agents in parallel 

2734 

2735 for agent in self.agents: 

2736 agent.solve() 

2737 

2738 def unpack_solutions(self): 

2739 """ 

2740 Unpacks the solutions of each agent into an attribute of the population. 

2741 """ 

2742 self.solution = [agent.solution for agent in self.agents] 

2743 

2744 def initialize_sim(self): 

2745 """ 

2746 Initializes the simulation for each agent. 

2747 """ 

2748 for agent in self.agents: 

2749 agent.initialize_sim() 

2750 

2751 def simulate(self, num_jobs=None): 

2752 """ 

2753 Simulates each agent of the population. 

2754 

2755 Parameters 

2756 ---------- 

2757 num_jobs : int, optional 

2758 Number of parallel jobs to use. Defaults to using all available 

2759 cores when ``None``. Falls back to serial execution if parallel 

2760 processing fails. 

2761 """ 

2762 try: 

2763 multi_thread_commands(self.agents, ["simulate()"], num_jobs) 

2764 except Exception as err: 

2765 if getattr(self, "print_parallel_error_once", False): 

2766 self.print_parallel_error_once = False 

2767 print( 

2768 "**** WARNING: could not execute multi_thread_commands in HARK.core.AgentPopulation.simulate() ", 

2769 "so using the serial version instead. This will likely be slower. ", 

2770 "The multi_thread_commands() function failed with the following error:\n", 

2771 sys.exc_info()[0], 

2772 ":", 

2773 err, 

2774 ) 

2775 multi_thread_commands_fake(self.agents, ["simulate()"], num_jobs) 

2776 

2777 def __iter__(self): 

2778 """ 

2779 Allows for iteration over the agents in the population. 

2780 """ 

2781 return iter(self.agents) 

2782 

2783 def __getitem__(self, idx): 

2784 """ 

2785 Allows for indexing into the population. 

2786 """ 

2787 return self.agents[idx] 

2788 

2789 

2790############################################################################### 

2791 

2792 

2793def multi_thread_commands_fake( 

2794 agent_list: List, command_list: List, num_jobs=None 

2795) -> None: 

2796 """ 

2797 Executes the list of commands in command_list for each AgentType in agent_list 

2798 in an ordinary, single-threaded loop. Each command should be a method of 

2799 that AgentType subclass. This function exists so as to easily disable 

2800 multithreading, as it uses the same syntax as multi_thread_commands. 

2801 

2802 Parameters 

2803 ---------- 

2804 agent_list : [AgentType] 

2805 A list of instances of AgentType on which the commands will be run. 

2806 command_list : [string] 

2807 A list of commands to run for each AgentType. 

2808 num_jobs : None 

2809 Dummy input to match syntax of multi_thread_commands. Does nothing. 

2810 

2811 Returns 

2812 ------- 

2813 none 

2814 """ 

2815 for agent in agent_list: 

2816 for command in command_list: 

2817 # Can pass method names with or without parentheses 

2818 if command[-2:] == "()": 

2819 getattr(agent, command[:-2])() 

2820 else: 

2821 getattr(agent, command)() 

2822 

2823 

2824def multi_thread_commands(agent_list: List, command_list: List, num_jobs=None) -> None: 

2825 """ 

2826 Executes the list of commands in command_list for each AgentType in agent_list 

2827 using a multithreaded system. Each command should be a method of that AgentType subclass. 

2828 

2829 Parameters 

2830 ---------- 

2831 agent_list : [AgentType] 

2832 A list of instances of AgentType on which the commands will be run. 

2833 command_list : [string] 

2834 A list of commands to run for each AgentType in agent_list. 

2835 

2836 Returns 

2837 ------- 

2838 None 

2839 """ 

2840 if len(agent_list) == 1: 

2841 multi_thread_commands_fake(agent_list, command_list) 

2842 return None 

2843 

2844 # Default number of parallel jobs is the smaller of number of AgentTypes in 

2845 # the input and the number of available cores. 

2846 if num_jobs is None: 

2847 num_jobs = min(len(agent_list), multiprocessing.cpu_count()) 

2848 

2849 # Send each command in command_list to each of the types in agent_list to be run 

2850 agent_list_out = Parallel(n_jobs=num_jobs)( 

2851 delayed(run_commands)(*args) 

2852 for args in zip(agent_list, len(agent_list) * [command_list]) 

2853 ) 

2854 

2855 # Replace the original types with the output from the parallel call 

2856 for j in range(len(agent_list)): 

2857 agent_list[j] = agent_list_out[j] 

2858 

2859 

2860def run_commands(agent: Any, command_list: List) -> Any: 

2861 """ 

2862 Executes each command in command_list on a given AgentType. The commands 

2863 should be methods of that AgentType's subclass. 

2864 

2865 Parameters 

2866 ---------- 

2867 agent : AgentType 

2868 An instance of AgentType on which the commands will be run. 

2869 command_list : [string] 

2870 A list of commands that the agent should run, as methods. 

2871 

2872 Returns 

2873 ------- 

2874 agent : AgentType 

2875 The same AgentType instance passed as input, after running the commands. 

2876 """ 

2877 for command in command_list: 

2878 # Can pass method names with or without parentheses 

2879 if command[-2:] == "()": 

2880 getattr(agent, command[:-2])() 

2881 else: 

2882 getattr(agent, command)() 

2883 return agent