Coverage for HARK/ConsumptionSaving/ConsLabeledModel.py: 72%

364 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-02 05:14 +0000

1from dataclasses import dataclass 

2from types import SimpleNamespace 

3from typing import Mapping 

4 

5import numpy as np 

6import xarray as xr 

7 

8from HARK.Calibration.Assets.AssetProcesses import ( 

9 make_lognormal_RiskyDstn, 

10 combine_IncShkDstn_and_RiskyDstn, 

11) 

12from HARK.ConsumptionSaving.ConsIndShockModel import ( 

13 IndShockConsumerType, 

14 init_perfect_foresight, 

15 init_idiosyncratic_shocks, 

16 IndShockConsumerType_aXtraGrid_default, 

17) 

18from HARK.ConsumptionSaving.ConsPortfolioModel import ( 

19 PortfolioConsumerType, 

20 init_portfolio, 

21) 

22from HARK.ConsumptionSaving.ConsRiskyAssetModel import ( 

23 RiskyAssetConsumerType, 

24 init_risky_asset, 

25 IndShockRiskyAssetConsumerType_constructor_default, 

26) 

27from HARK.Calibration.Income.IncomeProcesses import ( 

28 construct_lognormal_income_process_unemployment, 

29) 

30from HARK.ConsumptionSaving.LegacyOOsolvers import ConsIndShockSetup 

31from HARK.core import make_one_period_oo_solver 

32from HARK.distributions import DiscreteDistributionLabeled 

33from HARK.metric import MetricObject 

34from HARK.rewards import UtilityFuncCRRA 

35from HARK.utilities import make_assets_grid 

36 

37 

38class ValueFuncCRRALabeled(MetricObject): 

39 """ 

40 Class to allow for value function interpolation using xarray. 

41 """ 

42 

43 def __init__(self, dataset: xr.Dataset, CRRA: float): 

44 """ 

45 Initialize a value function. 

46 

47 Parameters 

48 ---------- 

49 dataset : xr.Dataset 

50 Underlying dataset that should include a variable named 

51 "v_inv" that is the inverse of the value function. 

52 

53 CRRA : float 

54 Coefficient of relative risk aversion. 

55 """ 

56 

57 self.dataset = dataset 

58 self.CRRA = CRRA 

59 self.u = UtilityFuncCRRA(CRRA) 

60 

61 def __call__(self, state: Mapping[str, np.ndarray]) -> xr.Dataset: 

62 """ 

63 Interpolate inverse value function then invert to get value function at given state. 

64 

65 Parameters 

66 ---------- 

67 state : Mapping[str, np.ndarray] 

68 State to evaluate value function at. 

69 

70 Returns 

71 ------- 

72 result : xr.Dataset 

73 """ 

74 

75 state_dict = self._validate_state(state) 

76 

77 result = self.u( 

78 self.dataset["v_inv"].interp( 

79 state_dict, 

80 assume_sorted=True, 

81 kwargs={"fill_value": "extrapolate"}, 

82 ) 

83 ) 

84 

85 result.name = "v" 

86 result.attrs = self.dataset["v"].attrs 

87 

88 return result 

89 

90 def derivative(self, state): 

91 """ 

92 Interpolate inverse marginal value function then invert to get marginal value function at given state. 

93 

94 Parameters 

95 ---------- 

96 state : Mapping[str, np.ndarray] 

97 State to evaluate marginal value function at. 

98 

99 Returns 

100 ------- 

101 result : xr.Dataset 

102 """ 

103 

104 state_dict = self._validate_state(state) 

105 

106 result = self.u.der( 

107 self.dataset["v_der_inv"].interp( 

108 state_dict, 

109 assume_sorted=True, 

110 kwargs={"fill_value": "extrapolate"}, 

111 ) 

112 ) 

113 

114 result.name = "v_der" 

115 result.attrs = self.dataset["v"].attrs 

116 

117 return result 

118 

119 def evaluate(self, state): 

120 """ 

121 Interpolate all data variables in the dataset. 

122 

123 Parameters 

124 ---------- 

125 state : Mapping[str, np.ndarray] 

126 State to evaluate all data variables at. 

127 

128 Returns 

129 ------- 

130 result : xr.Dataset 

131 """ 

132 

133 state_dict = self._validate_state(state) 

134 

135 result = self.dataset.interp( 

136 state_dict, 

137 kwargs={"fill_value": None}, 

138 ) 

139 result.attrs = self.dataset["v"].attrs 

140 

141 return result 

142 

143 def _validate_state(self, state): 

144 """ 

145 Allowed states are either a dict or an xr.Dataset. 

146 This methods keeps only the coordinates of the dataset 

147 if they are both in the dataset and the input state. 

148 

149 Parameters 

150 ---------- 

151 state : Mapping[str, np.ndarray] 

152 State to validate. 

153 

154 Returns 

155 ------- 

156 state_dict : dict 

157 """ 

158 

159 if isinstance(state, (xr.Dataset, dict)): 

160 state_dict = {} 

161 for coords in self.dataset.coords.keys(): 

162 state_dict[coords] = state[coords] 

163 else: 

164 raise ValueError("state must be a dict or xr.Dataset") 

165 

166 return state_dict 

167 

168 

169class ConsumerSolutionLabeled(MetricObject): 

170 """ 

171 Class to allow for solution interpolation using xarray. 

172 Represents a solution object for labeled models. 

173 """ 

174 

175 def __init__( 

176 self, 

177 value: ValueFuncCRRALabeled, 

178 policy: xr.Dataset, 

179 continuation: ValueFuncCRRALabeled, 

180 attrs=None, 

181 ): 

182 """ 

183 Consumer Solution for labeled models. 

184 

185 Parameters 

186 ---------- 

187 value : ValueFuncCRRALabeled 

188 Value function and marginal value function. 

189 policy : xr.Dataset 

190 Policy function. 

191 continuation : ValueFuncCRRALabeled 

192 Continuation value function and marginal value function. 

193 attrs : _type_, optional 

194 Attributes of the solution. The default is None. 

195 """ 

