Coverage for HARK / Labeled / solvers.py: 91%

225 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-25 05:22 +0000

1""" 

2Solvers for labeled consumption-saving models. 

3 

4This module implements the Template Method pattern for the Endogenous Grid 

5Method (EGM) algorithm. The base solver defines the algorithm skeleton, 

6while concrete solvers override specific hook methods for their model type. 

7""" 

8 

9from __future__ import annotations 

10 

11import warnings 

12from types import SimpleNamespace 

13from typing import TYPE_CHECKING, Any 

14 

15import numpy as np 

16import xarray as xr 

17 

18from HARK.metric import MetricObject 

19from HARK.rewards import UtilityFuncCRRA 

20 

21from .solution import ConsumerSolutionLabeled, ValueFuncCRRALabeled 

22from .transitions import ( 

23 FixedPortfolioTransitions, 

24 IndShockTransitions, 

25 PerfectForesightTransitions, 

26 PortfolioTransitions, 

27 RiskyAssetTransitions, 

28) 

29 

30if TYPE_CHECKING: 

31 from HARK.distributions import DiscreteDistributionLabeled 

32 

33__all__ = [ 

34 "BaseLabeledSolver", 

35 "ConsPerfForesightLabeledSolver", 

36 "ConsIndShockLabeledSolver", 

37 "ConsRiskyAssetLabeledSolver", 

38 "ConsFixedPortfolioLabeledSolver", 

39 "ConsPortfolioLabeledSolver", 

40] 

41 

42 

43class BaseLabeledSolver(MetricObject): 

44 """ 

45 Base solver implementing Template Method pattern for EGM algorithm. 

46 

47 This class provides the algorithm skeleton for solving consumption-saving 

48 problems using the Endogenous Grid Method. Subclasses customize behavior 

49 by: 

50 1. Setting TransitionsClass to specify model-specific transitions 

51 2. Overriding hook methods for model-specific logic 

52 

53 Template Method: solve() 

54 Hook Methods: 

55 - create_params_namespace(): Add model-specific parameters 

56 - calculate_borrowing_constraint(): Model-specific constraint logic 

57 - create_post_state(): Add extra state dimensions (e.g., stigma) 

58 - create_continuation_function(): Handle shock integration 

59 

60 Parameters 

61 ---------- 

62 solution_next : ConsumerSolutionLabeled 

63 Solution to next period's problem. 

64 LivPrb : float 

65 Survival probability. 

66 DiscFac : float 

67 Intertemporal discount factor. 

68 CRRA : float 

69 Coefficient of relative risk aversion. 

70 Rfree : float 

71 Risk-free interest factor. 

72 PermGroFac : float 

73 Permanent income growth factor. 

74 BoroCnstArt : float or None 

75 Artificial borrowing constraint. 

76 aXtraGrid : np.ndarray 

77 Grid of end-of-period asset values above minimum. 

78 **kwargs 

79 Additional model-specific parameters. 

80 

81 Raises 

82 ------ 

83 ValueError 

84 If CRRA is invalid or aXtraGrid is malformed. 

85 """ 

86 

87 # Class-level strategy specification - override in subclasses 

88 TransitionsClass: type = PerfectForesightTransitions 

89 

90 def __init__( 

91 self, 

92 solution_next: ConsumerSolutionLabeled, 

93 LivPrb: float, 

94 DiscFac: float, 

95 CRRA: float, 

96 Rfree: float, 

97 PermGroFac: float, 

98 BoroCnstArt: float | None, 

99 aXtraGrid: np.ndarray, 

100 **kwargs, 

101 ) -> None: 

102 # Input validation - solution_next 

103 if solution_next is None: 

104 raise ValueError("solution_next cannot be None") 

105 if not isinstance(solution_next, ConsumerSolutionLabeled): 

106 raise TypeError( 

107 f"solution_next must be ConsumerSolutionLabeled, got {type(solution_next)}" 

108 ) 

109 if "m_nrm_min" not in solution_next.attrs: 

110 raise ValueError( 

111 "solution_next.attrs must contain 'm_nrm_min'. " 

112 "Use make_solution_terminal_labeled() to create valid terminal solutions." 

113 ) 

114 

115 # Input validation - CRRA 

116 if not np.isfinite(CRRA): 

117 raise ValueError(f"CRRA must be finite, got {CRRA}") 

118 if CRRA < 0: 

119 raise ValueError(f"CRRA must be non-negative, got {CRRA}") 

