Coverage for HARK / Labeled / solvers.py: 91%

214 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-10 06:19 +0000

1""" 

2Solvers for labeled consumption-saving models. 

3 

4This module implements the Template Method pattern for the Endogenous Grid 

5Method (EGM) algorithm. The base solver defines the algorithm skeleton, 

6while concrete solvers override specific hook methods for their model type. 

7""" 

8 

9from __future__ import annotations 

10 

11import warnings 

12from types import SimpleNamespace 

13from typing import TYPE_CHECKING, Any 

14 

15import numpy as np 

16import xarray as xr 

17 

18from HARK.metric import MetricObject 

19from HARK.rewards import UtilityFuncCRRA 

20 

21from .solution import ConsumerSolutionLabeled, ValueFuncCRRALabeled 

22from .transitions import ( 

23 FixedPortfolioTransitions, 

24 IndShockTransitions, 

25 PerfectForesightTransitions, 

26 PortfolioTransitions, 

27 RiskyAssetTransitions, 

28) 

29 

30if TYPE_CHECKING: 

31 from HARK.distributions import DiscreteDistributionLabeled 

32 

33__all__ = [ 

34 "BaseLabeledSolver", 

35 "ConsPerfForesightLabeledSolver", 

36 "ConsIndShockLabeledSolver", 

37 "ConsRiskyAssetLabeledSolver", 

38 "ConsFixedPortfolioLabeledSolver", 

39 "ConsPortfolioLabeledSolver", 

40] 

41 

42 

43class BaseLabeledSolver(MetricObject): 

44 """ 

45 Base solver implementing Template Method pattern for EGM algorithm. 

46 

47 This class provides the algorithm skeleton for solving consumption-saving 

48 problems using the Endogenous Grid Method. Subclasses customize behavior 

49 by: 

50 1. Setting TransitionsClass to specify model-specific transitions 

51 2. Overriding hook methods for model-specific logic 

52 

53 Template Method: solve() 

54 Hook Methods: 

55 - create_params_namespace(): Add model-specific parameters 

56 - calculate_borrowing_constraint(): Model-specific constraint logic 

57 - create_post_state(): Add extra state dimensions (e.g., stigma) 

58 - create_continuation_function(): Handle shock integration 

59 

60 Parameters 

61 ---------- 

62 solution_next : ConsumerSolutionLabeled 

63 Solution to next period's problem. 

64 LivPrb : float 

65 Survival probability. 

66 DiscFac : float 

67 Intertemporal discount factor. 

68 CRRA : float 

69 Coefficient of relative risk aversion. 

70 Rfree : float 

71 Risk-free interest factor. 

72 PermGroFac : float 

73 Permanent income growth factor. 

74 BoroCnstArt : float or None 

75 Artificial borrowing constraint. 

76 aXtraGrid : np.ndarray 

77 Grid of end-of-period asset values above minimum. 

78 **kwargs 

79 Additional model-specific parameters. 

80 

81 Raises 

82 ------ 

83 ValueError 

84 If CRRA is invalid or aXtraGrid is malformed. 

85 """ 

86 

87 # Class-level strategy specification - override in subclasses 

88 TransitionsClass: type = PerfectForesightTransitions 

89 

90 def __init__( 

91 self, 

92 solution_next: ConsumerSolutionLabeled, 

93 LivPrb: float, 

94 DiscFac: float, 

95 CRRA: float, 

96 Rfree: float, 

97 PermGroFac: float, 

98 BoroCnstArt: float | None, 

99 aXtraGrid: np.ndarray, 

100 **kwargs, 

101 ) -> None: 

102 # Input validation - solution_next 

103 if solution_next is None: 

104 raise ValueError("solution_next cannot be None") 

105 if not isinstance(solution_next, ConsumerSolutionLabeled): 

106 raise TypeError( 

107 f"solution_next must be ConsumerSolutionLabeled, got {type(solution_next)}" 

108 ) 

109 if "m_nrm_min" not in solution_next.attrs: 

110 raise ValueError( 

111 "solution_next.attrs must contain 'm_nrm_min'. " 

112 "Use make_solution_terminal_labeled() to create valid terminal solutions." 

113 ) 

114 

115 # Input validation - CRRA 

116 if not np.isfinite(CRRA): 

117 raise ValueError(f"CRRA must be finite, got {CRRA}") 

118 if CRRA < 0: 

