Coverage for HARK / Labeled / transitions.py: 98%

87 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-10 06:19 +0000

1""" 

2Transition functions for labeled consumption-saving models. 

3 

4This module implements the Strategy pattern for state transitions, 

5allowing different model types to share the same solver structure 

6while varying only the transition dynamics. 

7""" 

8 

9from __future__ import annotations 

10 

11from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable 

12 

13import numpy as np 

14 

15if TYPE_CHECKING: 

16 from types import SimpleNamespace 

17 

18 from HARK.rewards import UtilityFuncCRRA 

19 

20 from .solution import ValueFuncCRRALabeled 

21 

22__all__ = [ 

23 "Transitions", 

24 "PerfectForesightTransitions", 

25 "IndShockTransitions", 

26 "RiskyAssetTransitions", 

27 "FixedPortfolioTransitions", 

28 "PortfolioTransitions", 

29] 

30 

31 

32def _validate_shock_keys( 

33 shocks: dict[str, Any], required_keys: set[str], class_name: str 

34) -> None: 

35 """ 

36 Validate that shocks dictionary contains required keys. 

37 

38 Parameters 

39 ---------- 

40 shocks : dict 

41 Shock dictionary to validate. 

42 required_keys : set 

43 Set of required key names. 

44 class_name : str 

45 Name of the class for error messages. 

46 

47 Raises 

48 ------ 

49 KeyError 

50 If any required key is missing from shocks. 

51 """ 

52 missing_keys = required_keys - set(shocks.keys()) 

53 if missing_keys: 

54 raise KeyError( 

55 f"{class_name} requires shock keys {required_keys} but got {set(shocks.keys())}. " 

56 f"Missing: {missing_keys}. " 

57 f"Ensure the shock distribution has the correct variable names." 

58 ) 

59 

60 

61def _simple_post_state( 

62 transitions, 

63 post_state: dict[str, Any], 

64 shocks: dict[str, Any], 

65 params: "SimpleNamespace", 

66 return_rate: Any, 

67) -> dict[str, Any]: 

68 """ 

69 Shared ``post_state`` body for transitions whose only return component is 

70 a single asset return rate (no portfolio decomposition). 

71 

72 Validates shock keys, then maps post-decision assets through to next-period 

73 market resources ``mNrm = aNrm * return_rate / (PermGroFac * perm) + tran``. 

74 Used by :class:`IndShockTransitions` (``return_rate = params.Rfree``) and 

75 :class:`RiskyAssetTransitions` (``return_rate = shocks["risky"]``). 

76 """ 

77 _validate_shock_keys( 

78 shocks, transitions._required_shock_keys, type(transitions).__name__ 

79 ) 

80 next_state = {} 

81 next_state["mNrm"] = ( 

82 post_state["aNrm"] * return_rate / (params.PermGroFac * shocks["perm"]) 

83 + shocks["tran"] 

84 ) 

85 return next_state 

86 

87 

88def _portfolio_post_state( 

89 transitions, 

90 post_state: dict[str, Any], 

91 shocks: dict[str, Any], 

92 params: "SimpleNamespace", 

93 risky_share: Any, 

94) -> dict[str, Any]: 

95 """ 

96 Shared ``post_state`` body for portfolio transition classes. 

97 

98 Validates shock keys, computes the excess return ``rDiff`` and portfolio 

99 return ``rPort`` for the supplied ``risky_share``, then maps post-decision 

100 assets through to next-period market resources ``mNrm``. 

101 

102 The two portfolio variants differ only in where ``risky_share`` comes from: 

103 ``FixedPortfolioTransitions`` reads ``params.RiskyShareFixed`` while 

104 ``PortfolioTransitions`` reads ``post_state["stigma"]``. 

105 """ 

106 _validate_shock_keys( 

107 shocks, transitions._required_shock_keys, type(transitions).__name__ 

108 ) 

109 next_state = {} 

110 next_state["rDiff"] = shocks["risky"] - params.Rfree 

111 next_state["rPort"] = params.Rfree + next_state["rDiff"] * risky_share 

112 next_state["mNrm"] = ( 

113 post_state["aNrm"] * next_state["rPort"] / (params.PermGroFac * shocks["perm"]) 

114 + shocks["tran"] 

115 ) 

116 return next_state 

117 

118 