196 

197 if attrs is None: 

198 attrs = dict() 

199 

200 self.value = value # value function 

201 self.policy = policy # policy function 

202 self.continuation = continuation # continuation function 

203 

204 self.attrs = attrs 

205 

206 def distance(self, other: "ConsumerSolutionLabeled"): 

207 """ 

208 Compute the distance between two solutions. 

209 

210 Parameters 

211 ---------- 

212 other : ConsumerSolutionLabeled 

213 Other solution to compare to. 

214 

215 Returns 

216 ------- 

217 float 

218 Distance between the two solutions. 

219 """ 

220 

221 # TODO: is there a faster way to compare two xr.Datasets? 

222 

223 value = self.value.dataset 

224 other_value = other.value.dataset.interp_like(value) 

225 

226 return np.max(np.abs(value - other_value).to_array()) 

227 

228 

229############################################################################### 

230 

231 

232def make_solution_terminal_labeled(CRRA, aXtraGrid): 

233 """ 

234 Construct the terminal solution of the model by creating a terminal value 

235 function and terminal marginal value function along with a terminal policy 

236 function. This is used as the constructor for solution_terminal. 

237 

238 Parameters 

239 ---------- 

240 CRRA : float 

241 Coefficient of relative risk aversion. 

242 aXtraGrid : np.array 

243 Grid of assets above minimum. 

244 

245 Returns 

246 ------- 

247 solution_terminal : ConsumerSolutionLabeled 

248 Terminal period solution. 

249 """ 

250 u = UtilityFuncCRRA(CRRA) 

251 

252 mNrm = xr.DataArray( 

253 np.append(0.0, aXtraGrid), 

254 name="mNrm", 

255 dims=("mNrm"), 

256 attrs={"long_name": "cash_on_hand"}, 

257 ) 

258 state = xr.Dataset({"mNrm": mNrm}) # only one state var in this model 

259 

260 # optimal decision is to consume everything in the last period 

261 cNrm = xr.DataArray( 

262 mNrm, 

263 name="cNrm", 

264 dims=state.dims, 

265 coords=state.coords, 

266 attrs={"long_name": "consumption"}, 

267 ) 

268 

269 v = u(cNrm) 

270 v.name = "v" 

271 v.attrs = {"long_name": "value function"} 

272 

273 v_der = u.der(cNrm) 

274 v_der.name = "v_der" 

275 v_der.attrs = {"long_name": "marginal value function"} 

276 

277 v_inv = cNrm.copy() 

278 v_inv.name = "v_inv" 

279 v_inv.attrs = {"long_name": "inverse value function"} 

280 

281 v_der_inv = cNrm.copy() 

282 v_der_inv.name = "v_der_inv" 

283 v_der_inv.attrs = {"long_name": "inverse marginal value function"} 

284 

285 dataset = xr.Dataset( 

286 { 

287 "cNrm": cNrm, 

288 "v": v, 

289 "v_der": v_der, 

290 "v_inv": v_inv, 

291 "v_der_inv": v_der_inv, 

292 } 

293 ) 

294 

295 vfunc = ValueFuncCRRALabeled(dataset[["v", "v_der", "v_inv", "v_der_inv"]], CRRA) 

296 

297 solution_terminal = ConsumerSolutionLabeled( 

298 value=vfunc, 

299 policy=dataset[["cNrm"]], 

300 continuation=None, 

301 attrs={"m_nrm_min": 0.0}, # minimum normalized market resources 

302 ) 

303 return solution_terminal 

304 

305 

306def make_labeled_inc_shk_dstn( 

307 T_cycle, 

308 PermShkStd, 

309 PermShkCount, 

310 TranShkStd, 

311 TranShkCount, 

312 T_retire, 

313 UnempPrb, 

314 IncUnemp, 

315 UnempPrbRet, 

316 IncUnempRet, 

317 RNG, 

318 neutral_measure=False, 

319): 

320 """ 

321 Wrapper around construct_lognormal_income_process_unemployment that converts 

322 the IncShkDstn to a labeled version. 

323 """ 

324 IncShkDstnBase = construct_lognormal_income_process_unemployment( 

325 T_cycle, 

326 PermShkStd, 

327 PermShkCount, 

328 TranShkStd, 

329 TranShkCount, 

330 T_retire, 

331 UnempPrb, 

332 IncUnemp, 

333 UnempPrbRet, 

334 IncUnempRet, 

335 RNG, 

336 neutral_measure, 

337 ) 

338 IncShkDstn = [] 

339 for i in range(len(IncShkDstnBase.dstns)): 

340 IncShkDstn.append( 

341 DiscreteDistributionLabeled.from_unlabeled( 

342 IncShkDstnBase[i], 

343 name="Distribution of Shocks to Income", 

344 var_names=["perm", "tran"], 

345 ) 

346 ) 

347 return IncShkDstn 

348 

349 

350def make_labeled_risky_dstn(T_cycle, RiskyAvg, RiskyStd, RiskyCount, RNG): 

351 """ 

352 A wrapper around make_lognormal_RiskyDstn that makes it labeled. 

353 """ 

354 RiskyDstnBase = make_lognormal_RiskyDstn( 

355 T_cycle, RiskyAvg, RiskyStd, RiskyCount, RNG 

356 ) 

357 RiskyDstn = DiscreteDistributionLabeled.from_unlabeled( 

358 RiskyDstnBase, 

359 name="Distribution of Risky Asset Returns", 

360 var_names=["risky"], 

361 ) 

362 return RiskyDstn 

363 

364 

365def make_labeled_shock_dstn(T_cycle, IncShkDstn, RiskyDstn): 

366 """ 

367 A wrapper function that makes the joint distributions labeled. 

368 """ 

369 ShockDstnBase = combine_IncShkDstn_and_RiskyDstn(T_cycle, RiskyDstn, IncShkDstn) 

370 ShockDstn = [] 

371 for i in range(len(ShockDstnBase.dstns)): 