119 raise ValueError(f"CRRA must be non-negative, got {CRRA}") 

120 

121 # Input validation - economic parameters 

122 if LivPrb <= 0 or LivPrb > 1: 

123 raise ValueError(f"LivPrb must be in (0, 1], got {LivPrb}") 

124 if DiscFac <= 0: 

125 raise ValueError(f"DiscFac must be positive, got {DiscFac}") 

126 if Rfree <= 0: 

127 raise ValueError(f"Rfree must be positive, got {Rfree}") 

128 if PermGroFac <= 0: 

129 raise ValueError(f"PermGroFac must be positive, got {PermGroFac}") 

130 

131 # Input validation - asset grid 

132 aXtraGrid = np.asarray(aXtraGrid) 

133 if len(aXtraGrid) == 0: 

134 raise ValueError("aXtraGrid cannot be empty") 

135 if np.any(aXtraGrid < 0): 

136 raise ValueError("aXtraGrid values must be non-negative") 

137 if not np.all(np.diff(aXtraGrid) > 0): 

138 raise ValueError("aXtraGrid must be strictly increasing") 

139 

140 # Store parameters 

141 self.solution_next = solution_next 

142 self.LivPrb = LivPrb 

143 self.DiscFac = DiscFac 

144 self.CRRA = CRRA 

145 self.Rfree = Rfree 

146 self.PermGroFac = PermGroFac 

147 self.BoroCnstArt = BoroCnstArt 

148 self.aXtraGrid = aXtraGrid 

149 

150 # Initialize utility function 

151 self.u = UtilityFuncCRRA(CRRA) 

152 

153 # Initialize transitions strategy 

154 self.transitions = self.TransitionsClass() 

155 

156 # Store additional kwargs 

157 for key, value in kwargs.items(): 

158 setattr(self, key, value) 

159 

160 # ========================================================================= 

161 # TEMPLATE METHOD - The algorithm skeleton 

162 # ========================================================================= 

163 

164 def solve(self) -> ConsumerSolutionLabeled: 

165 """ 

166 Solve the consumption-saving problem using EGM. 

167 

168 This is the template method that defines the algorithm skeleton. 

169 It calls hook methods that subclasses can override. 

170 

171 Returns 

172 ------- 

173 ConsumerSolutionLabeled 

174 Solution containing value function, policy, and continuation. 

175 """ 

176 self.prepare_to_solve() 

177 self.endogenous_grid_method() 

178 return self.solution 

179 

180 # ========================================================================= 

181 # HOOK METHODS - Override in subclasses for customization 

182 # ========================================================================= 

183 

184 def create_params_namespace(self) -> SimpleNamespace: 

185 """ 

186 Create parameters namespace. 

187 

188 Override in subclasses to add model-specific parameters. 

189 

190 Returns 

191 ------- 

192 SimpleNamespace 

193 Parameters for this period's problem. 

194 """ 

195 return SimpleNamespace( 

196 Discount=self.DiscFac * self.LivPrb, 

197 CRRA=self.CRRA, 

198 Rfree=self.Rfree, 

199 PermGroFac=self.PermGroFac, 

200 ) 

201 

202 def calculate_borrowing_constraint(self) -> None: 

203 """ 

204 Calculate the natural borrowing constraint. 

205 

206 Override in shock models to account for minimum shock realizations. 

207 Sets self.BoroCnstNat. 

208 """ 

209 self.BoroCnstNat = ( 

210 (self.solution_next.attrs["m_nrm_min"] - 1) 

211 * self.params.PermGroFac 

212 / self.params.Rfree 

213 ) 

214 

215 def create_post_state(self) -> xr.Dataset: 

216 """ 

217 Create the post-decision state grid. 

218 

219 Override in portfolio models to add the risky share dimension. 

220 

221 Returns 

222 ------- 

223 xr.Dataset 

224 Post-decision state grid. 

225 """ 

226 if self.nat_boro_cnst: 

227 a_grid = self.aXtraGrid + self.m_nrm_min 

228 else: 

229 a_grid = np.append(0.0, self.aXtraGrid) + self.m_nrm_min 

230 

231 aVec = xr.DataArray( 

232 a_grid, 

233 name="aNrm", 

234 dims=("aNrm"), 

235 attrs={"long_name": "savings", "state": True}, 

236 ) 

237 return xr.Dataset({"aNrm": aVec}) 