119def _base_continuation( 

120 transitions, 

121 post_state: dict[str, Any], 

122 shocks: dict[str, Any], 

123 value_next: "ValueFuncCRRALabeled", 

124 params: "SimpleNamespace", 

125 return_factor: Any, 

126) -> dict[str, Any]: 

127 """ 

128 Shared computation kernel for stochastic continuation methods. 

129 

130 Computes the next state, permanent income scaling (psi), value (v), 

131 marginal value (v_der), contributions, and aggregate value that are 

132 common to all stochastic transition classes. 

133 

134 Parameters 

135 ---------- 

136 transitions : Transitions instance 

137 The calling transitions object, whose ``post_state`` method is 

138 used to map the post-decision state forward through shocks. 

139 post_state : dict 

140 Post-decision state (e.g. containing 'aNrm'). 

141 shocks : dict 

142 Realized shocks for this quadrature node (must include 'perm'). 

143 value_next : ValueFuncCRRALabeled 

144 Next period's value function, callable and with a ``derivative`` 

145 method. 

146 params : SimpleNamespace 

147 Model parameters; must expose ``PermGroFac`` and ``CRRA``. 

148 return_factor : scalar, array, or callable 

149 Factor that scales the marginal value ``v_der``. Pass 

150 ``params.Rfree`` for IndShock, ``shocks["risky"]`` for 

151 RiskyAsset, ``1.0`` for Portfolio (which applies its own scaling 

152 afterward), or a callable ``(next_state) -> factor`` when the 

153 factor depends on ``next_state`` (e.g. ``lambda ns: ns["rPort"]`` 

154 for FixedPortfolio). 

155 

156 Returns 

157 ------- 

158 dict 

159 Variables dict containing: all entries from ``next_state``, 

160 ``psi``, ``v``, ``v_der``, ``contributions``, and ``value``. 

161 """ 

162 variables = {} 

163 next_state = transitions.post_state(post_state, shocks, params) 

164 variables.update(next_state) 

165 

166 psi = params.PermGroFac * shocks["perm"] 

167 variables["psi"] = psi 

168 

169 # Allow return_factor to depend on next_state without a second post_state call. 

170 if callable(return_factor): 

171 factor = return_factor(next_state) 

172 else: 

173 factor = return_factor 

174 

175 variables["v"] = psi ** (1 - params.CRRA) * value_next(next_state) 

176 variables["v_der"] = ( 

177 factor * psi ** (-params.CRRA) * value_next.derivative(next_state) 

178 ) 

179 

180 variables["contributions"] = variables["v"] 

181 variables["value"] = np.sum(variables["v"]) 

182 

183 return variables 

184 

185 

186@runtime_checkable 

187class Transitions(Protocol): 

188 """ 

189 Protocol defining the interface for model-specific transitions. 

190 

191 Each model type (PerfForesight, IndShock, RiskyAsset, etc.) implements 

192 this protocol with its specific transition dynamics. The transitions 

193 include: 

194 - post_state: How savings today become resources tomorrow 

195 - continuation: How to compute continuation value from post-state 

196 """ 

197 

198 requires_shocks: bool 

199 

200 def post_state( 

201 self, 

202 post_state: dict[str, Any], 

203 shocks: dict[str, Any] | None, 

204 params: SimpleNamespace, 

205 ) -> dict[str, Any]: 

206 """Transform post-decision state to next period's state.""" 

207 ... 

208 

209 def continuation( 

210 self, 

211 post_state: dict[str, Any], 

212 shocks: dict[str, Any] | None, 

213 value_next: ValueFuncCRRALabeled, 

214 params: SimpleNamespace, 

215 utility: UtilityFuncCRRA, 

216 ) -> dict[str, Any]: 

217 """Compute continuation value from post-decision state.""" 

218 ... 

219 

220 

221class PerfectForesightTransitions: 

222 """ 

223 Transitions for perfect foresight consumption model. 

224 

225 In perfect foresight, there are no shocks. Next period's market 

226 resources depend only on savings, risk-free return, and growth. 

227 

228 State transition: mNrm_{t+1} = aNrm_t * Rfree / PermGroFac + 1 

229 """ 

230 

231 requires_shocks: bool = False 

232 

233 def post_state( 

234 self, 

235 post_state: dict[str, Any], 

236 shocks: dict[str, Any] | None, 

237 params: SimpleNamespace, 

238 ) -> dict[str, Any]: 