372 ShockDstn.append( 

373 DiscreteDistributionLabeled.from_unlabeled( 

374 ShockDstnBase[i], 

375 name="Distribution of Shocks to Income and Risky Asset Returns", 

376 var_names=["perm", "tran", "risky"], 

377 ) 

378 ) 

379 return ShockDstn 

380 

381 

382############################################################################### 

383 

384 

385class ConsPerfForesightLabeledSolver(ConsIndShockSetup): 

386 """ 

387 Solver for PerfForeshightLabeledType. 

388 """ 

389 

390 def create_params_namespace(self): 

391 """ 

392 Create a namespace for parameters. 

393 """ 

394 

395 self.params = SimpleNamespace( 

396 Discount=self.DiscFac * self.LivPrb, 

397 CRRA=self.CRRA, 

398 Rfree=self.Rfree, 

399 PermGroFac=self.PermGroFac, 

400 ) 

401 

402 def calculate_borrowing_constraint(self): 

403 """ 

404 Calculate the minimum allowable value of money resources in this period. 

405 """ 

406 

407 self.BoroCnstNat = ( 

408 self.solution_next.attrs["m_nrm_min"] - 1 

409 ) / self.params.Rfree 

410 

411 def define_boundary_constraint(self): 

412 """ 

413 If the natural borrowing constraint is a binding constraint, 

414 then we can not evaluate the value function at that point, 

415 so we must fill out the data by hand. 

416 """ 

417 

418 if self.BoroCnstArt is None or self.BoroCnstArt <= self.BoroCnstNat: 

419 self.m_nrm_min = self.BoroCnstNat 

420 self.nat_boro_cnst = True # natural borrowing constraint is binding 

421 

422 self.borocnst = xr.Dataset( 

423 coords={"mNrm": self.m_nrm_min, "aNrm": self.m_nrm_min}, 

424 data_vars={ 

425 "cNrm": 0.0, 

426 "v": -np.inf, 

427 "v_inv": 0.0, 

428 "reward": -np.inf, 

429 "marginal_reward": np.inf, 

430 "v_der": np.inf, 

431 "v_der_inv": 0.0, 

432 }, 

433 ) 

434 

435 elif self.BoroCnstArt > self.BoroCnstNat: 

436 self.m_nrm_min = self.BoroCnstArt 

437 self.nat_boro_cnst = False # artificial borrowing constraint is binding 

438 

439 self.borocnst = xr.Dataset( 

440 coords={"mNrm": self.m_nrm_min, "aNrm": self.m_nrm_min}, 

441 data_vars={"cNrm": 0.0}, 

442 ) 

443 

444 def create_post_state(self): 

445 """ 

446 Create the post state variable, which in this case is 

447 the normalized assets saved this period. 

448 """ 

449 

450 if self.nat_boro_cnst: 

451 # don't include natural borrowing constraint 

452 a_grid = self.aXtraGrid + self.m_nrm_min 

453 else: 

454 # include artificial borrowing constraint 

455 a_grid = np.append(0.0, self.aXtraGrid) + self.m_nrm_min 

456 

457 aVec = xr.DataArray( 

458 a_grid, 

459 name="aNrm", 

460 dims=("aNrm"), 

461 attrs={"long_name": "savings", "state": True}, 

462 ) 

463 post_state = xr.Dataset({"aNrm": aVec}) 

464 

465 self.post_state = post_state 

466 

467 def state_transition(self, state=None, action=None, params=None): 

468 """ 

469 State to post_state transition. 

470 

471 Parameters 

472 ---------- 

473 state : xr.Dataset 

474 State variables. 

475 action : xr.Dataset 

476 Action variables. 

477 params : SimpleNamespace 

478 Parameters. 

479 

480 Returns 

481 ------- 

482 post_state : xr.Dataset 

483 Post state variables. 

484 """ 

485 

486 post_state = {} # pytree 

487 post_state["aNrm"] = state["mNrm"] - action["cNrm"] 

488 return post_state 

489 

490 def post_state_transition(self, post_state=None, params=None): 

491 """ 

492 Post_state to next_state transition. 

493 

494 Parameters 

495 ---------- 

496 post_state : xr.Dataset 

497 Post state variables. 

498 params : SimpleNamespace 

499 Parameters. 

500 

501 Returns 

502 ------- 

503 next_state : xr.Dataset 

504 Next period's state variables. 

505 """ 

506 

507 next_state = {} # pytree 

508 next_state["mNrm"] = post_state["aNrm"] * params.Rfree / params.PermGroFac + 1 

509 return next_state 

510 

511 def reverse_transition(self, post_state=None, action=None, params=None): 

512 """ 

513 State from post state and actions. 

514 

515 Parameters 

516 ---------- 

517 post_state : xr.Dataset 

518 Post state variables. 

519 action : xr.Dataset 

520 Action variables. 

521 params : SimpleNamespace 

522 

523 Returns 

524 ------- 

525 state : xr.Dataset 

526 State variables. 

527 """ 

528 

529 state = {} # pytree 

530 state["mNrm"] = post_state["aNrm"] + action["cNrm"] 

531 

532 return state 

533 

534 def egm_transition(self, post_state=None, continuation=None, params=None): 

535 """ 

536 Actions from post state using the endogenous grid method. 

537 

538 Parameters 

539 ---------- 

540 post_state : xr.Dataset 

541 Post state variables. 

542 continuation : ValueFuncCRRALabeled 

543 Continuation value function, next period's value function. 

544 params : SimpleNamespace 

545 

546 Returns 

547 ------- 

548 action : xr.Dataset 

549 Action variables. 

550 """ 

551 

552 action = {} # pytree 

553 action["cNrm"] = self.u.derinv( 

554 params.Discount * continuation.derivative(post_state) 

555 ) 

556 

557 return action 

558 

559 def value_transition(self, action=None, state=None, continuation=None, params=None): 

