Coverage for HARK / Labeled / transitions.py: 98%

91 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-08 05:31 +0000

1""" 

2Transition functions for labeled consumption-saving models. 

3 

4This module implements the Strategy pattern for state transitions, 

5allowing different model types to share the same solver structure 

6while varying only the transition dynamics. 

7""" 

8 

9from __future__ import annotations 

10 

11from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable 

12 

13import numpy as np 

14 

15if TYPE_CHECKING: 

16 from types import SimpleNamespace 

17 

18 from HARK.rewards import UtilityFuncCRRA 

19 

20 from .solution import ValueFuncCRRALabeled 

21 

22__all__ = [ 

23 "Transitions", 

24 "PerfectForesightTransitions", 

25 "IndShockTransitions", 

26 "RiskyAssetTransitions", 

27 "FixedPortfolioTransitions", 

28 "PortfolioTransitions", 

29] 

30 

31 

32def _validate_shock_keys( 

33 shocks: dict[str, Any], required_keys: set[str], class_name: str 

34) -> None: 

35 """ 

36 Validate that shocks dictionary contains required keys. 

37 

38 Parameters 

39 ---------- 

40 shocks : dict 

41 Shock dictionary to validate. 

42 required_keys : set 

43 Set of required key names. 

44 class_name : str 

45 Name of the class for error messages. 

46 

47 Raises 

48 ------ 

49 KeyError 

50 If any required key is missing from shocks. 

51 """ 

52 missing_keys = required_keys - set(shocks.keys()) 

53 if missing_keys: 

54 raise KeyError( 

55 f"{class_name} requires shock keys {required_keys} but got {set(shocks.keys())}. " 

56 f"Missing: {missing_keys}. " 

57 f"Ensure the shock distribution has the correct variable names." 

58 ) 

59 

60 

61def _base_continuation( 

62 transitions, 

63 post_state: dict[str, Any], 

64 shocks: dict[str, Any], 

65 value_next: "ValueFuncCRRALabeled", 

66 params: "SimpleNamespace", 

67 return_factor: Any, 

68) -> dict[str, Any]: 

69 """ 

70 Shared computation kernel for stochastic continuation methods. 

71 

72 Computes the next state, permanent income scaling (psi), value (v), 

73 marginal value (v_der), contributions, and aggregate value that are 

74 common to all stochastic transition classes. 

75 

76 Parameters 

77 ---------- 

78 transitions : Transitions instance 

79 The calling transitions object, whose ``post_state`` method is 

80 used to map the post-decision state forward through shocks. 

81 post_state : dict 

82 Post-decision state (e.g. containing 'aNrm'). 

83 shocks : dict 

84 Realized shocks for this quadrature node (must include 'perm'). 

85 value_next : ValueFuncCRRALabeled 

86 Next period's value function, callable and with a ``derivative`` 

87 method. 

88 params : SimpleNamespace 

89 Model parameters; must expose ``PermGroFac`` and ``CRRA``. 

90 return_factor : scalar, array, or callable 

91 Factor that scales the marginal value ``v_der``. Pass 

92 ``params.Rfree`` for IndShock, ``shocks["risky"]`` for 

93 RiskyAsset, ``1.0`` for Portfolio (which applies its own scaling 

94 afterward), or a callable ``(next_state) -> factor`` when the 

95 factor depends on ``next_state`` (e.g. ``lambda ns: ns["rPort"]`` 

96 for FixedPortfolio). 

97 

98 Returns 

99 ------- 

100 dict 

101 Variables dict containing: all entries from ``next_state``, 

102 ``psi``, ``v``, ``v_der``, ``contributions``, and ``value``. 

103 """ 

104 variables = {} 

105 next_state = transitions.post_state(post_state, shocks, params) 

106 variables.update(next_state) 

107 

108 psi = params.PermGroFac * shocks["perm"] 

109 variables["psi"] = psi 

110 

111 # Allow return_factor to depend on next_state without a second post_state call. 

112 if callable(return_factor): 

113 factor = return_factor(next_state) 

114 else: 

115 factor = return_factor 

116 

117 variables["v"] = psi ** (1 - params.CRRA) * value_next(next_state) 

