Coverage for HARK / Labeled / transitions.py: 91%

112 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-25 05:22 +0000

1""" 

2Transition functions for labeled consumption-saving models. 

3 

4This module implements the Strategy pattern for state transitions, 

5allowing different model types to share the same solver structure 

6while varying only the transition dynamics. 

7""" 

8 

9from __future__ import annotations 

10 

11from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable 

12 

13import numpy as np 

14 

15if TYPE_CHECKING: 

16 from types import SimpleNamespace 

17 

18 from HARK.rewards import UtilityFuncCRRA 

19 

20 from .solution import ValueFuncCRRALabeled 

21 

22__all__ = [ 

23 "Transitions", 

24 "PerfectForesightTransitions", 

25 "IndShockTransitions", 

26 "RiskyAssetTransitions", 

27 "FixedPortfolioTransitions", 

28 "PortfolioTransitions", 

29] 

30 

31 

32def _validate_shock_keys( 

33 shocks: dict[str, Any], required_keys: set[str], class_name: str 

34) -> None: 

35 """ 

36 Validate that shocks dictionary contains required keys. 

37 

38 Parameters 

39 ---------- 

40 shocks : dict 

41 Shock dictionary to validate. 

42 required_keys : set 

43 Set of required key names. 

44 class_name : str 

45 Name of the class for error messages. 

46 

47 Raises 

48 ------ 

49 KeyError 

50 If any required key is missing from shocks. 

51 """ 

52 missing_keys = required_keys - set(shocks.keys()) 

53 if missing_keys: 

54 raise KeyError( 

55 f"{class_name} requires shock keys {required_keys} but got {set(shocks.keys())}. " 

56 f"Missing: {missing_keys}. " 

57 f"Ensure the shock distribution has the correct variable names." 

58 ) 

59 

60 

61@runtime_checkable 

62class Transitions(Protocol): 

63 """ 

64 Protocol defining the interface for model-specific transitions. 

65 

66 Each model type (PerfForesight, IndShock, RiskyAsset, etc.) implements 

67 this protocol with its specific transition dynamics. The transitions 

68 include: 

69 - post_state: How savings today become resources tomorrow 

70 - continuation: How to compute continuation value from post-state 

71 """ 

72 

73 requires_shocks: bool 

74 

75 def post_state( 

76 self, 

77 post_state: dict[str, Any], 

78 shocks: dict[str, Any] | None, 

79 params: SimpleNamespace, 

80 ) -> dict[str, Any]: 

81 """Transform post-decision state to next period's state.""" 

82 ... 

83 

84 def continuation( 

85 self, 

86 post_state: dict[str, Any], 

87 shocks: dict[str, Any] | None, 

88 value_next: ValueFuncCRRALabeled, 

89 params: SimpleNamespace, 

90 utility: UtilityFuncCRRA, 

91 ) -> dict[str, Any]: 

92 """Compute continuation value from post-decision state.""" 

93 ... 

94 

95 

96class PerfectForesightTransitions: 

97 """ 

98 Transitions for perfect foresight consumption model. 

99 

100 In perfect foresight, there are no shocks. Next period's market 

101 resources depend only on savings, risk-free return, and growth. 

102 

103 State transition: mNrm_{t+1} = aNrm_t * Rfree / PermGroFac + 1 

104 """ 

105 

106 requires_shocks: bool = False 

107 

108 def post_state( 

109 self, 

110 post_state: dict[str, Any], 

111 shocks: dict[str, Any] | None, 

112 params: SimpleNamespace, 

113 ) -> dict[str, Any]: 

114 """ 

115 Transform savings to next period's market resources. 

116 

117 Parameters 

118 ---------- 

119 post_state : dict 

120 Post-decision state with 'aNrm' (normalized assets). 

121 shocks : dict or None 

122 Not used for perfect foresight. 

123 params : SimpleNamespace 

124 Parameters including Rfree and PermGroFac. 

125 

126 Returns 

127 ------- 

128 dict 

129 Next state with 'mNrm' (normalized market resources). 

130 """ 