238 

239 def create_continuation_function(self) -> ValueFuncCRRALabeled: 

240 """ 

241 Create the continuation value function. 

242 

243 Override in shock models to integrate over shock distributions. 

244 

245 Returns 

246 ------- 

247 ValueFuncCRRALabeled 

248 Continuation value function. 

249 """ 

250 value_next = self.solution_next.value 

251 

252 v_end = self.transitions.continuation( 

253 self.post_state, None, value_next, self.params, self.u 

254 ) 

255 v_end = xr.Dataset(v_end).drop_vars(["mNrm"]) 

256 

257 return self._finalize_value_func(v_end) 

258 

259 def _finalize_value_func(self, v_end, transform=None) -> ValueFuncCRRALabeled: 

260 """Apply the post-processing shared by ``create_continuation_function`` 

261 across solver subclasses: write inverse value variables, merge in the 

262 borrowing-constraint boundary when natural, and optionally apply 

263 ``transform`` (e.g. a ``.transpose(...)`` reordering).""" 

264 v_end["v_inv"] = self.u.inv(v_end["v"]) 

265 v_end["v_der_inv"] = self.u.derinv(v_end["v_der"]) 

266 borocnst = self.borocnst.drop_vars(["mNrm"]).expand_dims("aNrm") 

267 if self.nat_boro_cnst: 

268 v_end = xr.merge([borocnst, v_end], join="outer", compat="no_conflicts") 

269 if transform is not None: 

270 v_end = transform(v_end) 

271 return ValueFuncCRRALabeled(v_end, self.params.CRRA) 

272 

273 def _natural_boro_cnst_with_shocks(self, perm_shk_min, tran_shk_min) -> float: 

274 """Compute ``BoroCnstNat`` accounting for minimum shock realizations. 

275 Shared by ``ConsIndShockLabeledSolver`` and ``ConsRiskyAssetLabeledSolver``.""" 

276 return ( 

277 (self.solution_next.attrs["m_nrm_min"] - tran_shk_min) 

278 * (self.params.PermGroFac * perm_shk_min) 

279 / self.params.Rfree 

280 ) 

281 

282 # ========================================================================= 

283 # CORE METHODS - Shared implementation, rarely overridden 

284 # ========================================================================= 

285 

286 def prepare_to_solve(self) -> None: 

287 """Prepare solver state before running EGM.""" 

288 self.params = self.create_params_namespace() 

289 self.calculate_borrowing_constraint() 

290 self.define_boundary_constraint() 

291 self.post_state = self.create_post_state() 

292 

293 def define_boundary_constraint(self) -> None: 

294 """Define borrowing constraint boundary conditions.""" 

295 if self.BoroCnstArt is None or self.BoroCnstArt <= self.BoroCnstNat: 

296 self.m_nrm_min = self.BoroCnstNat 

297 self.nat_boro_cnst = True 

298 self.borocnst = xr.Dataset( 

299 coords={"mNrm": self.m_nrm_min, "aNrm": self.m_nrm_min}, 

300 data_vars={ 

301 "cNrm": 0.0, 

302 "v": -np.inf, 

303 "v_inv": 0.0, 

304 "reward": -np.inf, 

305 "marginal_reward": np.inf, 

306 "v_der": np.inf, 

307 "v_der_inv": 0.0, 

308 }, 

309 ) 

310 else: 

311 self.m_nrm_min = self.BoroCnstArt 

312 self.nat_boro_cnst = False 

313 self.borocnst = xr.Dataset( 

314 coords={"mNrm": self.m_nrm_min, "aNrm": self.m_nrm_min}, 

315 data_vars={"cNrm": 0.0}, 

316 ) 

317 

318 def state_transition( 

319 self, state: dict[str, Any], action: dict[str, Any], params: SimpleNamespace 

320 ) -> dict[str, Any]: 

321 """Compute post-decision state from state and action.""" 

322 return {"aNrm": state["mNrm"] - action["cNrm"]} 

323 

324 def reverse_transition( 

325 self, 

326 post_state: dict[str, Any], 

327 action: dict[str, Any], 

328 params: SimpleNamespace, 

329 ) -> dict[str, Any]: 

330 """Recover state from post-decision state and action (for EGM).""" 

331 return {"mNrm": post_state["aNrm"] + action["cNrm"]} 

332 