118 variables["v_der"] = ( 

119 factor * psi ** (-params.CRRA) * value_next.derivative(next_state) 

120 ) 

121 

122 variables["contributions"] = variables["v"] 

123 variables["value"] = np.sum(variables["v"]) 

124 

125 return variables 

126 

127 

128@runtime_checkable 

129class Transitions(Protocol): 

130 """ 

131 Protocol defining the interface for model-specific transitions. 

132 

133 Each model type (PerfForesight, IndShock, RiskyAsset, etc.) implements 

134 this protocol with its specific transition dynamics. The transitions 

135 include: 

136 - post_state: How savings today become resources tomorrow 

137 - continuation: How to compute continuation value from post-state 

138 """ 

139 

140 requires_shocks: bool 

141 

142 def post_state( 

143 self, 

144 post_state: dict[str, Any], 

145 shocks: dict[str, Any] | None, 

146 params: SimpleNamespace, 

147 ) -> dict[str, Any]: 

148 """Transform post-decision state to next period's state.""" 

149 ... 

150 

151 def continuation( 

152 self, 

153 post_state: dict[str, Any], 

154 shocks: dict[str, Any] | None, 

155 value_next: ValueFuncCRRALabeled, 

156 params: SimpleNamespace, 

157 utility: UtilityFuncCRRA, 

158 ) -> dict[str, Any]: 

159 """Compute continuation value from post-decision state.""" 

160 ... 

161 

162 

163class PerfectForesightTransitions: 

164 """ 

165 Transitions for perfect foresight consumption model. 

166 

167 In perfect foresight, there are no shocks. Next period's market 

168 resources depend only on savings, risk-free return, and growth. 

169 

170 State transition: mNrm_{t+1} = aNrm_t * Rfree / PermGroFac + 1 

171 """ 

172 

173 requires_shocks: bool = False 

174 

175 def post_state( 

176 self, 

177 post_state: dict[str, Any], 

178 shocks: dict[str, Any] | None, 

179 params: SimpleNamespace, 

180 ) -> dict[str, Any]: 

181 """ 

182 Transform savings to next period's market resources. 

183 

184 Parameters 

185 ---------- 

186 post_state : dict 

187 Post-decision state with 'aNrm' (normalized assets). 

188 shocks : dict or None 

189 Not used for perfect foresight. 

190 params : SimpleNamespace 

191 Parameters including Rfree and PermGroFac. 

192 

193 Returns 

194 ------- 

195 dict 

196 Next state with 'mNrm' (normalized market resources). 

197 """ 

198 next_state = {} 

199 next_state["mNrm"] = post_state["aNrm"] * params.Rfree / params.PermGroFac + 1 

200 return next_state 

201 

202 def continuation( 

203 self, 

204 post_state: dict[str, Any], 

205 shocks: dict[str, Any] | None, 

206 value_next: ValueFuncCRRALabeled, 

207 params: SimpleNamespace, 

208 utility: UtilityFuncCRRA, 

209 ) -> dict[str, Any]: 

210 """ 

211 Compute continuation value for perfect foresight model. 

212 

213 Parameters 

214 ---------- 

215 post_state : dict 

216 Post-decision state with 'aNrm'. 

217 shocks : dict or None 

218 Not used for perfect foresight. 

219 value_next : ValueFuncCRRALabeled 

220 Next period's value function. 

221 params : SimpleNamespace 

222 Parameters including CRRA, Rfree, PermGroFac. 

223 utility : UtilityFuncCRRA 

224 Utility function for inverse operations. 

225 

226 Returns 

227 ------- 

228 dict 

229 Continuation value variables including v, v_der, v_inv, v_der_inv. 

230 """ 

231 variables = {} 

232 next_state = self.post_state(post_state, shocks, params) 

233 variables.update(next_state) 

234 

235 # Value scaled by permanent income growth 

236 variables["v"] = params.PermGroFac ** (1 - params.CRRA) * value_next(next_state) 

237 

238 # Marginal value scaled by return and growth 

239 variables["v_der"] = ( 

240 params.Rfree 

241 * params.PermGroFac ** (-params.CRRA) 

242 * value_next.derivative(next_state) 

243 ) 