131 next_state = {} 

132 next_state["mNrm"] = post_state["aNrm"] * params.Rfree / params.PermGroFac + 1 

133 return next_state 

134 

135 def continuation( 

136 self, 

137 post_state: dict[str, Any], 

138 shocks: dict[str, Any] | None, 

139 value_next: ValueFuncCRRALabeled, 

140 params: SimpleNamespace, 

141 utility: UtilityFuncCRRA, 

142 ) -> dict[str, Any]: 

143 """ 

144 Compute continuation value for perfect foresight model. 

145 

146 Parameters 

147 ---------- 

148 post_state : dict 

149 Post-decision state with 'aNrm'. 

150 shocks : dict or None 

151 Not used for perfect foresight. 

152 value_next : ValueFuncCRRALabeled 

153 Next period's value function. 

154 params : SimpleNamespace 

155 Parameters including CRRA, Rfree, PermGroFac. 

156 utility : UtilityFuncCRRA 

157 Utility function for inverse operations. 

158 

159 Returns 

160 ------- 

161 dict 

162 Continuation value variables including v, v_der, v_inv, v_der_inv. 

163 """ 

164 variables = {} 

165 next_state = self.post_state(post_state, shocks, params) 

166 variables.update(next_state) 

167 

168 # Value scaled by permanent income growth 

169 variables["v"] = params.PermGroFac ** (1 - params.CRRA) * value_next(next_state) 

170 

171 # Marginal value scaled by return and growth 

172 variables["v_der"] = ( 

173 params.Rfree 

174 * params.PermGroFac ** (-params.CRRA) 

175 * value_next.derivative(next_state) 

176 ) 

177 

178 variables["v_inv"] = utility.inv(variables["v"]) 

179 variables["v_der_inv"] = utility.derinv(variables["v_der"]) 

180 

181 variables["contributions"] = variables["v"] 

182 variables["value"] = np.sum(variables["v"]) 

183 

184 return variables 

185 

186 

187class IndShockTransitions: 

188 """ 

189 Transitions for model with idiosyncratic income shocks. 

190 

191 Adds permanent and transitory income shocks to the transition. 

192 

193 State transition: mNrm_{t+1} = aNrm_t * Rfree / (PermGroFac * perm) + tran 

194 """ 

195 

196 requires_shocks: bool = True 

197 _required_shock_keys: set[str] = {"perm", "tran"} 

198 

199 def post_state( 

200 self, 

201 post_state: dict[str, Any], 

202 shocks: dict[str, Any], 

203 params: SimpleNamespace, 

204 ) -> dict[str, Any]: 

205 """ 

206 Transform savings to next period's market resources with income shocks. 

207 

208 Parameters 

209 ---------- 

210 post_state : dict 

211 Post-decision state with 'aNrm'. 

212 shocks : dict 

213 Income shocks with 'perm' and 'tran'. 

214 params : SimpleNamespace 

215 Parameters including Rfree and PermGroFac. 

216 

217 Returns 

218 ------- 

219 dict 

220 Next state with 'mNrm'. 

221 

222 Raises 

223 ------ 

224 KeyError 

225 If required shock keys are missing. 

226 """ 

227 _validate_shock_keys(shocks, self._required_shock_keys, "IndShockTransitions") 

228 next_state = {} 

229 next_state["mNrm"] = ( 

230 post_state["aNrm"] * params.Rfree / (params.PermGroFac * shocks["perm"]) 

231 + shocks["tran"] 

232 ) 

233 return next_state 

234 

235 def continuation( 

236 self, 

237 post_state: dict[str, Any], 

238 shocks: dict[str, Any], 

239 value_next: ValueFuncCRRALabeled, 

240 params: SimpleNamespace, 

241 utility: UtilityFuncCRRA, 

242 ) -> dict[str, Any]: 