333 def egm_transition( 

334 self, 

335 post_state: dict[str, Any], 

336 continuation: ValueFuncCRRALabeled, 

337 params: SimpleNamespace, 

338 ) -> dict[str, Any]: 

339 """Compute optimal action using first-order condition (EGM).""" 

340 return { 

341 "cNrm": self.u.derinv(params.Discount * continuation.derivative(post_state)) 

342 } 

343 

344 def value_transition( 

345 self, 

346 action: dict[str, Any], 

347 state: dict[str, Any], 

348 continuation: ValueFuncCRRALabeled, 

349 params: SimpleNamespace, 

350 ) -> dict[str, Any]: 

351 """Compute value function variables from action, state, and continuation.""" 

352 variables = {} 

353 post_state = self.state_transition(state, action, params) 

354 variables.update(post_state) 

355 

356 variables["reward"] = self.u(action["cNrm"]) 

357 variables["v"] = variables["reward"] + params.Discount * continuation( 

358 post_state 

359 ) 

360 variables["v_inv"] = self.u.inv(variables["v"]) 

361 

362 variables["marginal_reward"] = self.u.der(action["cNrm"]) 

363 variables["v_der"] = variables["marginal_reward"] 

364 variables["v_der_inv"] = action["cNrm"] 

365 

366 variables["contributions"] = variables["v"] 

367 variables["value"] = np.sum(variables["v"]) 

368 

369 return variables 

370 

371 def _continuation_for_expectation( 

372 self, 

373 shocks: dict[str, Any], 

374 post_state: dict[str, Any], 

375 value_next: ValueFuncCRRALabeled, 

376 params: SimpleNamespace, 

377 ) -> dict[str, Any]: 

378 """ 

379 Wrapper for continuation transition compatible with expected(). 

380 

381 This method adapts the transitions.continuation() interface to work 

382 with the expected() function from DiscreteDistributionLabeled. 

383 

384 Parameters 

385 ---------- 

386 shocks : dict[str, Any] 

387 Shock realizations (e.g., perm, tran, risky). 

388 post_state : dict[str, Any] 

389 Post-decision state (e.g., aNrm). 

390 value_next : ValueFuncCRRALabeled 

391 Next period's value function. 

392 params : SimpleNamespace 

393 Model parameters. 

394 

395 Returns 

396 ------- 

397 dict[str, Any] 

398 Continuation value variables. 

399 """ 

400 return self.transitions.continuation( 

401 post_state, shocks, value_next, params, self.u 

402 ) 

403 

404 def endogenous_grid_method(self) -> None: 

405 """Execute the Endogenous Grid Method algorithm.""" 

406 wfunc = self.create_continuation_function() 

407 

408 # Check for numerical issues in continuation function 

409 if np.any(~np.isfinite(wfunc.dataset["v_der_inv"].values)): 

410 warnings.warn( 

411 "Continuation value function contains NaN or Inf values. " 

412 "This may indicate invalid parameters (CRRA too high, " 

413 "PermGroFac issues, or extreme shock realizations).", 

414 RuntimeWarning, 

415 stacklevel=2, 

416 ) 

417 

418 # EGM: Get optimal actions from first-order condition 

419 acted = self.egm_transition(self.post_state, wfunc, self.params) 

420 state = self.reverse_transition(self.post_state, acted, self.params) 

421 

422 # Check for numerical issues in EGM results 

423 if np.any(acted["cNrm"] < 0): 

424 warnings.warn( 

425 "EGM produced negative consumption values. " 

426 "Check discount factor and interest rate parameters.", 

427 RuntimeWarning, 

428 stacklevel=2, 

429 ) 

430 

431 # Swap dimensions for state-based indexing 

432 action = xr.Dataset(acted).swap_dims({"aNrm": "mNrm"}) 

433 state = xr.Dataset(state).swap_dims({"aNrm": "mNrm"}) 

434 

435 egm_dataset = xr.merge([action, state]) 

436 

437 if not self.nat_boro_cnst: 

438 egm_dataset = xr.concat( 

439 [self.borocnst, egm_dataset], dim="mNrm", data_vars="all" 

440 ) 

441 

442 # Compute values 

443 values = self.value_transition(egm_dataset, egm_dataset, wfunc, self.params) 

444 egm_dataset.update(values) 

445 

446 if self.nat_boro_cnst: 