120 

121 # Input validation - economic parameters 

122 if LivPrb <= 0 or LivPrb > 1: 

123 raise ValueError(f"LivPrb must be in (0, 1], got {LivPrb}") 

124 if DiscFac <= 0: 

125 raise ValueError(f"DiscFac must be positive, got {DiscFac}") 

126 if Rfree <= 0: 

127 raise ValueError(f"Rfree must be positive, got {Rfree}") 

128 if PermGroFac <= 0: 

129 raise ValueError(f"PermGroFac must be positive, got {PermGroFac}") 

130 

131 # Input validation - asset grid 

132 aXtraGrid = np.asarray(aXtraGrid) 

133 if len(aXtraGrid) == 0: 

134 raise ValueError("aXtraGrid cannot be empty") 

135 if np.any(aXtraGrid < 0): 

136 raise ValueError("aXtraGrid values must be non-negative") 

137 if not np.all(np.diff(aXtraGrid) > 0): 

138 raise ValueError("aXtraGrid must be strictly increasing") 

139 

140 # Store parameters 

141 self.solution_next = solution_next 

142 self.LivPrb = LivPrb 

143 self.DiscFac = DiscFac 

144 self.CRRA = CRRA 

145 self.Rfree = Rfree 

146 self.PermGroFac = PermGroFac 

147 self.BoroCnstArt = BoroCnstArt 

148 self.aXtraGrid = aXtraGrid 

149 

150 # Initialize utility function 

151 self.u = UtilityFuncCRRA(CRRA) 

152 

153 # Initialize transitions strategy 

154 self.transitions = self.TransitionsClass() 

155 

156 # Store additional kwargs 

157 for key, value in kwargs.items(): 

158 setattr(self, key, value) 

159 

160 # ========================================================================= 

161 # TEMPLATE METHOD - The algorithm skeleton 

162 # ========================================================================= 

163 

164 def solve(self) -> ConsumerSolutionLabeled: 

165 """ 

166 Solve the consumption-saving problem using EGM. 

167 

168 This is the template method that defines the algorithm skeleton. 

169 It calls hook methods that subclasses can override. 

170 

171 Returns 

172 ------- 

173 ConsumerSolutionLabeled 

174 Solution containing value function, policy, and continuation. 

175 """ 

176 self.prepare_to_solve() 

177 self.endogenous_grid_method() 

178 return self.solution 

179 

180 # ========================================================================= 

181 # HOOK METHODS - Override in subclasses for customization 

182 # ========================================================================= 

183 

184 def create_params_namespace(self) -> SimpleNamespace: 

185 """ 

186 Create parameters namespace. 

187 

188 Override in subclasses to add model-specific parameters. 

189 

190 Returns 

191 ------- 

192 SimpleNamespace 

193 Parameters for this period's problem. 

194 """ 

195 return SimpleNamespace( 

196 Discount=self.DiscFac * self.LivPrb, 

197 CRRA=self.CRRA, 

198 Rfree=self.Rfree, 

199 PermGroFac=self.PermGroFac, 

200 ) 

201 

202 def calculate_borrowing_constraint(self) -> None: 

203 """ 

204 Calculate the natural borrowing constraint. 

205 

206 Override in shock models to account for minimum shock realizations. 

207 Sets self.BoroCnstNat. 

208 """ 

209 self.BoroCnstNat = ( 

210 (self.solution_next.attrs["m_nrm_min"] - 1) 

211 * self.params.PermGroFac 

212 / self.params.Rfree 

213 ) 

214 

215 def create_post_state(self) -> xr.Dataset: 

216 """ 

217 Create the post-decision state grid. 

218 

219 Override in portfolio models to add the risky share dimension. 

220 

221 Returns 

222 ------- 

223 xr.Dataset 

224 Post-decision state grid. 

225 """ 

226 if self.nat_boro_cnst: 

227 a_grid = self.aXtraGrid + self.m_nrm_min 

228 else: 

229 a_grid = np.append(0.0, self.aXtraGrid) + self.m_nrm_min 

230 

231 aVec = xr.DataArray( 

232 a_grid, 

233 name="aNrm", 

234 dims=("aNrm"), 

235 attrs={"long_name": "savings", "state": True}, 

236 ) 

237 return xr.Dataset({"aNrm": aVec}) 

238 

239 def create_continuation_function(self) -> ValueFuncCRRALabeled: 