243 """ 

244 Compute continuation value with income shocks. 

245 

246 Parameters 

247 ---------- 

248 post_state : dict 

249 Post-decision state with 'aNrm'. 

250 shocks : dict 

251 Income shocks with 'perm' and 'tran'. 

252 value_next : ValueFuncCRRALabeled 

253 Next period's value function. 

254 params : SimpleNamespace 

255 Parameters including CRRA, Rfree, PermGroFac. 

256 utility : UtilityFuncCRRA 

257 Utility function for inverse operations. 

258 

259 Returns 

260 ------- 

261 dict 

262 Continuation value variables. 

263 """ 

264 variables = {} 

265 next_state = self.post_state(post_state, shocks, params) 

266 variables.update(next_state) 

267 

268 # Permanent income scaling 

269 psi = params.PermGroFac * shocks["perm"] 

270 variables["psi"] = psi 

271 

272 variables["v"] = psi ** (1 - params.CRRA) * value_next(next_state) 

273 variables["v_der"] = ( 

274 params.Rfree * psi ** (-params.CRRA) * value_next.derivative(next_state) 

275 ) 

276 

277 variables["contributions"] = variables["v"] 

278 variables["value"] = np.sum(variables["v"]) 

279 

280 return variables 

281 

282 

283class RiskyAssetTransitions: 

284 """ 

285 Transitions for model with risky asset returns. 

286 

287 Savings earn a stochastic risky return instead of risk-free rate. 

288 

289 State transition: mNrm_{t+1} = aNrm_t * risky / (PermGroFac * perm) + tran 

290 """ 

291 

292 requires_shocks: bool = True 

293 _required_shock_keys: set[str] = {"perm", "tran", "risky"} 

294 

295 def post_state( 

296 self, 

297 post_state: dict[str, Any], 

298 shocks: dict[str, Any], 

299 params: SimpleNamespace, 

300 ) -> dict[str, Any]: 

301 """ 

302 Transform savings with risky asset return. 

303 

304 Parameters 

305 ---------- 

306 post_state : dict 

307 Post-decision state with 'aNrm'. 

308 shocks : dict 

309 Shocks with 'perm', 'tran', and 'risky'. 

310 params : SimpleNamespace 

311 Parameters including PermGroFac. 

312 

313 Returns 

314 ------- 

315 dict 

316 Next state with 'mNrm'. 

317 

318 Raises 

319 ------ 

320 KeyError 

321 If required shock keys are missing. 

322 """ 

323 _validate_shock_keys(shocks, self._required_shock_keys, "RiskyAssetTransitions") 

324 next_state = {} 

325 next_state["mNrm"] = ( 

326 post_state["aNrm"] * shocks["risky"] / (params.PermGroFac * shocks["perm"]) 

327 + shocks["tran"] 

328 ) 

329 return next_state 

330 

331 def continuation( 

332 self, 

333 post_state: dict[str, Any], 

334 shocks: dict[str, Any], 

335 value_next: ValueFuncCRRALabeled, 

336 params: SimpleNamespace, 

337 utility: UtilityFuncCRRA, 

338 ) -> dict[str, Any]: 

339 """ 

340 Compute continuation value with risky asset. 

341 

342 The marginal value is scaled by the risky return instead of Rfree. 

343 

344 Parameters 

345 ---------- 

346 post_state : dict 

347 Post-decision state with 'aNrm'. 

348 shocks : dict 

349 Shocks with 'perm', 'tran', and 'risky'. 

350 value_next : ValueFuncCRRALabeled 

351 Next period's value function. 

352 params : SimpleNamespace 

353 Parameters including CRRA, PermGroFac. 

354 utility : UtilityFuncCRRA 

355 Utility function for inverse operations. 

356 

357 Returns 

358 ------- 

359 dict 

360 Continuation value variables. 

361 """ 

362 variables = {} 

363 next_state = self.post_state(post_state, shocks, params) 

364 variables.update(next_state) 

365 

366 psi = params.PermGroFac * shocks["perm"] 

367 variables["psi"] = psi 

368 

369 variables["v"] = psi ** (1 - params.CRRA) * value_next(next_state) 

370 # Risky return scales marginal value 