447 egm_dataset = xr.concat( 

448 [self.borocnst, egm_dataset], 

449 dim="mNrm", 

450 data_vars="all", 

451 combine_attrs="no_conflicts", 

452 ) 

453 

454 egm_dataset = egm_dataset.drop_vars("aNrm") 

455 

456 # Build solution 

457 vfunc = ValueFuncCRRALabeled( 

458 egm_dataset[["v", "v_der", "v_inv", "v_der_inv"]], self.params.CRRA 

459 ) 

460 pfunc = egm_dataset[["cNrm"]] 

461 

462 self.solution = ConsumerSolutionLabeled( 

463 value=vfunc, 

464 policy=pfunc, 

465 continuation=wfunc, 

466 attrs={"m_nrm_min": self.m_nrm_min, "dataset": egm_dataset}, 

467 ) 

468 

469 

470class ConsPerfForesightLabeledSolver(BaseLabeledSolver): 

471 """ 

472 Solver for perfect foresight consumption model. 

473 

474 Uses PerfectForesightTransitions - no shocks, risk-free return only. 

475 """ 

476 

477 TransitionsClass = PerfectForesightTransitions 

478 

479 

480class ConsIndShockLabeledSolver(BaseLabeledSolver): 

481 """ 

482 Solver for consumption model with idiosyncratic income shocks. 

483 

484 Uses IndShockTransitions and integrates continuation value over 

485 the income shock distribution. 

486 

487 Additional Parameters 

488 --------------------- 

489 IncShkDstn : DiscreteDistributionLabeled 

490 Distribution of income shocks with 'perm' and 'tran' variables. 

491 """ 

492 

493 TransitionsClass = IndShockTransitions 

494 

495 def __init__( 

496 self, 

497 solution_next: ConsumerSolutionLabeled, 

498 IncShkDstn: DiscreteDistributionLabeled, 

499 LivPrb: float, 

500 DiscFac: float, 

501 CRRA: float, 

502 Rfree: float, 

503 PermGroFac: float, 

504 BoroCnstArt: float | None, 

505 aXtraGrid: np.ndarray, 

506 **kwargs, 

507 ) -> None: 

508 self.IncShkDstn = IncShkDstn 

509 super().__init__( 

510 solution_next=solution_next, 

511 LivPrb=LivPrb, 

512 DiscFac=DiscFac, 

513 CRRA=CRRA, 

514 Rfree=Rfree, 

515 PermGroFac=PermGroFac, 

516 BoroCnstArt=BoroCnstArt, 

517 aXtraGrid=aXtraGrid, 

518 **kwargs, 

519 ) 

520 

521 def calculate_borrowing_constraint(self) -> None: 

522 """Calculate constraint accounting for minimum shock realizations.""" 

523 self.BoroCnstNat = self._natural_boro_cnst_with_shocks( 

524 np.min(self.IncShkDstn.atoms[0]), 

525 np.min(self.IncShkDstn.atoms[1]), 

526 ) 

527 

528 def create_continuation_function(self) -> ValueFuncCRRALabeled: 

529 """Create continuation function by integrating over income shocks.""" 

530 v_end = self.IncShkDstn.expected( 

531 func=self._continuation_for_expectation, 

532 post_state=self.post_state, 

533 value_next=self.solution_next.value, 

534 params=self.params, 

535 ) 

536 return self._finalize_value_func(v_end) 

537 

538 

539class ConsRiskyAssetLabeledSolver(BaseLabeledSolver): 

540 """ 

541 Solver for consumption model with risky asset. 

542 

543 Uses RiskyAssetTransitions - all savings earn stochastic risky return. 

544 

545 Additional Parameters 

546 --------------------- 

547 ShockDstn : DiscreteDistributionLabeled 

548 Joint distribution of income and risky return shocks. 

549 """ 

550 

551 TransitionsClass = RiskyAssetTransitions 

552 

553 def __init__( 

554 self, 

555 solution_next: ConsumerSolutionLabeled, 

556 ShockDstn: DiscreteDistributionLabeled, 

557 LivPrb: float, 

558 DiscFac: float, 

559 CRRA: float, 

560 Rfree: float, 

561 PermGroFac: float, 

562 BoroCnstArt: float | None, 

563 aXtraGrid: np.ndarray, 

564 **kwargs, 

565 ) -> None: 

566 self.ShockDstn = ShockDstn 