560 """ 

561 Value of action given state and continuation 

562 

563 Parameters 

564 ---------- 

565 action : xr.Dataset 

566 Action variables. 

567 state : xr.Dataset 

568 State variables. 

569 continuation : ValueFuncCRRALabeled 

570 Continuation value function, next period's value function. 

571 params : SimpleNamespace 

572 Parameters 

573 

574 Returns 

575 ------- 

576 variables : xr.Dataset 

577 Value, marginal value, reward, marginal reward, and contributions. 

578 """ 

579 

580 variables = {} # pytree 

581 post_state = self.state_transition(state, action, params) 

582 variables.update(post_state) 

583 

584 variables["reward"] = self.u(action["cNrm"]) 

585 variables["v"] = variables["reward"] + params.Discount * continuation( 

586 post_state 

587 ) 

588 variables["v_inv"] = self.u.inv(variables["v"]) 

589 

590 variables["marginal_reward"] = self.u.der(action["cNrm"]) 

591 variables["v_der"] = variables["marginal_reward"] 

592 variables["v_der_inv"] = action["cNrm"] 

593 

594 # for estimagic purposes 

595 variables["contributions"] = variables["v"] 

596 variables["value"] = np.sum(variables["v"]) 

597 

598 return variables 

599 

600 def continuation_transition(self, post_state=None, value_next=None, params=None): 

601 """ 

602 Continuation value function of post state. 

603 

604 Parameters 

605 ---------- 

606 post_state : xr.Dataset 

607 Post state variables. 

608 value_next : ValueFuncCRRALabeled 

609 Next period's value function. 

610 params : SimpleNamespace 

611 Parameters. 

612 

613 Returns 

614 ------- 

615 variables : xr.Dataset 

616 Value, marginal value, inverse value, and inverse marginal value. 

617 """ 

618 

619 variables = {} # pytree 

620 next_state = self.post_state_transition(post_state, params) 

621 variables.update(next_state) 

622 variables["v"] = params.PermGroFac ** (1 - params.CRRA) * value_next(next_state) 

623 variables["v_der"] = ( 

624 params.Rfree 

625 * params.PermGroFac ** (-params.CRRA) 

626 * value_next.derivative(next_state) 

627 ) 

628 

629 variables["v_inv"] = self.u.inv(variables["v"]) 

630 variables["v_der_inv"] = self.u.derinv(variables["v_der"]) 

631 

632 # for estimagic purposes 

633 variables["contributions"] = variables["v"] 

634 variables["value"] = np.sum(variables["v"]) 

635 

636 return variables 

637 

638 def prepare_to_solve(self): 

639 """ 

640 Prepare to solve the model by creating the parameters namespace, 

641 calculating the borrowing constraint, defining the boundary constraint, 

642 and creating the post state. 

643 """ 

644 

645 self.create_params_namespace() 

646 self.calculate_borrowing_constraint() 

647 self.define_boundary_constraint() 

648 self.create_post_state() 

649 

650 def create_continuation_function(self): 

651 """ 

652 Create the continuation function, or the value function 

653 of every possible post state. 

654 

655 Returns 

656 ------- 

657 wfunc : ValueFuncCRRALabeled 

658 Continuation function. 

659 """ 

660 

661 # unpack next period's solution 

662 vfunc_next = self.solution_next.value 

663 

664 v_end = self.continuation_transition(self.post_state, vfunc_next, self.params) 

665 # need to drop m because it's next period's m 

666 v_end = xr.Dataset(v_end).drop(["mNrm"]) 

667 borocnst = self.borocnst.drop(["mNrm"]).expand_dims("aNrm") 

668 if self.nat_boro_cnst: 

669 v_end = xr.merge([borocnst, v_end]) 

670 

671 wfunc = ValueFuncCRRALabeled(v_end, self.params.CRRA) 

672 

673 return wfunc 

674 

675 def endogenous_grid_method(self): 

676 """ 

677 Solve the model using the endogenous grid method, which consists of 

678 solving the model backwards in time using the following steps: 

679 

680 1. Create the continuation function, or the value function of every 

681 possible post state. 

682 2. Get the optimal actions/decisions from the endogenous grid transition. 

683 3. Get the state from the actions and post state using the reverse transition. 

684 4. EGM requires swapping dimensions; make actions and state functions of state. 

685 5. Merge the actions and state into a single dataset. 

686 6. If the natural borrowing constraint is not used, concatenate the 

687 borrowing constraint to the dataset. 

688 7. Create the value function from the variables in the dataset. 

689 8. Create the policy function from the variables in the dataset. 

690 9. Create the solution from the value and policy functions. 

691 """ 

692 wfunc = self.create_continuation_function() 

693 

694 # get optimal actions/decisions from egm 

695 acted = self.egm_transition(self.post_state, wfunc, self.params) 

696 # get state from actions and post_state 

697 state = self.reverse_transition(self.post_state, acted, self.params) 

698 

699 # egm requires swap dimensions; make actions and state functions of state 

700 action = xr.Dataset(acted).swap_dims({"aNrm": "mNrm"}) 

701 state = xr.Dataset(state).swap_dims({"aNrm": "mNrm"}) 

702 

703 egm_dataset = xr.merge([action, state]) 

704 

705 if not self.nat_boro_cnst: 

706 egm_dataset = xr.concat([self.borocnst, egm_dataset], dim="mNrm") 

707 

708 values = self.value_transition(egm_dataset, egm_dataset, wfunc, self.params) 

709 egm_dataset.update(values) 

710 

711 if self.nat_boro_cnst: 

712 egm_dataset = xr.concat( 

713 [self.borocnst, egm_dataset], dim="mNrm", combine_attrs="no_conflicts" 

714 ) 

715 

716 egm_dataset = egm_dataset.drop("aNrm") 

717 

718 vfunc = ValueFuncCRRALabeled( 

719 egm_dataset[["v", "v_der", "v_inv", "v_der_inv"]], self.params.CRRA 

720 ) 

721 pfunc = egm_dataset[["cNrm"]] 

722 

723 self.solution = ConsumerSolutionLabeled( 

724 value=vfunc, 

725 policy=pfunc, 

726 continuation=wfunc, 

727 attrs={"m_nrm_min": self.m_nrm_min, "dataset": egm_dataset}, 

728 ) 