239 """ 

240 Transform savings to next period's market resources. 

241 

242 Parameters 

243 ---------- 

244 post_state : dict 

245 Post-decision state with 'aNrm' (normalized assets). 

246 shocks : dict or None 

247 Not used for perfect foresight. 

248 params : SimpleNamespace 

249 Parameters including Rfree and PermGroFac. 

250 

251 Returns 

252 ------- 

253 dict 

254 Next state with 'mNrm' (normalized market resources). 

255 """ 

256 next_state = {} 

257 next_state["mNrm"] = post_state["aNrm"] * params.Rfree / params.PermGroFac + 1 

258 return next_state 

259 

260 def continuation( 

261 self, 

262 post_state: dict[str, Any], 

263 shocks: dict[str, Any] | None, 

264 value_next: ValueFuncCRRALabeled, 

265 params: SimpleNamespace, 

266 utility: UtilityFuncCRRA, 

267 ) -> dict[str, Any]: 

268 """ 

269 Compute continuation value for perfect foresight model. 

270 

271 Parameters 

272 ---------- 

273 post_state : dict 

274 Post-decision state with 'aNrm'. 

275 shocks : dict or None 

276 Not used for perfect foresight. 

277 value_next : ValueFuncCRRALabeled 

278 Next period's value function. 

279 params : SimpleNamespace 

280 Parameters including CRRA, Rfree, PermGroFac. 

281 utility : UtilityFuncCRRA 

282 Utility function for inverse operations. 

283 

284 Returns 

285 ------- 

286 dict 

287 Continuation value variables including v, v_der, v_inv, v_der_inv. 

288 """ 

289 variables = {} 

290 next_state = self.post_state(post_state, shocks, params) 

291 variables.update(next_state) 

292 

293 # Value scaled by permanent income growth 

294 variables["v"] = params.PermGroFac ** (1 - params.CRRA) * value_next(next_state) 

295 

296 # Marginal value scaled by return and growth 

297 variables["v_der"] = ( 

298 params.Rfree 

299 * params.PermGroFac ** (-params.CRRA) 

300 * value_next.derivative(next_state) 

301 ) 

302 

303 variables["v_inv"] = utility.inv(variables["v"]) 

304 variables["v_der_inv"] = utility.derinv(variables["v_der"]) 

305 

306 variables["contributions"] = variables["v"] 

307 variables["value"] = np.sum(variables["v"]) 

308 

309 return variables 

310 

311 

312class IndShockTransitions: 

313 """ 

314 Transitions for model with idiosyncratic income shocks. 

315 

316 Adds permanent and transitory income shocks to the transition. 

317 

318 State transition: mNrm_{t+1} = aNrm_t * Rfree / (PermGroFac * perm) + tran 

319 """ 

320 

321 requires_shocks: bool = True 

322 _required_shock_keys: set[str] = {"perm", "tran"} 

323 

324 def post_state( 

325 self, 

326 post_state: dict[str, Any], 

327 shocks: dict[str, Any], 

328 params: SimpleNamespace, 

329 ) -> dict[str, Any]: 

330 """ 

331 Transform savings to next period's market resources with income shocks. 

332 

333 Parameters 

334 ---------- 

335 post_state : dict 

336 Post-decision state with 'aNrm'. 

337 shocks : dict 

338 Income shocks with 'perm' and 'tran'. 

339 params : SimpleNamespace 

340 Parameters including Rfree and PermGroFac. 

341 

342 Returns 

343 ------- 

344 dict 

345 Next state with 'mNrm'. 

346 

347 Raises 

348 ------ 

349 KeyError 

350 If required shock keys are missing. 

351 """ 

352 return _simple_post_state(self, post_state, shocks, params, params.Rfree) 

353 

354 def continuation( 

355 self, 

356 post_state: dict[str, Any], 

357 shocks: dict[str, Any], 

358 value_next: ValueFuncCRRALabeled, 

359 params: SimpleNamespace, 

360 utility: UtilityFuncCRRA, 

361 ) -> dict[str, Any]: 

362 """ 

363 Compute continuation value with income shocks. 

364 

365 Parameters 

366 ---------- 

367 post_state : dict 

368 Post-decision state with 'aNrm'. 

369 shocks : dict 

370 Income shocks with 'perm' and 'tran'. 

371 value_next : ValueFuncCRRALabeled 

372 Next period's value function. 

373 params : SimpleNamespace 

374 Parameters including CRRA, Rfree, PermGroFac. 

375 utility : UtilityFuncCRRA 

376 Utility function for inverse operations. 

377 

378 Returns 

379 ------- 

380 dict 

381 Continuation value variables. 

382 """ 