567 super().__init__( 

568 solution_next=solution_next, 

569 LivPrb=LivPrb, 

570 DiscFac=DiscFac, 

571 CRRA=CRRA, 

572 Rfree=Rfree, 

573 PermGroFac=PermGroFac, 

574 BoroCnstArt=BoroCnstArt, 

575 aXtraGrid=aXtraGrid, 

576 **kwargs, 

577 ) 

578 

579 def calculate_borrowing_constraint(self) -> None: 

580 """Calculate constraint with artificial borrowing constraint.""" 

581 self.BoroCnstArt = 0.0 

582 self.IncShkDstn = self.ShockDstn 

583 self.BoroCnstNat = self._natural_boro_cnst_with_shocks( 

584 np.min(self.ShockDstn.atoms[0]), 

585 np.min(self.ShockDstn.atoms[1]), 

586 ) 

587 

588 def create_continuation_function(self) -> ValueFuncCRRALabeled: 

589 """Create continuation function integrating over shock distribution.""" 

590 v_end = self.ShockDstn.expected( 

591 func=self._continuation_for_expectation, 

592 post_state=self.post_state, 

593 value_next=self.solution_next.value, 

594 params=self.params, 

595 ) 

596 return self._finalize_value_func( 

597 v_end, transform=lambda d: d.transpose("aNrm", ...) 

598 ) 

599 

600 

601class ConsFixedPortfolioLabeledSolver(ConsRiskyAssetLabeledSolver): 

602 """ 

603 Solver for consumption model with fixed portfolio allocation. 

604 

605 Uses FixedPortfolioTransitions - agent allocates fixed share to risky asset. 

606 

607 Additional Parameters 

608 --------------------- 

609 RiskyShareFixed : float 

610 Fixed share of savings allocated to risky asset. 

611 """ 

612 

613 TransitionsClass = FixedPortfolioTransitions 

614 

615 def __init__( 

616 self, 

617 solution_next: ConsumerSolutionLabeled, 

618 ShockDstn: DiscreteDistributionLabeled, 

619 LivPrb: float, 

620 DiscFac: float, 

621 CRRA: float, 

622 Rfree: float, 

623 PermGroFac: float, 

624 BoroCnstArt: float | None, 

625 aXtraGrid: np.ndarray, 

626 RiskyShareFixed: float, 

627 **kwargs, 

628 ) -> None: 

629 # Validate RiskyShareFixed 

630 if RiskyShareFixed < 0 or RiskyShareFixed > 1: 

631 raise ValueError( 

632 f"RiskyShareFixed must be in [0, 1], got {RiskyShareFixed}" 

633 ) 

634 

635 self.RiskyShareFixed = RiskyShareFixed 

636 super().__init__( 

637 solution_next=solution_next, 

638 ShockDstn=ShockDstn, 

639 LivPrb=LivPrb, 

640 DiscFac=DiscFac, 

641 CRRA=CRRA, 

642 Rfree=Rfree, 

643 PermGroFac=PermGroFac, 

644 BoroCnstArt=BoroCnstArt, 

645 aXtraGrid=aXtraGrid, 

646 **kwargs, 

647 ) 

648 

649 def create_params_namespace(self) -> SimpleNamespace: 

650 """Add RiskyShareFixed to parameters.""" 

651 params = super().create_params_namespace() 

652 params.RiskyShareFixed = self.RiskyShareFixed 

653 return params 

654 

655 

656class ConsPortfolioLabeledSolver(ConsRiskyAssetLabeledSolver): 

657 """ 

658 Solver for consumption model with optimal portfolio choice. 

659 

660 Uses PortfolioTransitions - agent optimally chooses risky share each period. 

661 The optimal share is found by solving the portfolio first-order condition. 

662 

663 Additional Parameters 

664 --------------------- 

665 ShareGrid : np.ndarray 

666 Grid of risky share values to search over. 

667 """ 

668 

669 TransitionsClass = PortfolioTransitions 

670 

671 def __init__( 

672 self, 

673 solution_next: ConsumerSolutionLabeled, 

674 ShockDstn: DiscreteDistributionLabeled, 

675 LivPrb: float, 

676 DiscFac: float, 

677 CRRA: float, 

678 Rfree: float, 

679 PermGroFac: float, 

680 BoroCnstArt: float | None, 

681 aXtraGrid: np.ndarray, 

682 ShareGrid: np.ndarray, 

683 **kwargs, 

684 ) -> None: 