240 """ 

241 Create the continuation value function. 

242 

243 Override in shock models to integrate over shock distributions. 

244 

245 Returns 

246 ------- 

247 ValueFuncCRRALabeled 

248 Continuation value function. 

249 """ 

250 value_next = self.solution_next.value 

251 

252 v_end = self.transitions.continuation( 

253 self.post_state, None, value_next, self.params, self.u 

254 ) 

255 v_end = xr.Dataset(v_end).drop_vars(["mNrm"]) 

256 

257 v_end["v_inv"] = self.u.inv(v_end["v"]) 

258 v_end["v_der_inv"] = self.u.derinv(v_end["v_der"]) 

259 

260 borocnst = self.borocnst.drop_vars(["mNrm"]).expand_dims("aNrm") 

261 if self.nat_boro_cnst: 

262 v_end = xr.merge([borocnst, v_end], join="outer", compat="no_conflicts") 

263 

264 return ValueFuncCRRALabeled(v_end, self.params.CRRA) 

265 

266 # ========================================================================= 

267 # CORE METHODS - Shared implementation, rarely overridden 

268 # ========================================================================= 

269 

270 def prepare_to_solve(self) -> None: 

271 """Prepare solver state before running EGM.""" 

272 self.params = self.create_params_namespace() 

273 self.calculate_borrowing_constraint() 

274 self.define_boundary_constraint() 

275 self.post_state = self.create_post_state() 

276 

277 def define_boundary_constraint(self) -> None: 

278 """Define borrowing constraint boundary conditions.""" 

279 if self.BoroCnstArt is None or self.BoroCnstArt <= self.BoroCnstNat: 

280 self.m_nrm_min = self.BoroCnstNat 

281 self.nat_boro_cnst = True 

282 self.borocnst = xr.Dataset( 

283 coords={"mNrm": self.m_nrm_min, "aNrm": self.m_nrm_min}, 

284 data_vars={ 

285 "cNrm": 0.0, 

286 "v": -np.inf, 

287 "v_inv": 0.0, 

288 "reward": -np.inf, 

289 "marginal_reward": np.inf, 

290 "v_der": np.inf, 

291 "v_der_inv": 0.0, 

292 }, 

293 ) 

294 else: 

295 self.m_nrm_min = self.BoroCnstArt 

296 self.nat_boro_cnst = False 

297 self.borocnst = xr.Dataset( 

298 coords={"mNrm": self.m_nrm_min, "aNrm": self.m_nrm_min}, 

299 data_vars={"cNrm": 0.0}, 

300 ) 

301 

302 def state_transition( 

303 self, state: dict[str, Any], action: dict[str, Any], params: SimpleNamespace 

304 ) -> dict[str, Any]: 

305 """Compute post-decision state from state and action.""" 

306 return {"aNrm": state["mNrm"] - action["cNrm"]} 

307 

308 def reverse_transition( 

309 self, 

310 post_state: dict[str, Any], 

311 action: dict[str, Any], 

312 params: SimpleNamespace, 

313 ) -> dict[str, Any]: 

314 """Recover state from post-decision state and action (for EGM).""" 

315 return {"mNrm": post_state["aNrm"] + action["cNrm"]} 

316 

317 def egm_transition( 

318 self, 

319 post_state: dict[str, Any], 

320 continuation: ValueFuncCRRALabeled, 

321 params: SimpleNamespace, 

322 ) -> dict[str, Any]: 

323 """Compute optimal action using first-order condition (EGM).""" 

324 return { 

325 "cNrm": self.u.derinv(params.Discount * continuation.derivative(post_state)) 

326 } 

327 

328 def value_transition( 

329 self, 

330 action: dict[str, Any], 

331 state: dict[str, Any], 

332 continuation: ValueFuncCRRALabeled, 

333 params: SimpleNamespace, 

334 ) -> dict[str, Any]: 

335 """Compute value function variables from action, state, and continuation.""" 

336 variables = {} 

337 post_state = self.state_transition(state, action, params) 

338 variables.update(post_state) 

339 

340 variables["reward"] = self.u(action["cNrm"]) 

341 variables["v"] = variables["reward"] + params.Discount * continuation( 

342 post_state 

343 ) 

344 variables["v_inv"] = self.u.inv(variables["v"]) 

345 

346 variables["marginal_reward"] = self.u.der(action["cNrm"]) 

347 variables["v_der"] = variables["marginal_reward"] 