383 return _base_continuation( 

384 self, post_state, shocks, value_next, params, params.Rfree 

385 ) 

386 

387 

388class RiskyAssetTransitions: 

389 """ 

390 Transitions for model with risky asset returns. 

391 

392 Savings earn a stochastic risky return instead of risk-free rate. 

393 

394 State transition: mNrm_{t+1} = aNrm_t * risky / (PermGroFac * perm) + tran 

395 """ 

396 

397 requires_shocks: bool = True 

398 _required_shock_keys: set[str] = {"perm", "tran", "risky"} 

399 

400 def post_state( 

401 self, 

402 post_state: dict[str, Any], 

403 shocks: dict[str, Any], 

404 params: SimpleNamespace, 

405 ) -> dict[str, Any]: 

406 """ 

407 Transform savings with risky asset return. 

408 

409 Parameters 

410 ---------- 

411 post_state : dict 

412 Post-decision state with 'aNrm'. 

413 shocks : dict 

414 Shocks with 'perm', 'tran', and 'risky'. 

415 params : SimpleNamespace 

416 Parameters including PermGroFac. 

417 

418 Returns 

419 ------- 

420 dict 

421 Next state with 'mNrm'. 

422 

423 Raises 

424 ------ 

425 KeyError 

426 If required shock keys are missing. 

427 """ 

428 return _simple_post_state(self, post_state, shocks, params, shocks["risky"]) 

429 

430 def continuation( 

431 self, 

432 post_state: dict[str, Any], 

433 shocks: dict[str, Any], 

434 value_next: ValueFuncCRRALabeled, 

435 params: SimpleNamespace, 

436 utility: UtilityFuncCRRA, 

437 ) -> dict[str, Any]: 

438 """ 

439 Compute continuation value with risky asset. 

440 

441 The marginal value is scaled by the risky return instead of Rfree. 

442 

443 Parameters 

444 ---------- 

445 post_state : dict 

446 Post-decision state with 'aNrm'. 

447 shocks : dict 

448 Shocks with 'perm', 'tran', and 'risky'. 

449 value_next : ValueFuncCRRALabeled 

450 Next period's value function. 

451 params : SimpleNamespace 

452 Parameters including CRRA, PermGroFac. 

453 utility : UtilityFuncCRRA 

454 Utility function for inverse operations. 

455 

456 Returns 

457 ------- 

458 dict 

459 Continuation value variables. 

460 """ 

461 return _base_continuation( 

462 self, post_state, shocks, value_next, params, shocks["risky"] 

463 ) 

464 

465 

466class FixedPortfolioTransitions: 

467 """ 

468 Transitions for model with fixed portfolio allocation. 

469 

470 Agent allocates a fixed share to risky asset, earning portfolio return. 

471 

472 Portfolio return: rPort = Rfree + (risky - Rfree) * RiskyShareFixed 

473 State transition: mNrm_{t+1} = aNrm_t * rPort / (PermGroFac * perm) + tran 

474 """ 

475 

476 requires_shocks: bool = True 

477 _required_shock_keys: set[str] = {"perm", "tran", "risky"} 

478 

479 def post_state( 

480 self, 

481 post_state: dict[str, Any], 

482 shocks: dict[str, Any], 

483 params: SimpleNamespace, 

484 ) -> dict[str, Any]: 

485 """ 

486 Transform savings with fixed portfolio return. 

487 

488 Parameters 

489 ---------- 

490 post_state : dict 

491 Post-decision state with 'aNrm'. 

492 shocks : dict 

493 Shocks with 'perm', 'tran', and 'risky'. 

494 params : SimpleNamespace 

495 Parameters including Rfree, PermGroFac, RiskyShareFixed. 

496 

497 Returns 

498 ------- 

499 dict 

500 Next state with 'mNrm', 'rDiff', 'rPort'. 

501 

502 Raises 

503 ------ 

504 KeyError 

505 If required shock keys are missing. 

506 """ 

507 return _portfolio_post_state( 

508 self, post_state, shocks, params, params.RiskyShareFixed 

509 ) 

510 