244 

245 variables["v_inv"] = utility.inv(variables["v"]) 

246 variables["v_der_inv"] = utility.derinv(variables["v_der"]) 

247 

248 variables["contributions"] = variables["v"] 

249 variables["value"] = np.sum(variables["v"]) 

250 

251 return variables 

252 

253 

254class IndShockTransitions: 

255 """ 

256 Transitions for model with idiosyncratic income shocks. 

257 

258 Adds permanent and transitory income shocks to the transition. 

259 

260 State transition: mNrm_{t+1} = aNrm_t * Rfree / (PermGroFac * perm) + tran 

261 """ 

262 

263 requires_shocks: bool = True 

264 _required_shock_keys: set[str] = {"perm", "tran"} 

265 

266 def post_state( 

267 self, 

268 post_state: dict[str, Any], 

269 shocks: dict[str, Any], 

270 params: SimpleNamespace, 

271 ) -> dict[str, Any]: 

272 """ 

273 Transform savings to next period's market resources with income shocks. 

274 

275 Parameters 

276 ---------- 

277 post_state : dict 

278 Post-decision state with 'aNrm'. 

279 shocks : dict 

280 Income shocks with 'perm' and 'tran'. 

281 params : SimpleNamespace 

282 Parameters including Rfree and PermGroFac. 

283 

284 Returns 

285 ------- 

286 dict 

287 Next state with 'mNrm'. 

288 

289 Raises 

290 ------ 

291 KeyError 

292 If required shock keys are missing. 

293 """ 

294 _validate_shock_keys(shocks, self._required_shock_keys, "IndShockTransitions") 

295 next_state = {} 

296 next_state["mNrm"] = ( 

297 post_state["aNrm"] * params.Rfree / (params.PermGroFac * shocks["perm"]) 

298 + shocks["tran"] 

299 ) 

300 return next_state 

301 

302 def continuation( 

303 self, 

304 post_state: dict[str, Any], 

305 shocks: dict[str, Any], 

306 value_next: ValueFuncCRRALabeled, 

307 params: SimpleNamespace, 

308 utility: UtilityFuncCRRA, 

309 ) -> dict[str, Any]: 

310 """ 

311 Compute continuation value with income shocks. 

312 

313 Parameters 

314 ---------- 

315 post_state : dict 

316 Post-decision state with 'aNrm'. 

317 shocks : dict 

318 Income shocks with 'perm' and 'tran'. 

319 value_next : ValueFuncCRRALabeled 

320 Next period's value function. 

321 params : SimpleNamespace 

322 Parameters including CRRA, Rfree, PermGroFac. 

323 utility : UtilityFuncCRRA 

324 Utility function for inverse operations. 

325 

326 Returns 

327 ------- 

328 dict 

329 Continuation value variables. 

330 """ 

331 return _base_continuation( 

332 self, post_state, shocks, value_next, params, params.Rfree 

333 ) 

334 

335 

336class RiskyAssetTransitions: 

337 """ 

338 Transitions for model with risky asset returns. 

339 

340 Savings earn a stochastic risky return instead of risk-free rate. 

341 

342 State transition: mNrm_{t+1} = aNrm_t * risky / (PermGroFac * perm) + tran 

343 """ 

344 

345 requires_shocks: bool = True 

346 _required_shock_keys: set[str] = {"perm", "tran", "risky"} 

347 

348 def post_state( 

349 self, 

350 post_state: dict[str, Any], 

351 shocks: dict[str, Any], 

352 params: SimpleNamespace, 

353 ) -> dict[str, Any]: 

354 """ 

355 Transform savings with risky asset return. 

356 

357 Parameters 

358 ---------- 

359 post_state : dict 

360 Post-decision state with 'aNrm'. 

361 shocks : dict 

362 Shocks with 'perm', 'tran', and 'risky'. 

363 params : SimpleNamespace 

364 Parameters including PermGroFac. 

365 

366 Returns 

367 ------- 

368 dict 

369 Next state with 'mNrm'. 

370 

371 Raises 

372 ------ 

373 KeyError 

374 If required shock keys are missing. 

375 """ 