685 # Validate ShareGrid 

686 ShareGrid = np.asarray(ShareGrid) 

687 if len(ShareGrid) == 0: 

688 raise ValueError("ShareGrid cannot be empty") 

689 if np.any(ShareGrid < 0) or np.any(ShareGrid > 1): 

690 raise ValueError("ShareGrid values must be in [0, 1]") 

691 if not np.all(np.diff(ShareGrid) > 0): 

692 raise ValueError("ShareGrid must be strictly increasing") 

693 

694 self.ShareGrid = ShareGrid 

695 super().__init__( 

696 solution_next=solution_next, 

697 ShockDstn=ShockDstn, 

698 LivPrb=LivPrb, 

699 DiscFac=DiscFac, 

700 CRRA=CRRA, 

701 Rfree=Rfree, 

702 PermGroFac=PermGroFac, 

703 BoroCnstArt=BoroCnstArt, 

704 aXtraGrid=aXtraGrid, 

705 **kwargs, 

706 ) 

707 

708 def create_post_state(self) -> xr.Dataset: 

709 """Add risky share dimension to post-decision state.""" 

710 post_state = super().create_post_state() 

711 post_state["stigma"] = xr.DataArray( 

712 self.ShareGrid, dims=["stigma"], attrs={"long_name": "risky share"} 

713 ) 

714 return post_state 

715 

716 def create_continuation_function(self) -> ValueFuncCRRALabeled: 

717 """ 

718 Create continuation function with optimal portfolio choice. 

719 

720 First computes continuation value over the (aNrm, stigma) grid, 

721 then finds the optimal stigma for each aNrm level. 

722 """ 

723 # Get continuation value over full (aNrm, stigma) grid 

724 wfunc = super().create_continuation_function() 

725 

726 dvds = wfunc.dataset["dvds"].values 

727 

728 # Find optimal share using linear interpolation on FOC 

729 crossing = np.logical_and(dvds[:, 1:] <= 0.0, dvds[:, :-1] >= 0.0) 

730 share_idx = np.argmax(crossing, axis=1) 

731 a_idx = np.arange(self.post_state["aNrm"].size) 

732 

733 bottom_share = self.ShareGrid[share_idx] 

734 top_share = self.ShareGrid[share_idx + 1] 

735 bottom_foc = dvds[a_idx, share_idx] 

736 top_foc = dvds[a_idx, share_idx + 1] 

737 

738 # Linear interpolation with division-by-zero protection 

739 denominator = top_foc - bottom_foc 

740 fallback_mask = np.abs(denominator) <= 1e-12 

741 if np.any(fallback_mask): 

742 n_fallbacks = np.sum(fallback_mask) 

743 warnings.warn( 

744 f"Portfolio optimization used fallback interpolation for {n_fallbacks} " 

745 f"grid points due to near-zero FOC difference. " 

746 f"Consider refining ShareGrid for more accurate results.", 

747 RuntimeWarning, 

748 stacklevel=2, 

749 ) 

750 alpha = np.where( 

751 ~fallback_mask, 

752 1.0 - top_foc / denominator, 

753 0.5, 

754 ) 

755 opt_share = (1.0 - alpha) * bottom_share + alpha * top_share 

756 

757 # Handle corner solutions 

758 opt_share[dvds[:, -1] > 0.0] = 1.0 # Want more than 100% risky 

759 opt_share[dvds[:, 0] < 0.0] = 0.0 # Want less than 0% risky 

760 

761 if not self.nat_boro_cnst: 

762 # At aNrm = 0 the portfolio share is irrelevant; 1.0 is limit as a --> 0 

763 opt_share[0] = 1.0 

764 

765 opt_share = xr.DataArray( 

766 opt_share, 

767 coords={"aNrm": self.post_state["aNrm"].values}, 

768 dims=["aNrm"], 

769 attrs={"long_name": "optimal risky share"}, 

770 ) 

771 

772 # Evaluate continuation at optimal share 

773 v_end = wfunc.evaluate({"aNrm": self.post_state["aNrm"], "stigma": opt_share}) 

774 v_end = v_end.reset_coords(names="stigma") 

775 

776 wfunc = ValueFuncCRRALabeled(v_end, self.params.CRRA) 

777 

778 # Remove stigma from post_state for EGM 

779 self.post_state = self.post_state.drop_vars("stigma") 

780 

781 return wfunc