511 def continuation( 

512 self, 

513 post_state: dict[str, Any], 

514 shocks: dict[str, Any], 

515 value_next: ValueFuncCRRALabeled, 

516 params: SimpleNamespace, 

517 utility: UtilityFuncCRRA, 

518 ) -> dict[str, Any]: 

519 """ 

520 Compute continuation value with fixed portfolio. 

521 

522 The marginal value is scaled by the portfolio return. 

523 

524 Parameters 

525 ---------- 

526 post_state : dict 

527 Post-decision state with 'aNrm'. 

528 shocks : dict 

529 Shocks with 'perm', 'tran', and 'risky'. 

530 value_next : ValueFuncCRRALabeled 

531 Next period's value function. 

532 params : SimpleNamespace 

533 Parameters including CRRA, PermGroFac. 

534 utility : UtilityFuncCRRA 

535 Utility function for inverse operations. 

536 

537 Returns 

538 ------- 

539 dict 

540 Continuation value variables. 

541 """ 

542 return _base_continuation( 

543 self, post_state, shocks, value_next, params, lambda ns: ns["rPort"] 

544 ) 

545 

546 

547class PortfolioTransitions: 

548 """ 

549 Transitions for model with optimal portfolio choice. 

550 

551 Agent optimally chooses risky share (stigma) each period. 

552 

553 Portfolio return: rPort = Rfree + (risky - Rfree) * stigma 

554 State transition: mNrm_{t+1} = aNrm_t * rPort / (PermGroFac * perm) + tran 

555 

556 Also computes derivatives for portfolio optimization: 

557 - dvda: derivative of value wrt assets 

558 - dvds: derivative of value wrt risky share 

559 """ 

560 

561 requires_shocks: bool = True 

562 _required_shock_keys: set[str] = {"perm", "tran", "risky"} 

563 

564 def post_state( 

565 self, 

566 post_state: dict[str, Any], 

567 shocks: dict[str, Any], 

568 params: SimpleNamespace, 

569 ) -> dict[str, Any]: 

570 """ 

571 Transform savings with optimal portfolio return. 

572 

573 Parameters 

574 ---------- 

575 post_state : dict 

576 Post-decision state with 'aNrm' and 'stigma' (risky share). 

577 shocks : dict 

578 Shocks with 'perm', 'tran', and 'risky'. 

579 params : SimpleNamespace 

580 Parameters including Rfree, PermGroFac. 

581 

582 Returns 

583 ------- 

584 dict 

585 Next state with 'mNrm', 'rDiff', 'rPort'. 

586 

587 Raises 

588 ------ 

589 KeyError 

590 If required shock keys are missing. 

591 """ 

592 return _portfolio_post_state( 

593 self, post_state, shocks, params, post_state["stigma"] 

594 ) 

595 

596 def continuation( 

597 self, 

598 post_state: dict[str, Any], 

599 shocks: dict[str, Any], 

600 value_next: ValueFuncCRRALabeled, 

601 params: SimpleNamespace, 

602 utility: UtilityFuncCRRA, 

603 ) -> dict[str, Any]: 

604 """ 

605 Compute continuation value with optimal portfolio. 

606 

607 Uses ``return_factor=1.0`` so that ``v_der`` is unscaled. 

608 Then adds ``dvda`` (portfolio return times ``v_der``) for the 

609 consumption FOC and ``dvds`` (excess return times assets times 

610 ``v_der``) for the portfolio FOC (should equal 0 at optimum). 

611 

612 Parameters 

613 ---------- 

614 post_state : dict 

615 Post-decision state with 'aNrm' and 'stigma'. 

616 shocks : dict 

617 Shocks with 'perm', 'tran', and 'risky'. 

618 value_next : ValueFuncCRRALabeled 

619 Next period's value function. 

620 params : SimpleNamespace 

621 Parameters including CRRA, PermGroFac. 

622 utility : UtilityFuncCRRA 

623 Utility function for inverse operations. 

624 

625 Returns 

626 ------- 

627 dict 

628 Continuation value variables including dvda and dvds. 

629 """ 

630 variables = _base_continuation( 

631 self, post_state, shocks, value_next, params, 1.0 

632 ) 

633 

634 # Derivatives for portfolio optimization 

635 variables["dvda"] = variables["rPort"] * variables["v_der"] 

636 variables["dvds"] = variables["rDiff"] * post_state["aNrm"] * variables["v_der"] 

637 

638 return variables