729 

730 def solve(self): 

731 """ 

732 Solve the model by endogenous grid method. 

733 """ 

734 

735 self.endogenous_grid_method() 

736 

737 return self.solution 

738 

739 

740############################################################################### 

741 

742init_perf_foresight_labeled = init_idiosyncratic_shocks.copy() 

743init_perf_foresight_labeled.update(init_perfect_foresight) 

744PF_labeled_constructor_dict = init_idiosyncratic_shocks["constructors"].copy() 

745PF_labeled_constructor_dict["solution_terminal"] = make_solution_terminal_labeled 

746PF_labeled_constructor_dict["aXtraGrid"] = make_assets_grid 

747init_perf_foresight_labeled["constructors"] = PF_labeled_constructor_dict 

748init_perf_foresight_labeled.update(IndShockConsumerType_aXtraGrid_default) 

749 

750############################################################################### 

751 

752 

753class PerfForesightLabeledType(IndShockConsumerType): 

754 """ 

755 A labeled perfect foresight consumer type. This class is a subclass of 

756 IndShockConsumerType, and inherits all of its methods and attributes. 

757 

758 Perfect foresight consumers have no uncertainty about income or interest 

759 rates, and so the only state variable is market resources m. 

760 """ 

761 

762 default_ = { 

763 "params": init_perf_foresight_labeled, 

764 "solver": make_one_period_oo_solver(ConsPerfForesightLabeledSolver), 

765 "model": "ConsPerfForesight.yaml", 

766 } 

767 

768 def post_solve(self): 

769 pass # Do nothing, rather than try to run calc_stable_points 

770 

771 

772############################################################################### 

773 

774 

775class ConsIndShockLabeledSolver(ConsPerfForesightLabeledSolver): 

776 """ 

777 Solver for IndShockLabeledType. 

778 """ 

779 

780 def calculate_borrowing_constraint(self): 

781 """ 

782 Calculate the minimum allowable value of money resources in this period. 

783 This is different from the perfect foresight natural borrowing constraint 

784 because of the presence of income uncertainty. 

785 """ 

786 

787 PermShkMinNext = np.min(self.IncShkDstn.atoms[0]) 

788 TranShkMinNext = np.min(self.IncShkDstn.atoms[1]) 

789 

790 self.BoroCnstNat = ( 

791 (self.solution_next.attrs["m_nrm_min"] - TranShkMinNext) 

792 * (self.params.PermGroFac * PermShkMinNext) 

793 / self.params.Rfree 

794 ) 

795 

796 def post_state_transition(self, post_state=None, shocks=None, params=None): 

797 """ 

798 Post state to next state transition now depends on income shocks. 

799 

800 Parameters 

801 ---------- 

802 post_state : dict 

803 Post state variables. 

804 shocks : dict 

805 Shocks to income. 

806 params : dict 

807 Parameters. 

808 

809 Returns 

810 ------- 

811 next_state : dict 

812 Next period's state variables. 

813 """ 

814 

815 next_state = {} # pytree 

816 next_state["mNrm"] = ( 

817 post_state["aNrm"] * params.Rfree / (params.PermGroFac * shocks["perm"]) 

818 + shocks["tran"] 

819 ) 

820 return next_state 

821 

822 def continuation_transition( 

823 self, shocks=None, post_state=None, v_next=None, params=None 

824 ): 

825 """ 

826 Continuation value function of post state. 

827 

828 Parameters 

829 ---------- 

830 shocks : dict 

831 Shocks to income. 

832 post_state : dict 

833 Post state variables. 

834 v_next : ValueFuncCRRALabeled 

835 Next period's value function. 

836 params : dict 

837 Parameters. 

838 

839 Returns 

840 ------- 

841 variables : dict 

842 Continuation value function and its derivative. 

843 """ 

844 

845 variables = {} # pytree 

846 next_state = self.post_state_transition(post_state, shocks, params) 

847 variables.update(next_state) 

848 

849 variables["psi"] = params.PermGroFac * shocks["perm"] 

850 

851 variables["v"] = variables["psi"] ** (1 - params.CRRA) * v_next(next_state) 

852 

853 variables["v_der"] = ( 

854 params.Rfree 

855 * variables["psi"] ** (-params.CRRA) 

856 * v_next.derivative(next_state) 

857 ) 

858 

859 # for estimagic purposes 

860 

861 variables["contributions"] = variables["v"] 

862 variables["value"] = np.sum(variables["v"]) 

863 

864 return variables 

865 

866 def create_continuation_function(self): 

867 """ 

868 Create the continuation function. Because of the income uncertainty 

869 in this model, we need to integrate over the income shocks to get the 

870 continuation value function. Depending on the natural borrowing constraint, 

871 we may also have to append the minimum allowable value of money resources. 

872 

873 Returns 

874 ------- 

875 wfunc : ValueFuncCRRALabeled 

876 Continuation value function. 

877 """ 

878 

879 # unpack next period's solution 

880 vfunc_next = self.solution_next.value 

881 

882 v_end = self.IncShkDstn.expected( 

883 func=self.continuation_transition, 

884 post_state=self.post_state, 

885 v_next=vfunc_next, 

886 params=self.params, 

887 ) 

888 

889 v_end["v_inv"] = self.u.inv(v_end["v"]) 

890 v_end["v_der_inv"] = self.u.derinv(v_end["v_der"]) 

891 

892 borocnst = self.borocnst.drop(["mNrm"]).expand_dims("aNrm") 

893 if self.nat_boro_cnst: 

894 v_end = xr.merge([borocnst, v_end]) 

895 

896 # need to drop m because it's next period's m 

897 # v_end = xr.Dataset(v_end).drop(["mNrm"]) 

898 wfunc = ValueFuncCRRALabeled(v_end, self.params.CRRA) 

899 

900 return wfunc 

901 

902 

903############################################################################### 

904 

905init_ind_shock_labeled = init_perf_foresight_labeled.copy() 

906ind_shock_labeled_constructor_dict = PF_labeled_constructor_dict.copy() 