371 variables["v_der"] = ( 

372 shocks["risky"] * psi ** (-params.CRRA) * value_next.derivative(next_state) 

373 ) 

374 

375 variables["contributions"] = variables["v"] 

376 variables["value"] = np.sum(variables["v"]) 

377 

378 return variables 

379 

380 

381class FixedPortfolioTransitions: 

382 """ 

383 Transitions for model with fixed portfolio allocation. 

384 

385 Agent allocates a fixed share to risky asset, earning portfolio return. 

386 

387 Portfolio return: rPort = Rfree + (risky - Rfree) * RiskyShareFixed 

388 State transition: mNrm_{t+1} = aNrm_t * rPort / (PermGroFac * perm) + tran 

389 """ 

390 

391 requires_shocks: bool = True 

392 _required_shock_keys: set[str] = {"perm", "tran", "risky"} 

393 

394 def post_state( 

395 self, 

396 post_state: dict[str, Any], 

397 shocks: dict[str, Any], 

398 params: SimpleNamespace, 

399 ) -> dict[str, Any]: 

400 """ 

401 Transform savings with fixed portfolio return. 

402 

403 Parameters 

404 ---------- 

405 post_state : dict 

406 Post-decision state with 'aNrm'. 

407 shocks : dict 

408 Shocks with 'perm', 'tran', and 'risky'. 

409 params : SimpleNamespace 

410 Parameters including Rfree, PermGroFac, RiskyShareFixed. 

411 

412 Returns 

413 ------- 

414 dict 

415 Next state with 'mNrm', 'rDiff', 'rPort'. 

416 

417 Raises 

418 ------ 

419 KeyError 

420 If required shock keys are missing. 

421 """ 

422 _validate_shock_keys( 

423 shocks, self._required_shock_keys, "FixedPortfolioTransitions" 

424 ) 

425 next_state = {} 

426 next_state["rDiff"] = shocks["risky"] - params.Rfree 

427 next_state["rPort"] = ( 

428 params.Rfree + next_state["rDiff"] * params.RiskyShareFixed 

429 ) 

430 next_state["mNrm"] = ( 

431 post_state["aNrm"] 

432 * next_state["rPort"] 

433 / (params.PermGroFac * shocks["perm"]) 

434 + shocks["tran"] 

435 ) 

436 return next_state 

437 

438 def continuation( 

439 self, 

440 post_state: dict[str, Any], 

441 shocks: dict[str, Any], 

442 value_next: ValueFuncCRRALabeled, 

443 params: SimpleNamespace, 

444 utility: UtilityFuncCRRA, 

445 ) -> dict[str, Any]: 

446 """ 

447 Compute continuation value with fixed portfolio. 

448 

449 The marginal value is scaled by the portfolio return. 

450 

451 Parameters 

452 ---------- 

453 post_state : dict 

454 Post-decision state with 'aNrm'. 

455 shocks : dict 

456 Shocks with 'perm', 'tran', and 'risky'. 

457 value_next : ValueFuncCRRALabeled 

458 Next period's value function. 

459 params : SimpleNamespace 

460 Parameters including CRRA, PermGroFac. 

461 utility : UtilityFuncCRRA 

462 Utility function for inverse operations. 

463 

464 Returns 

465 ------- 

466 dict 

467 Continuation value variables. 

468 """ 

469 variables = {} 

470 next_state = self.post_state(post_state, shocks, params) 

471 variables.update(next_state) 

472 

473 psi = params.PermGroFac * shocks["perm"] 

474 variables["psi"] = psi 

475 

476 variables["v"] = psi ** (1 - params.CRRA) * value_next(next_state) 

477 # Portfolio return scales marginal value 

478 variables["v_der"] = ( 

479 next_state["rPort"] 

480 * psi ** (-params.CRRA) 

481 * value_next.derivative(next_state) 

482 ) 

483 

484 variables["contributions"] = variables["v"] 

485 variables["value"] = np.sum(variables["v"]) 

486 

487 return variables 

488 

489 

490class PortfolioTransitions: 