348 variables["v_der_inv"] = action["cNrm"] 

349 

350 variables["contributions"] = variables["v"] 

351 variables["value"] = np.sum(variables["v"]) 

352 

353 return variables 

354 

355 def _continuation_for_expectation( 

356 self, 

357 shocks: dict[str, Any], 

358 post_state: dict[str, Any], 

359 value_next: ValueFuncCRRALabeled, 

360 params: SimpleNamespace, 

361 ) -> dict[str, Any]: 

362 """ 

363 Wrapper for continuation transition compatible with expected(). 

364 

365 This method adapts the transitions.continuation() interface to work 

366 with the expected() function from DiscreteDistributionLabeled. 

367 

368 Parameters 

369 ---------- 

370 shocks : dict[str, Any] 

371 Shock realizations (e.g., perm, tran, risky). 

372 post_state : dict[str, Any] 

373 Post-decision state (e.g., aNrm). 

374 value_next : ValueFuncCRRALabeled 

375 Next period's value function. 

376 params : SimpleNamespace 

377 Model parameters. 

378 

379 Returns 

380 ------- 

381 dict[str, Any] 

382 Continuation value variables. 

383 """ 

384 return self.transitions.continuation( 

385 post_state, shocks, value_next, params, self.u 

386 ) 

387 

388 def endogenous_grid_method(self) -> None: 

389 """Execute the Endogenous Grid Method algorithm.""" 

390 wfunc = self.create_continuation_function() 

391 

392 # Check for numerical issues in continuation function 

393 if np.any(~np.isfinite(wfunc.dataset["v_der_inv"].values)): 

394 warnings.warn( 

395 "Continuation value function contains NaN or Inf values. " 

396 "This may indicate invalid parameters (CRRA too high, " 

397 "PermGroFac issues, or extreme shock realizations).", 

398 RuntimeWarning, 

399 stacklevel=2, 

400 ) 

401 

402 # EGM: Get optimal actions from first-order condition 

403 acted = self.egm_transition(self.post_state, wfunc, self.params) 

404 state = self.reverse_transition(self.post_state, acted, self.params) 

405 

406 # Check for numerical issues in EGM results 

407 if np.any(acted["cNrm"] < 0): 

408 warnings.warn( 

409 "EGM produced negative consumption values. " 

410 "Check discount factor and interest rate parameters.", 

411 RuntimeWarning, 

412 stacklevel=2, 

413 ) 

414 

415 # Swap dimensions for state-based indexing 

416 action = xr.Dataset(acted).swap_dims({"aNrm": "mNrm"}) 

417 state = xr.Dataset(state).swap_dims({"aNrm": "mNrm"}) 

418 

419 egm_dataset = xr.merge([action, state]) 

420 

421 if not self.nat_boro_cnst: 

422 egm_dataset = xr.concat( 

423 [self.borocnst, egm_dataset], dim="mNrm", data_vars="all" 

424 ) 

425 

426 # Compute values 

427 values = self.value_transition(egm_dataset, egm_dataset, wfunc, self.params) 

428 egm_dataset.update(values) 

429 

430 if self.nat_boro_cnst: 

431 egm_dataset = xr.concat( 

432 [self.borocnst, egm_dataset], 

433 dim="mNrm", 

434 data_vars="all", 

435 combine_attrs="no_conflicts", 

436 ) 

437 

438 egm_dataset = egm_dataset.drop_vars("aNrm") 

439 

440 # Build solution 

441 vfunc = ValueFuncCRRALabeled( 

442 egm_dataset[["v", "v_der", "v_inv", "v_der_inv"]], self.params.CRRA 

443 ) 

444 pfunc = egm_dataset[["cNrm"]] 

445 

446 self.solution = ConsumerSolutionLabeled( 

447 value=vfunc, 

448 policy=pfunc, 

449 continuation=wfunc, 

450 attrs={"m_nrm_min": self.m_nrm_min, "dataset": egm_dataset}, 

451 ) 

452 

453 

454class ConsPerfForesightLabeledSolver(BaseLabeledSolver): 

455 """ 

456 Solver for perfect foresight consumption model. 

457 

458 Uses PerfectForesightTransitions - no shocks, risk-free return only. 

459 """ 

460 

461 TransitionsClass = PerfectForesightTransitions 

462 

463 

464class ConsIndShockLabeledSolver(BaseLabeledSolver): 