907ind_shock_labeled_constructor_dict["IncShkDstn"] = make_labeled_inc_shk_dstn 

908init_ind_shock_labeled["constructors"] = ind_shock_labeled_constructor_dict 

909 

910 

911class IndShockLabeledType(PerfForesightLabeledType): 

912 """ 

913 A labeled version of IndShockConsumerType. This class inherits from 

914 PerfForesightLabeledType and adds income uncertainty. 

915 """ 

916 

917 default_ = { 

918 "params": init_ind_shock_labeled, 

919 "solver": make_one_period_oo_solver(ConsIndShockLabeledSolver), 

920 "model": "ConsIndShock.yaml", 

921 } 

922 

923 

924############################################################################### 

925 

926 

927@dataclass 

928class ConsRiskyAssetLabeledSolver(ConsIndShockLabeledSolver): 

929 """ 

930 Solver for an agent that can save in an asset that has a risky return. 

931 """ 

932 

933 solution_next: ConsumerSolutionLabeled # solution to next period's problem 

934 ShockDstn: ( 

935 DiscreteDistributionLabeled # distribution of shocks to income and returns 

936 ) 

937 LivPrb: float # survival probability 

938 DiscFac: float # intertemporal discount factor 

939 CRRA: float # coefficient of relative risk aversion 

940 Rfree: float # interest factor on assets 

941 PermGroFac: float # permanent income growth factor 

942 BoroCnstArt: float # artificial borrowing constraint 

943 aXtraGrid: np.ndarray # grid of end-of-period assets 

944 

945 def __post_init__(self): 

946 """ 

947 Define utility functions. 

948 """ 

949 

950 self.def_utility_funcs() 

951 

952 def calculate_borrowing_constraint(self): 

953 """ 

954 Calculate the borrowing constraint by enforcing a 0.0 artificial borrowing 

955 constraint and setting the shocks to income to come from the shock distribution. 

956 """ 

957 self.BoroCnstArt = 0.0 

958 self.IncShkDstn = self.ShockDstn 

959 return super().calculate_borrowing_constraint() 

960 

961 def post_state_transition(self, post_state=None, shocks=None, params=None): 

962 """ 

963 Post_state to next_state transition with risky asset return. 

964 

965 Parameters 

966 ---------- 

967 post_state : dict 

968 Post-state variables. 

969 shocks : dict 

970 Shocks to income and risky asset return. 

971 params : dict 

972 Parameters of the model. 

973 

974 Returns 

975 ------- 

976 next_state : dict 

977 Next period's state variables. 

978 """ 

979 

980 next_state = {} # pytree 

981 next_state["mNrm"] = ( 

982 post_state["aNrm"] * shocks["risky"] / (params.PermGroFac * shocks["perm"]) 

983 + shocks["tran"] 

984 ) 

985 return next_state 

986 

987 def continuation_transition( 

988 self, shocks=None, post_state=None, v_next=None, params=None 

989 ): 

990 """ 

991 Continuation value function of post_state with risky asset return. 

992 

993 Parameters 

994 ---------- 

995 shocks : dict 

996 Shocks to income and risky asset return. 

997 post_state : dict 

998 Post-state variables. 

999 v_next : function 

1000 Value function of next period. 

1001 params : dict 

1002 Parameters of the model. 

1003 

1004 Returns 

1005 ------- 

1006 variables : dict 

1007 Variables of the continuation value function. 

1008 """ 

1009 

1010 variables = {} # pytree 

1011 next_state = self.post_state_transition(post_state, shocks, params) 

1012 variables.update(next_state) 

1013 

1014 variables["psi"] = params.PermGroFac * shocks["perm"] 

1015 

1016 variables["v"] = variables["psi"] ** (1 - params.CRRA) * v_next(next_state) 

1017 

1018 variables["v_der"] = ( 

1019 shocks["risky"] 

1020 * variables["psi"] ** (-params.CRRA) 

1021 * v_next.derivative(next_state) 

1022 ) 

1023 

1024 # for estimagic purposes 

1025 

1026 variables["contributions"] = variables["v"] 

1027 variables["value"] = np.sum(variables["v"]) 

1028 

1029 return variables 

1030 

1031 def create_continuation_function(self): 

1032 """ 

1033 Create the continuation value function taking expectation 

1034 over the shock distribution which includes shocks to income and 

1035 the risky asset return. 

1036 

1037 Returns 

1038 ------- 

1039 wfunc : ValueFuncCRRALabeled 

1040 Continuation value function. 

1041 """ 

1042 # unpack next period's solution 

1043 vfunc_next = self.solution_next.value 

1044 

1045 v_end = self.ShockDstn.expected( 

1046 func=self.continuation_transition, 

1047 post_state=self.post_state, 

1048 v_next=vfunc_next, 

1049 params=self.params, 

1050 ) 

1051 

1052 v_end["v_inv"] = self.u.inv(v_end["v"]) 

1053 v_end["v_der_inv"] = self.u.derinv(v_end["v_der"]) 

1054 

1055 borocnst = self.borocnst.drop(["mNrm"]).expand_dims("aNrm") 

1056 if self.nat_boro_cnst: 

1057 v_end = xr.merge([borocnst, v_end]) 

1058 

1059 v_end = v_end.transpose("aNrm", ...) 

1060 

1061 # need to drop m because it's next period's m 

1062 # v_end = xr.Dataset(v_end).drop(["mNrm"]) 

1063 wfunc = ValueFuncCRRALabeled(v_end, self.params.CRRA) 

1064 

1065 return wfunc 

1066 

1067 

1068############################################################################### 

1069 

1070risky_asset_labeled_constructor_dict = ( 

1071 IndShockRiskyAssetConsumerType_constructor_default.copy() 

1072) 

1073risky_asset_labeled_constructor_dict["IncShkDstn"] = make_labeled_inc_shk_dstn 

1074risky_asset_labeled_constructor_dict["RiskyDstn"] = make_labeled_risky_dstn 

1075risky_asset_labeled_constructor_dict["ShockDstn"] = make_labeled_shock_dstn 