376 _validate_shock_keys(shocks, self._required_shock_keys, "RiskyAssetTransitions") 

377 next_state = {} 

378 next_state["mNrm"] = ( 

379 post_state["aNrm"] * shocks["risky"] / (params.PermGroFac * shocks["perm"]) 

380 + shocks["tran"] 

381 ) 

382 return next_state 

383 

384 def continuation( 

385 self, 

386 post_state: dict[str, Any], 

387 shocks: dict[str, Any], 

388 value_next: ValueFuncCRRALabeled, 

389 params: SimpleNamespace, 

390 utility: UtilityFuncCRRA, 

391 ) -> dict[str, Any]: 

392 """ 

393 Compute continuation value with risky asset. 

394 

395 The marginal value is scaled by the risky return instead of Rfree. 

396 

397 Parameters 

398 ---------- 

399 post_state : dict 

400 Post-decision state with 'aNrm'. 

401 shocks : dict 

402 Shocks with 'perm', 'tran', and 'risky'. 

403 value_next : ValueFuncCRRALabeled 

404 Next period's value function. 

405 params : SimpleNamespace 

406 Parameters including CRRA, PermGroFac. 

407 utility : UtilityFuncCRRA 

408 Utility function for inverse operations. 

409 

410 Returns 

411 ------- 

412 dict 

413 Continuation value variables. 

414 """ 

415 return _base_continuation( 

416 self, post_state, shocks, value_next, params, shocks["risky"] 

417 ) 

418 

419 

420class FixedPortfolioTransitions: 

421 """ 

422 Transitions for model with fixed portfolio allocation. 

423 

424 Agent allocates a fixed share to risky asset, earning portfolio return. 

425 

426 Portfolio return: rPort = Rfree + (risky - Rfree) * RiskyShareFixed 

427 State transition: mNrm_{t+1} = aNrm_t * rPort / (PermGroFac * perm) + tran 

428 """ 

429 

430 requires_shocks: bool = True 

431 _required_shock_keys: set[str] = {"perm", "tran", "risky"} 

432 

433 def post_state( 

434 self, 

435 post_state: dict[str, Any], 

436 shocks: dict[str, Any], 

437 params: SimpleNamespace, 

438 ) -> dict[str, Any]: 

439 """ 

440 Transform savings with fixed portfolio return. 

441 

442 Parameters 

443 ---------- 

444 post_state : dict 

445 Post-decision state with 'aNrm'. 

446 shocks : dict 

447 Shocks with 'perm', 'tran', and 'risky'. 

448 params : SimpleNamespace 

449 Parameters including Rfree, PermGroFac, RiskyShareFixed. 

450 

451 Returns 

452 ------- 

453 dict 

454 Next state with 'mNrm', 'rDiff', 'rPort'. 

455 

456 Raises 

457 ------ 

458 KeyError 

459 If required shock keys are missing. 

460 """ 

461 _validate_shock_keys( 

462 shocks, self._required_shock_keys, "FixedPortfolioTransitions" 

463 ) 

464 next_state = {} 

465 next_state["rDiff"] = shocks["risky"] - params.Rfree 

466 next_state["rPort"] = ( 

467 params.Rfree + next_state["rDiff"] * params.RiskyShareFixed 

468 ) 

469 next_state["mNrm"] = ( 

470 post_state["aNrm"] 

471 * next_state["rPort"] 

472 / (params.PermGroFac * shocks["perm"]) 

473 + shocks["tran"] 

474 ) 

475 return next_state 

476 

477 def continuation( 

478 self, 

479 post_state: dict[str, Any], 

480 shocks: dict[str, Any], 

481 value_next: ValueFuncCRRALabeled, 

482 params: SimpleNamespace, 

483 utility: UtilityFuncCRRA, 

484 ) -> dict[str, Any]: 

485 """ 

486 Compute continuation value with fixed portfolio. 

487 

488 The marginal value is scaled by the portfolio return. 

489 

490 Parameters 

491 ---------- 

492 post_state : dict 

493 Post-decision state with 'aNrm'. 

494 shocks : dict 

495 Shocks with 'perm', 'tran', and 'risky'. 

496 value_next : ValueFuncCRRALabeled 

497 Next period's value function. 

498 params : SimpleNamespace 

499 Parameters including CRRA, PermGroFac. 

500 utility : UtilityFuncCRRA 

501 Utility function for inverse operations. 

502 

503 Returns 

504 ------- 

505 dict 

506 Continuation value variables. 

507 """ 