491 """ 

492 Transitions for model with optimal portfolio choice. 

493 

494 Agent optimally chooses risky share (stigma) each period. 

495 

496 Portfolio return: rPort = Rfree + (risky - Rfree) * stigma 

497 State transition: mNrm_{t+1} = aNrm_t * rPort / (PermGroFac * perm) + tran 

498 

499 Also computes derivatives for portfolio optimization: 

500 - dvda: derivative of value wrt assets 

501 - dvds: derivative of value wrt risky share 

502 """ 

503 

504 requires_shocks: bool = True 

505 _required_shock_keys: set[str] = {"perm", "tran", "risky"} 

506 

507 def post_state( 

508 self, 

509 post_state: dict[str, Any], 

510 shocks: dict[str, Any], 

511 params: SimpleNamespace, 

512 ) -> dict[str, Any]: 

513 """ 

514 Transform savings with optimal portfolio return. 

515 

516 Parameters 

517 ---------- 

518 post_state : dict 

519 Post-decision state with 'aNrm' and 'stigma' (risky share). 

520 shocks : dict 

521 Shocks with 'perm', 'tran', and 'risky'. 

522 params : SimpleNamespace 

523 Parameters including Rfree, PermGroFac. 

524 

525 Returns 

526 ------- 

527 dict 

528 Next state with 'mNrm', 'rDiff', 'rPort'. 

529 

530 Raises 

531 ------ 

532 KeyError 

533 If required shock keys are missing. 

534 """ 

535 _validate_shock_keys(shocks, self._required_shock_keys, "PortfolioTransitions") 

536 next_state = {} 

537 next_state["rDiff"] = shocks["risky"] - params.Rfree 

538 next_state["rPort"] = params.Rfree + next_state["rDiff"] * post_state["stigma"] 

539 next_state["mNrm"] = ( 

540 post_state["aNrm"] 

541 * next_state["rPort"] 

542 / (params.PermGroFac * shocks["perm"]) 

543 + shocks["tran"] 

544 ) 

545 return next_state 

546 

547 def continuation( 

548 self, 

549 post_state: dict[str, Any], 

550 shocks: dict[str, Any], 

551 value_next: ValueFuncCRRALabeled, 

552 params: SimpleNamespace, 

553 utility: UtilityFuncCRRA, 

554 ) -> dict[str, Any]: 

555 """ 

556 Compute continuation value with optimal portfolio. 

557 

558 Also computes derivatives needed for portfolio optimization: 

559 - dvda: used for consumption FOC 

560 - dvds: used for portfolio FOC (should equal 0 at optimum) 

561 

562 Parameters 

563 ---------- 

564 post_state : dict 

565 Post-decision state with 'aNrm' and 'stigma'. 

566 shocks : dict 

567 Shocks with 'perm', 'tran', and 'risky'. 

568 value_next : ValueFuncCRRALabeled 

569 Next period's value function. 

570 params : SimpleNamespace 

571 Parameters including CRRA, PermGroFac. 

572 utility : UtilityFuncCRRA 

573 Utility function for inverse operations. 

574 

575 Returns 

576 ------- 

577 dict 

578 Continuation value variables including dvda and dvds. 

579 """ 

580 variables = {} 

581 next_state = self.post_state(post_state, shocks, params) 

582 variables.update(next_state) 

583 

584 psi = params.PermGroFac * shocks["perm"] 

585 variables["psi"] = psi 

586 

587 variables["v"] = psi ** (1 - params.CRRA) * value_next(next_state) 

588 variables["v_der"] = psi ** (-params.CRRA) * value_next.derivative(next_state) 

589 

590 # Derivatives for portfolio optimization 

591 variables["dvda"] = next_state["rPort"] * variables["v_der"] 

592 variables["dvds"] = ( 

593 next_state["rDiff"] * post_state["aNrm"] * variables["v_der"] 

594 ) 

595 

596 variables["contributions"] = variables["v"] 

597 variables["value"] = np.sum(variables["v"]) 

598 

599 return variables