1076risky_asset_labeled_constructor_dict["solution_terminal"] = ( 

1077 make_solution_terminal_labeled 

1078) 

1079del risky_asset_labeled_constructor_dict["solve_one_period"] 

1080init_risky_asset_labeled = init_risky_asset.copy() 

1081init_risky_asset_labeled["constructors"] = risky_asset_labeled_constructor_dict 

1082 

1083############################################################################### 

1084 

1085 

1086class RiskyAssetLabeledType(IndShockLabeledType, RiskyAssetConsumerType): 

1087 """ 

1088 A labeled RiskyAssetConsumerType. This class is a subclass of 

1089 RiskyAssetConsumerType, and inherits all of its methods and attributes. 

1090 

1091 Risky asset consumers can only save on a risky asset that 

1092 pays a stochastic return. 

1093 """ 

1094 

1095 default_ = { 

1096 "params": init_risky_asset_labeled, 

1097 "solver": make_one_period_oo_solver(ConsRiskyAssetLabeledSolver), 

1098 "model": "ConsRiskyAsset.yaml", 

1099 } 

1100 

1101 

1102############################################################################### 

1103 

1104 

1105@dataclass 

1106class ConsFixedPortfolioLabeledSolver(ConsRiskyAssetLabeledSolver): 

1107 """ 

1108 Solver for an agent that can save in a risk-free and risky asset 

1109 at a fixed proportion. 

1110 """ 

1111 

1112 RiskyShareFixed: float # share of risky assets in portfolio 

1113 

1114 def create_params_namespace(self): 

1115 """ 

1116 Create a namespace for parameters. 

1117 """ 

1118 

1119 self.params = SimpleNamespace( 

1120 Discount=self.DiscFac * self.LivPrb, 

1121 CRRA=self.CRRA, 

1122 Rfree=self.Rfree, 

1123 PermGroFac=self.PermGroFac, 

1124 RiskyShareFixed=self.RiskyShareFixed, 

1125 ) 

1126 

1127 def post_state_transition(self, post_state=None, shocks=None, params=None): 

1128 """ 

1129 Post_state to next_state transition with fixed portfolio share. 

1130 

1131 Parameters 

1132 ---------- 

1133 post_state : dict 

1134 Post-state variables. 

1135 shocks : dict 

1136 Shocks to income and risky asset return. 

1137 params : dict 

1138 Parameters of the model. 

1139 

1140 Returns 

1141 ------- 

1142 next_state : dict 

1143 Next period's state variables. 

1144 """ 

1145 

1146 next_state = {} # pytree 

1147 next_state["rDiff"] = params.Rfree - shocks["risky"] 

1148 next_state["rPort"] = ( 

1149 params.Rfree + next_state["rDiff"] * params.RiskyShareFixed 

1150 ) 

1151 next_state["mNrm"] = ( 

1152 post_state["aNrm"] 

1153 * next_state["rPort"] 

1154 / (params.PermGroFac * shocks["perm"]) 

1155 + shocks["tran"] 

1156 ) 

1157 return next_state 

1158 

1159 def continuation_transition( 

1160 self, shocks=None, post_state=None, v_next=None, params=None 

1161 ): 

1162 """ 

1163 Continuation value function of post_state with fixed portfolio share. 

1164 

1165 Parameters 

1166 ---------- 

1167 shocks : dict 

1168 Shocks to income and risky asset return. 

1169 post_state : dict 

1170 Post-state variables. 

1171 v_next : ValueFuncCRRALabeled 

1172 Continuation value function. 

1173 params : dict 

1174 Parameters of the model. 

1175 

1176 Returns 

1177 ------- 

1178 variables : dict 

1179 Variables of the model. 

1180 """ 

1181 

1182 variables = {} # pytree 

1183 next_state = self.post_state_transition(post_state, shocks, params) 

1184 variables.update(next_state) 

1185 

1186 variables["psi"] = params.PermGroFac * shocks["perm"] 

1187 

1188 variables["v"] = variables["psi"] ** (1 - params.CRRA) * v_next(next_state) 

1189 

1190 variables["v_der"] = ( 

1191 next_state["rPort"] 

1192 * variables["psi"] ** (-params.CRRA) 

1193 * v_next.derivative(next_state) 

1194 ) 

1195 

1196 # for estimagic purposes 

1197 

1198 variables["contributions"] = variables["v"] 

1199 variables["value"] = np.sum(variables["v"]) 

1200 

1201 return variables 

1202 

1203 

1204############################################################################### 

1205 

1206 

1207@dataclass 

1208class ConsPortfolioLabeledSolver(ConsFixedPortfolioLabeledSolver): 

1209 """ 

1210 Solver for an agent that can save in a risk-free and risky asset 

1211 at an optimal proportion. 

1212 """ 

1213 

1214 ShareGrid: np.ndarray # grid of risky shares 

1215 

1216 def create_post_state(self): 

1217 """ 

1218 Create post-state variables by adding risky share, called 

1219 stigma, to the post-state variables. 

1220 """ 

1221 

1222 super().create_post_state() 

1223 

1224 self.post_state["stigma"] = xr.DataArray( 

1225 self.ShareGrid, dims=["stigma"], attrs={"long_name": "risky share"} 

1226 ) 

1227 

1228 def post_state_transition(self, post_state=None, shocks=None, params=None): 

1229 """ 

1230 Post_state to next_state transition with optimal portfolio share. 

1231 

1232 Parameters 

1233 ---------- 

1234 post_state : dict 

1235 Post-state variables. 

1236 shocks : dict 

1237 Shocks to income and risky asset return. 

1238 params : dict 

1239 Parameters of the model. 

1240 

1241 Returns 

1242 ------- 

1243 next_state : dict 

1244 Next period's state variables. 

1245 """ 

1246 

1247 next_state = {} # pytree 

1248 next_state["rDiff"] = shocks["risky"] - params.Rfree 

1249 next_state["rPort"] = params.Rfree + next_state["rDiff"] * post_state["stigma"] 