508 return _base_continuation( 

509 self, post_state, shocks, value_next, params, lambda ns: ns["rPort"] 

510 ) 

511 

512 

513class PortfolioTransitions: 

514 """ 

515 Transitions for model with optimal portfolio choice. 

516 

517 Agent optimally chooses risky share (stigma) each period. 

518 

519 Portfolio return: rPort = Rfree + (risky - Rfree) * stigma 

520 State transition: mNrm_{t+1} = aNrm_t * rPort / (PermGroFac * perm) + tran 

521 

522 Also computes derivatives for portfolio optimization: 

523 - dvda: derivative of value wrt assets 

524 - dvds: derivative of value wrt risky share 

525 """ 

526 

527 requires_shocks: bool = True 

528 _required_shock_keys: set[str] = {"perm", "tran", "risky"} 

529 

530 def post_state( 

531 self, 

532 post_state: dict[str, Any], 

533 shocks: dict[str, Any], 

534 params: SimpleNamespace, 

535 ) -> dict[str, Any]: 

536 """ 

537 Transform savings with optimal portfolio return. 

538 

539 Parameters 

540 ---------- 

541 post_state : dict 

542 Post-decision state with 'aNrm' and 'stigma' (risky share). 

543 shocks : dict 

544 Shocks with 'perm', 'tran', and 'risky'. 

545 params : SimpleNamespace 

546 Parameters including Rfree, PermGroFac. 

547 

548 Returns 

549 ------- 

550 dict 

551 Next state with 'mNrm', 'rDiff', 'rPort'. 

552 

553 Raises 

554 ------ 

555 KeyError 

556 If required shock keys are missing. 

557 """ 

558 _validate_shock_keys(shocks, self._required_shock_keys, "PortfolioTransitions") 

559 next_state = {} 

560 next_state["rDiff"] = shocks["risky"] - params.Rfree 

561 next_state["rPort"] = params.Rfree + next_state["rDiff"] * post_state["stigma"] 

562 next_state["mNrm"] = ( 

563 post_state["aNrm"] 

564 * next_state["rPort"] 

565 / (params.PermGroFac * shocks["perm"]) 

566 + shocks["tran"] 

567 ) 

568 return next_state 

569 

570 def continuation( 

571 self, 

572 post_state: dict[str, Any], 

573 shocks: dict[str, Any], 

574 value_next: ValueFuncCRRALabeled, 

575 params: SimpleNamespace, 

576 utility: UtilityFuncCRRA, 

577 ) -> dict[str, Any]: 

578 """ 

579 Compute continuation value with optimal portfolio. 

580 

581 Uses ``return_factor=1.0`` so that ``v_der`` is unscaled. 

582 Then adds ``dvda`` (portfolio return times ``v_der``) for the 

583 consumption FOC and ``dvds`` (excess return times assets times 

584 ``v_der``) for the portfolio FOC (should equal 0 at optimum). 

585 

586 Parameters 

587 ---------- 

588 post_state : dict 

589 Post-decision state with 'aNrm' and 'stigma'. 

590 shocks : dict 

591 Shocks with 'perm', 'tran', and 'risky'. 

592 value_next : ValueFuncCRRALabeled 

593 Next period's value function. 

594 params : SimpleNamespace 

595 Parameters including CRRA, PermGroFac. 

596 utility : UtilityFuncCRRA 

597 Utility function for inverse operations. 

598 

599 Returns 

600 ------- 

601 dict 

602 Continuation value variables including dvda and dvds. 

603 """ 

604 variables = _base_continuation( 

605 self, post_state, shocks, value_next, params, 1.0 

606 ) 

607 

608 # Derivatives for portfolio optimization 

609 variables["dvda"] = variables["rPort"] * variables["v_der"] 

610 variables["dvds"] = variables["rDiff"] * post_state["aNrm"] * variables["v_der"] 

611 

612 return variables