465 """ 

466 Solver for consumption model with idiosyncratic income shocks. 

467 

468 Uses IndShockTransitions and integrates continuation value over 

469 the income shock distribution. 

470 

471 Additional Parameters 

472 --------------------- 

473 IncShkDstn : DiscreteDistributionLabeled 

474 Distribution of income shocks with 'perm' and 'tran' variables. 

475 """ 

476 

477 TransitionsClass = IndShockTransitions 

478 

479 def __init__( 

480 self, 

481 solution_next: ConsumerSolutionLabeled, 

482 IncShkDstn: DiscreteDistributionLabeled, 

483 LivPrb: float, 

484 DiscFac: float, 

485 CRRA: float, 

486 Rfree: float, 

487 PermGroFac: float, 

488 BoroCnstArt: float | None, 

489 aXtraGrid: np.ndarray, 

490 **kwargs, 

491 ) -> None: 

492 self.IncShkDstn = IncShkDstn 

493 super().__init__( 

494 solution_next=solution_next, 

495 LivPrb=LivPrb, 

496 DiscFac=DiscFac, 

497 CRRA=CRRA, 

498 Rfree=Rfree, 

499 PermGroFac=PermGroFac, 

500 BoroCnstArt=BoroCnstArt, 

501 aXtraGrid=aXtraGrid, 

502 **kwargs, 

503 ) 

504 

505 def calculate_borrowing_constraint(self) -> None: 

506 """Calculate constraint accounting for minimum shock realizations.""" 

507 PermShkMinNext = np.min(self.IncShkDstn.atoms[0]) 

508 TranShkMinNext = np.min(self.IncShkDstn.atoms[1]) 

509 

510 self.BoroCnstNat = ( 

511 (self.solution_next.attrs["m_nrm_min"] - TranShkMinNext) 

512 * (self.params.PermGroFac * PermShkMinNext) 

513 / self.params.Rfree 

514 ) 

515 

516 def create_continuation_function(self) -> ValueFuncCRRALabeled: 

517 """Create continuation function by integrating over income shocks.""" 

518 value_next = self.solution_next.value 

519 

520 v_end = self.IncShkDstn.expected( 

521 func=self._continuation_for_expectation, 

522 post_state=self.post_state, 

523 value_next=value_next, 

524 params=self.params, 

525 ) 

526 

527 v_end["v_inv"] = self.u.inv(v_end["v"]) 

528 v_end["v_der_inv"] = self.u.derinv(v_end["v_der"]) 

529 

530 borocnst = self.borocnst.drop_vars(["mNrm"]).expand_dims("aNrm") 

531 if self.nat_boro_cnst: 

532 v_end = xr.merge([borocnst, v_end], join="outer", compat="no_conflicts") 

533 

534 return ValueFuncCRRALabeled(v_end, self.params.CRRA) 

535 

536 

537class ConsRiskyAssetLabeledSolver(BaseLabeledSolver): 

538 """ 

539 Solver for consumption model with risky asset. 

540 

541 Uses RiskyAssetTransitions - all savings earn stochastic risky return. 

542 

543 Additional Parameters 

544 --------------------- 

545 ShockDstn : DiscreteDistributionLabeled 

546 Joint distribution of income and risky return shocks. 

547 """ 

548 

549 TransitionsClass = RiskyAssetTransitions 

550 

551 def __init__( 

552 self, 

553 solution_next: ConsumerSolutionLabeled, 

554 ShockDstn: DiscreteDistributionLabeled, 

555 LivPrb: float, 

556 DiscFac: float, 

557 CRRA: float, 

558 Rfree: float, 

559 PermGroFac: float, 

560 BoroCnstArt: float | None, 

561 aXtraGrid: np.ndarray, 

562 **kwargs, 

563 ) -> None: 

564 self.ShockDstn = ShockDstn 

565 super().__init__( 

566 solution_next=solution_next, 

567 LivPrb=LivPrb, 

568 DiscFac=DiscFac, 

569 CRRA=CRRA, 

570 Rfree=Rfree, 

571 PermGroFac=PermGroFac, 

572 BoroCnstArt=BoroCnstArt, 

573 aXtraGrid=aXtraGrid, 

574 **kwargs, 

575 ) 

576 

577 def calculate_borrowing_constraint(self) -> None: 

578 """Calculate constraint with artificial borrowing constraint.""" 

579 self.BoroCnstArt = 0.0 

580 self.IncShkDstn = self.ShockDstn 

581 