1250 next_state["mNrm"] = ( 

1251 post_state["aNrm"] 

1252 * next_state["rPort"] 

1253 / (params.PermGroFac * shocks["perm"]) 

1254 + shocks["tran"] 

1255 ) 

1256 return next_state 

1257 

1258 def continuation_transition( 

1259 self, shocks=None, post_state=None, v_next=None, params=None 

1260 ): 

1261 """ 

1262 Continuation value function of post_state with optimal portfolio share. 

1263 

1264 Parameters 

1265 ---------- 

1266 shocks : dict 

1267 Shocks to income and risky asset return. 

1268 post_state : dict 

1269 Post-state variables. 

1270 v_next : ValueFuncCRRALabeled 

1271 Continuation value function. 

1272 params : dict 

1273 Parameters of the model. 

1274 

1275 Returns 

1276 ------- 

1277 variables : dict 

1278 Variables of the model. 

1279 """ 

1280 

1281 variables = {} # pytree 

1282 next_state = self.post_state_transition(post_state, shocks, params) 

1283 variables.update(next_state) 

1284 

1285 variables["psi"] = params.PermGroFac * shocks["perm"] 

1286 

1287 variables["v"] = variables["psi"] ** (1 - params.CRRA) * v_next(next_state) 

1288 

1289 variables["v_der"] = variables["psi"] ** (-params.CRRA) * v_next.derivative( 

1290 next_state 

1291 ) 

1292 

1293 variables["dvda"] = next_state["rPort"] * variables["v_der"] 

1294 variables["dvds"] = ( 

1295 next_state["rDiff"] * post_state["aNrm"] * variables["v_der"] 

1296 ) 

1297 

1298 # for estimagic purposes 

1299 

1300 variables["contributions"] = variables["v"] 

1301 variables["value"] = np.sum(variables["v"]) 

1302 

1303 return variables 

1304 

1305 def create_continuation_function(self): 

1306 """ 

1307 Create continuation function with optimal portfolio share. 

1308 The continuation function is a function of the post-state before 

1309 the growth period, but only a function of assets in the 

1310 allocation period. 

1311 

1312 Therefore, the first continuation function is a function of 

1313 assets and stigma. Given this, the agent makes an optimal 

1314 choice of risky share of portfolio, and the second continuation 

1315 function is a function of assets only. 

1316 

1317 Returns 

1318 ------- 

1319 wfunc : ValueFuncCRRALabeled 

1320 Continuation value function. 

1321 """ 

1322 

1323 wfunc = super().create_continuation_function() 

1324 

1325 dvds = wfunc.dataset["dvds"].values 

1326 

1327 # For each value of aNrm, find the value of Share such that FOC-Share == 0. 

1328 crossing = np.logical_and(dvds[:, 1:] <= 0.0, dvds[:, :-1] >= 0.0) 

1329 share_idx = np.argmax(crossing, axis=1) 

1330 a_idx = np.arange(self.post_state["aNrm"].size) 

1331 bot_s = self.ShareGrid[share_idx] 

1332 top_s = self.ShareGrid[share_idx + 1] 

1333 bot_f = dvds[a_idx, share_idx] 

1334 top_f = dvds[a_idx, share_idx + 1] 

1335 alpha = 1.0 - top_f / (top_f - bot_f) 

1336 opt_share = (1.0 - alpha) * bot_s + alpha * top_s 

1337 

1338 # If agent wants to put more than 100% into risky asset, he is constrained 

1339 # For values of aNrm at which the agent wants to put 

1340 # more than 100% into risky asset, constrain them 

1341 opt_share[dvds[:, -1] > 0.0] = 1.0 

1342 # Likewise if he wants to put less than 0% into risky asset 

1343 opt_share[dvds[:, 0] < 0.0] = 0.0 

1344 

1345 if not self.nat_boro_cnst: 

1346 # aNrm=0, so there's no way to "optimize" the portfolio 

1347 opt_share[0] = 1.0 

1348 

1349 opt_share = xr.DataArray( 

1350 opt_share, 

1351 coords={"aNrm": self.post_state["aNrm"].values}, 

1352 dims=["aNrm"], 

1353 attrs={"long_name": "optimal risky share"}, 

1354 ) 

1355 

1356 v_end = wfunc.evaluate({"aNrm": self.post_state["aNrm"], "stigma": opt_share}) 

1357 

1358 v_end = v_end.reset_coords(names="stigma") 

1359 

1360 wfunc = ValueFuncCRRALabeled(v_end, self.params.CRRA) 

1361 

1362 self.post_state = self.post_state.drop("stigma") 

1363 

1364 return wfunc 

1365 

1366 

1367############################################################################### 

1368 

1369init_portfolio_labeled = init_portfolio.copy() 

1370init_portfolio_labeled_constructors = init_portfolio["constructors"].copy() 

1371init_portfolio_labeled_constructors["IncShkDstn"] = make_labeled_inc_shk_dstn 

1372init_portfolio_labeled_constructors["RiskyDstn"] = make_labeled_risky_dstn 

1373init_portfolio_labeled_constructors["ShockDstn"] = make_labeled_shock_dstn 

1374init_portfolio_labeled_constructors["solution_terminal"] = ( 

1375 make_solution_terminal_labeled 

1376) 

1377init_portfolio_labeled["constructors"] = init_portfolio_labeled_constructors 

1378init_portfolio_labeled["RiskyShareFixed"] = [0.0] # This shouldn't exist 

1379 

1380 

1381class PortfolioLabeledType(PortfolioConsumerType): 

1382 """ 

1383 A labeled PortfolioConsumerType. This class is a subclass of 

1384 PortfolioConsumerType, and inherits all of its methods and attributes. 

1385 

1386 Portfolio consumers can save on a risk-free and 

1387 risky asset at an optimal proportion. 

1388 """ 

1389 

1390 default_ = { 

1391 "params": init_portfolio_labeled, 

1392 "solver": make_one_period_oo_solver(ConsPortfolioLabeledSolver), 

1393 "model": "ConsPortfolio.yaml", 

1394 }