582 PermShkMinNext = np.min(self.ShockDstn.atoms[0]) 

583 TranShkMinNext = np.min(self.ShockDstn.atoms[1]) 

584 

585 self.BoroCnstNat = ( 

586 (self.solution_next.attrs["m_nrm_min"] - TranShkMinNext) 

587 * (self.params.PermGroFac * PermShkMinNext) 

588 / self.params.Rfree 

589 ) 

590 

591 def create_continuation_function(self) -> ValueFuncCRRALabeled: 

592 """Create continuation function integrating over shock distribution.""" 

593 value_next = self.solution_next.value 

594 

595 v_end = self.ShockDstn.expected( 

596 func=self._continuation_for_expectation, 

597 post_state=self.post_state, 

598 value_next=value_next, 

599 params=self.params, 

600 ) 

601 

602 v_end["v_inv"] = self.u.inv(v_end["v"]) 

603 v_end["v_der_inv"] = self.u.derinv(v_end["v_der"]) 

604 

605 borocnst = self.borocnst.drop_vars(["mNrm"]).expand_dims("aNrm") 

606 if self.nat_boro_cnst: 

607 v_end = xr.merge([borocnst, v_end], join="outer", compat="no_conflicts") 

608 

609 v_end = v_end.transpose("aNrm", ...) 

610 

611 return ValueFuncCRRALabeled(v_end, self.params.CRRA) 

612 

613 

614class ConsFixedPortfolioLabeledSolver(ConsRiskyAssetLabeledSolver): 

615 """ 

616 Solver for consumption model with fixed portfolio allocation. 

617 

618 Uses FixedPortfolioTransitions - agent allocates fixed share to risky asset. 

619 

620 Additional Parameters 

621 --------------------- 

622 RiskyShareFixed : float 

623 Fixed share of savings allocated to risky asset. 

624 """ 

625 

626 TransitionsClass = FixedPortfolioTransitions 

627 

628 def __init__( 

629 self, 

630 solution_next: ConsumerSolutionLabeled, 

631 ShockDstn: DiscreteDistributionLabeled, 

632 LivPrb: float, 

633 DiscFac: float, 

634 CRRA: float, 

635 Rfree: float, 

636 PermGroFac: float, 

637 BoroCnstArt: float | None, 

638 aXtraGrid: np.ndarray, 

639 RiskyShareFixed: float, 

640 **kwargs, 

641 ) -> None: 

642 # Validate RiskyShareFixed 

643 if RiskyShareFixed < 0 or RiskyShareFixed > 1: 

644 raise ValueError( 

645 f"RiskyShareFixed must be in [0, 1], got {RiskyShareFixed}" 

646 ) 

647 

648 self.RiskyShareFixed = RiskyShareFixed 

649 super().__init__( 

650 solution_next=solution_next, 

651 ShockDstn=ShockDstn, 

652 LivPrb=LivPrb, 

653 DiscFac=DiscFac, 

654 CRRA=CRRA, 

655 Rfree=Rfree, 

656 PermGroFac=PermGroFac, 

657 BoroCnstArt=BoroCnstArt, 

658 aXtraGrid=aXtraGrid, 

659 **kwargs, 

660 ) 

661 

662 def create_params_namespace(self) -> SimpleNamespace: 

663 """Add RiskyShareFixed to parameters.""" 

664 params = super().create_params_namespace() 

665 params.RiskyShareFixed = self.RiskyShareFixed 

666 return params 

667 

668 

669class ConsPortfolioLabeledSolver(ConsRiskyAssetLabeledSolver): 

670 """ 

671 Solver for consumption model with optimal portfolio choice. 

672 

673 Uses PortfolioTransitions - agent optimally chooses risky share each period. 

674 The optimal share is found by solving the portfolio first-order condition. 

675 

676 Additional Parameters 

677 --------------------- 

678 ShareGrid : np.ndarray 

679 Grid of risky share values to search over. 

680 """ 

681 

682 TransitionsClass = PortfolioTransitions 

683 

684 def __init__( 

685 self, 

686 solution_next: ConsumerSolutionLabeled, 

687 ShockDstn: DiscreteDistributionLabeled, 

688 LivPrb: float, 

689 DiscFac: float, 

690 CRRA: float, 

691 Rfree: float, 

692 PermGroFac: float, 

693 BoroCnstArt: float | None, 

694 aXtraGrid: np.ndarray, 

695 ShareGrid: np.ndarray, 

696 **kwargs, 

697 ) -> None: 

698 # Validate ShareGrid 

699 ShareGrid = np.asarray(ShareGrid) 

700 if len(ShareGrid) == 0: 

701 raise ValueError("ShareGrid cannot be empty") 

702 if np.any(ShareGrid < 0) or np.any(ShareGrid > 1): 

703 raise ValueError("ShareGrid values must be in [0, 1]") 

704 if not np.all(np.diff(ShareGrid) > 0): 

705 raise ValueError("ShareGrid must be strictly increasing") 

706 

707 self.ShareGrid = ShareGrid 

708 super().__init__( 

709 solution_next=solution_next, 

710 ShockDstn=ShockDstn, 

711 LivPrb=LivPrb, 

712 DiscFac=DiscFac, 

713 CRRA=CRRA, 

714 Rfree=Rfree, 

715 PermGroFac=PermGroFac, 

716 BoroCnstArt=BoroCnstArt, 

717 aXtraGrid=aXtraGrid, 

718 **kwargs, 

719 ) 

720 

721 def create_post_state(self) -> xr.Dataset: 

722 """Add risky share dimension to post-decision state.""" 

723 post_state = super().create_post_state() 

724 post_state["stigma"] = xr.DataArray( 

725 self.ShareGrid, dims=["stigma"], attrs={"long_name": "risky share"} 

726 ) 

727 return post_state 

728 

729 def create_continuation_function(self) -> ValueFuncCRRALabeled: 

730 """ 

731 Create continuation function with optimal portfolio choice. 

732 

733 First computes continuation value over the (aNrm, stigma) grid, 

734 then finds the optimal stigma for each aNrm level. 

735 """ 

736 # Get continuation value over full (aNrm, stigma) grid 

737 wfunc = super().create_continuation_function() 

738 

739 dvds = wfunc.dataset["dvds"].values 

740 

741 # Find optimal share using linear interpolation on FOC 

742 crossing = np.logical_and(dvds[:, 1:] <= 0.0, dvds[:, :-1] >= 0.0) 

743 share_idx = np.argmax(crossing, axis=1) 

744 a_idx = np.arange(self.post_state["aNrm"].size) 

745 

746 bottom_share = self.ShareGrid[share_idx] 

747 top_share = self.ShareGrid[share_idx + 1] 

748 bottom_foc = dvds[a_idx, share_idx] 

749 top_foc = dvds[a_idx, share_idx + 1] 

750 

751 # Linear interpolation with division-by-zero protection 

752 denominator = top_foc - bottom_foc 

753 fallback_mask = np.abs(denominator) <= 1e-12 

754 if np.any(fallback_mask): 

755 n_fallbacks = np.sum(fallback_mask) 

756 warnings.warn( 

757 f"Portfolio optimization used fallback interpolation for {n_fallbacks} " 

758 f"grid points due to near-zero FOC difference. " 

759 f"Consider refining ShareGrid for more accurate results.", 

760 RuntimeWarning, 

761 stacklevel=2, 

762 ) 

763 alpha = np.where( 

764 ~fallback_mask, 

765 1.0 - top_foc / denominator, 

766 0.5, 

767 ) 

768 opt_share = (1.0 - alpha) * bottom_share + alpha * top_share 

769 

770 # Handle corner solutions 

771 opt_share[dvds[:, -1] > 0.0] = 1.0 # Want more than 100% risky 

772 opt_share[dvds[:, 0] < 0.0] = 0.0 # Want less than 0% risky 

773 

774 if not self.nat_boro_cnst: 

775 # At aNrm = 0 the portfolio share is irrelevant; 1.0 is limit as a --> 0 

776 opt_share[0] = 1.0 

777 

778 opt_share = xr.DataArray( 

779 opt_share, 

780 coords={"aNrm": self.post_state["aNrm"].values}, 

781 dims=["aNrm"], 

782 attrs={"long_name": "optimal risky share"}, 

783 ) 

784 

785 # Evaluate continuation at optimal share 

786 v_end = wfunc.evaluate({"aNrm": self.post_state["aNrm"], "stigma": opt_share}) 

787 v_end = v_end.reset_coords(names="stigma") 

788 

789 wfunc = ValueFuncCRRALabeled(v_end, self.params.CRRA) 

790 

791 # Remove stigma from post_state for EGM 

792 self.post_state = self.post_state.drop_vars("stigma") 

793 

794 return wfunc