Coverage for HARK / Labeled / transitions.py: 98%
87 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-10 06:19 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-10 06:19 +0000
1"""
2Transition functions for labeled consumption-saving models.
4This module implements the Strategy pattern for state transitions,
5allowing different model types to share the same solver structure
6while varying only the transition dynamics.
7"""
9from __future__ import annotations
11from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
13import numpy as np
15if TYPE_CHECKING:
16 from types import SimpleNamespace
18 from HARK.rewards import UtilityFuncCRRA
20 from .solution import ValueFuncCRRALabeled
22__all__ = [
23 "Transitions",
24 "PerfectForesightTransitions",
25 "IndShockTransitions",
26 "RiskyAssetTransitions",
27 "FixedPortfolioTransitions",
28 "PortfolioTransitions",
29]
32def _validate_shock_keys(
33 shocks: dict[str, Any], required_keys: set[str], class_name: str
34) -> None:
35 """
36 Validate that shocks dictionary contains required keys.
38 Parameters
39 ----------
40 shocks : dict
41 Shock dictionary to validate.
42 required_keys : set
43 Set of required key names.
44 class_name : str
45 Name of the class for error messages.
47 Raises
48 ------
49 KeyError
50 If any required key is missing from shocks.
51 """
52 missing_keys = required_keys - set(shocks.keys())
53 if missing_keys:
54 raise KeyError(
55 f"{class_name} requires shock keys {required_keys} but got {set(shocks.keys())}. "
56 f"Missing: {missing_keys}. "
57 f"Ensure the shock distribution has the correct variable names."
58 )
61def _simple_post_state(
62 transitions,
63 post_state: dict[str, Any],
64 shocks: dict[str, Any],
65 params: "SimpleNamespace",
66 return_rate: Any,
67) -> dict[str, Any]:
68 """
69 Shared ``post_state`` body for transitions whose only return component is
70 a single asset return rate (no portfolio decomposition).
72 Validates shock keys, then maps post-decision assets through to next-period
73 market resources ``mNrm = aNrm * return_rate / (PermGroFac * perm) + tran``.
74 Used by :class:`IndShockTransitions` (``return_rate = params.Rfree``) and
75 :class:`RiskyAssetTransitions` (``return_rate = shocks["risky"]``).
76 """
77 _validate_shock_keys(
78 shocks, transitions._required_shock_keys, type(transitions).__name__
79 )
80 next_state = {}
81 next_state["mNrm"] = (
82 post_state["aNrm"] * return_rate / (params.PermGroFac * shocks["perm"])
83 + shocks["tran"]
84 )
85 return next_state
88def _portfolio_post_state(
89 transitions,
90 post_state: dict[str, Any],
91 shocks: dict[str, Any],
92 params: "SimpleNamespace",
93 risky_share: Any,
94) -> dict[str, Any]:
95 """
96 Shared ``post_state`` body for portfolio transition classes.
98 Validates shock keys, computes the excess return ``rDiff`` and portfolio
99 return ``rPort`` for the supplied ``risky_share``, then maps post-decision
100 assets through to next-period market resources ``mNrm``.
102 The two portfolio variants differ only in where ``risky_share`` comes from:
103 ``FixedPortfolioTransitions`` reads ``params.RiskyShareFixed`` while
104 ``PortfolioTransitions`` reads ``post_state["stigma"]``.
105 """
106 _validate_shock_keys(
107 shocks, transitions._required_shock_keys, type(transitions).__name__
108 )
109 next_state = {}
110 next_state["rDiff"] = shocks["risky"] - params.Rfree
111 next_state["rPort"] = params.Rfree + next_state["rDiff"] * risky_share
112 next_state["mNrm"] = (
113 post_state["aNrm"] * next_state["rPort"] / (params.PermGroFac * shocks["perm"])
114 + shocks["tran"]
115 )
116 return next_state
119def _base_continuation(
120 transitions,
121 post_state: dict[str, Any],
122 shocks: dict[str, Any],
123 value_next: "ValueFuncCRRALabeled",
124 params: "SimpleNamespace",
125 return_factor: Any,
126) -> dict[str, Any]:
127 """
128 Shared computation kernel for stochastic continuation methods.
130 Computes the next state, permanent income scaling (psi), value (v),
131 marginal value (v_der), contributions, and aggregate value that are
132 common to all stochastic transition classes.
134 Parameters
135 ----------
136 transitions : Transitions instance
137 The calling transitions object, whose ``post_state`` method is
138 used to map the post-decision state forward through shocks.
139 post_state : dict
140 Post-decision state (e.g. containing 'aNrm').
141 shocks : dict
142 Realized shocks for this quadrature node (must include 'perm').
143 value_next : ValueFuncCRRALabeled
144 Next period's value function, callable and with a ``derivative``
145 method.
146 params : SimpleNamespace
147 Model parameters; must expose ``PermGroFac`` and ``CRRA``.
148 return_factor : scalar, array, or callable
149 Factor that scales the marginal value ``v_der``. Pass
150 ``params.Rfree`` for IndShock, ``shocks["risky"]`` for
151 RiskyAsset, ``1.0`` for Portfolio (which applies its own scaling
152 afterward), or a callable ``(next_state) -> factor`` when the
153 factor depends on ``next_state`` (e.g. ``lambda ns: ns["rPort"]``
154 for FixedPortfolio).
156 Returns
157 -------
158 dict
159 Variables dict containing: all entries from ``next_state``,
160 ``psi``, ``v``, ``v_der``, ``contributions``, and ``value``.
161 """
162 variables = {}
163 next_state = transitions.post_state(post_state, shocks, params)
164 variables.update(next_state)
166 psi = params.PermGroFac * shocks["perm"]
167 variables["psi"] = psi
169 # Allow return_factor to depend on next_state without a second post_state call.
170 if callable(return_factor):
171 factor = return_factor(next_state)
172 else:
173 factor = return_factor
175 variables["v"] = psi ** (1 - params.CRRA) * value_next(next_state)
176 variables["v_der"] = (
177 factor * psi ** (-params.CRRA) * value_next.derivative(next_state)
178 )
180 variables["contributions"] = variables["v"]
181 variables["value"] = np.sum(variables["v"])
183 return variables
186@runtime_checkable
187class Transitions(Protocol):
188 """
189 Protocol defining the interface for model-specific transitions.
191 Each model type (PerfForesight, IndShock, RiskyAsset, etc.) implements
192 this protocol with its specific transition dynamics. The transitions
193 include:
194 - post_state: How savings today become resources tomorrow
195 - continuation: How to compute continuation value from post-state
196 """
198 requires_shocks: bool
200 def post_state(
201 self,
202 post_state: dict[str, Any],
203 shocks: dict[str, Any] | None,
204 params: SimpleNamespace,
205 ) -> dict[str, Any]:
206 """Transform post-decision state to next period's state."""
207 ...
209 def continuation(
210 self,
211 post_state: dict[str, Any],
212 shocks: dict[str, Any] | None,
213 value_next: ValueFuncCRRALabeled,
214 params: SimpleNamespace,
215 utility: UtilityFuncCRRA,
216 ) -> dict[str, Any]:
217 """Compute continuation value from post-decision state."""
218 ...
221class PerfectForesightTransitions:
222 """
223 Transitions for perfect foresight consumption model.
225 In perfect foresight, there are no shocks. Next period's market
226 resources depend only on savings, risk-free return, and growth.
228 State transition: mNrm_{t+1} = aNrm_t * Rfree / PermGroFac + 1
229 """
231 requires_shocks: bool = False
233 def post_state(
234 self,
235 post_state: dict[str, Any],
236 shocks: dict[str, Any] | None,
237 params: SimpleNamespace,
238 ) -> dict[str, Any]:
239 """
240 Transform savings to next period's market resources.
242 Parameters
243 ----------
244 post_state : dict
245 Post-decision state with 'aNrm' (normalized assets).
246 shocks : dict or None
247 Not used for perfect foresight.
248 params : SimpleNamespace
249 Parameters including Rfree and PermGroFac.
251 Returns
252 -------
253 dict
254 Next state with 'mNrm' (normalized market resources).
255 """
256 next_state = {}
257 next_state["mNrm"] = post_state["aNrm"] * params.Rfree / params.PermGroFac + 1
258 return next_state
260 def continuation(
261 self,
262 post_state: dict[str, Any],
263 shocks: dict[str, Any] | None,
264 value_next: ValueFuncCRRALabeled,
265 params: SimpleNamespace,
266 utility: UtilityFuncCRRA,
267 ) -> dict[str, Any]:
268 """
269 Compute continuation value for perfect foresight model.
271 Parameters
272 ----------
273 post_state : dict
274 Post-decision state with 'aNrm'.
275 shocks : dict or None
276 Not used for perfect foresight.
277 value_next : ValueFuncCRRALabeled
278 Next period's value function.
279 params : SimpleNamespace
280 Parameters including CRRA, Rfree, PermGroFac.
281 utility : UtilityFuncCRRA
282 Utility function for inverse operations.
284 Returns
285 -------
286 dict
287 Continuation value variables including v, v_der, v_inv, v_der_inv.
288 """
289 variables = {}
290 next_state = self.post_state(post_state, shocks, params)
291 variables.update(next_state)
293 # Value scaled by permanent income growth
294 variables["v"] = params.PermGroFac ** (1 - params.CRRA) * value_next(next_state)
296 # Marginal value scaled by return and growth
297 variables["v_der"] = (
298 params.Rfree
299 * params.PermGroFac ** (-params.CRRA)
300 * value_next.derivative(next_state)
301 )
303 variables["v_inv"] = utility.inv(variables["v"])
304 variables["v_der_inv"] = utility.derinv(variables["v_der"])
306 variables["contributions"] = variables["v"]
307 variables["value"] = np.sum(variables["v"])
309 return variables
312class IndShockTransitions:
313 """
314 Transitions for model with idiosyncratic income shocks.
316 Adds permanent and transitory income shocks to the transition.
318 State transition: mNrm_{t+1} = aNrm_t * Rfree / (PermGroFac * perm) + tran
319 """
321 requires_shocks: bool = True
322 _required_shock_keys: set[str] = {"perm", "tran"}
324 def post_state(
325 self,
326 post_state: dict[str, Any],
327 shocks: dict[str, Any],
328 params: SimpleNamespace,
329 ) -> dict[str, Any]:
330 """
331 Transform savings to next period's market resources with income shocks.
333 Parameters
334 ----------
335 post_state : dict
336 Post-decision state with 'aNrm'.
337 shocks : dict
338 Income shocks with 'perm' and 'tran'.
339 params : SimpleNamespace
340 Parameters including Rfree and PermGroFac.
342 Returns
343 -------
344 dict
345 Next state with 'mNrm'.
347 Raises
348 ------
349 KeyError
350 If required shock keys are missing.
351 """
352 return _simple_post_state(self, post_state, shocks, params, params.Rfree)
354 def continuation(
355 self,
356 post_state: dict[str, Any],
357 shocks: dict[str, Any],
358 value_next: ValueFuncCRRALabeled,
359 params: SimpleNamespace,
360 utility: UtilityFuncCRRA,
361 ) -> dict[str, Any]:
362 """
363 Compute continuation value with income shocks.
365 Parameters
366 ----------
367 post_state : dict
368 Post-decision state with 'aNrm'.
369 shocks : dict
370 Income shocks with 'perm' and 'tran'.
371 value_next : ValueFuncCRRALabeled
372 Next period's value function.
373 params : SimpleNamespace
374 Parameters including CRRA, Rfree, PermGroFac.
375 utility : UtilityFuncCRRA
376 Utility function for inverse operations.
378 Returns
379 -------
380 dict
381 Continuation value variables.
382 """
383 return _base_continuation(
384 self, post_state, shocks, value_next, params, params.Rfree
385 )
388class RiskyAssetTransitions:
389 """
390 Transitions for model with risky asset returns.
392 Savings earn a stochastic risky return instead of risk-free rate.
394 State transition: mNrm_{t+1} = aNrm_t * risky / (PermGroFac * perm) + tran
395 """
397 requires_shocks: bool = True
398 _required_shock_keys: set[str] = {"perm", "tran", "risky"}
400 def post_state(
401 self,
402 post_state: dict[str, Any],
403 shocks: dict[str, Any],
404 params: SimpleNamespace,
405 ) -> dict[str, Any]:
406 """
407 Transform savings with risky asset return.
409 Parameters
410 ----------
411 post_state : dict
412 Post-decision state with 'aNrm'.
413 shocks : dict
414 Shocks with 'perm', 'tran', and 'risky'.
415 params : SimpleNamespace
416 Parameters including PermGroFac.
418 Returns
419 -------
420 dict
421 Next state with 'mNrm'.
423 Raises
424 ------
425 KeyError
426 If required shock keys are missing.
427 """
428 return _simple_post_state(self, post_state, shocks, params, shocks["risky"])
430 def continuation(
431 self,
432 post_state: dict[str, Any],
433 shocks: dict[str, Any],
434 value_next: ValueFuncCRRALabeled,
435 params: SimpleNamespace,
436 utility: UtilityFuncCRRA,
437 ) -> dict[str, Any]:
438 """
439 Compute continuation value with risky asset.
441 The marginal value is scaled by the risky return instead of Rfree.
443 Parameters
444 ----------
445 post_state : dict
446 Post-decision state with 'aNrm'.
447 shocks : dict
448 Shocks with 'perm', 'tran', and 'risky'.
449 value_next : ValueFuncCRRALabeled
450 Next period's value function.
451 params : SimpleNamespace
452 Parameters including CRRA, PermGroFac.
453 utility : UtilityFuncCRRA
454 Utility function for inverse operations.
456 Returns
457 -------
458 dict
459 Continuation value variables.
460 """
461 return _base_continuation(
462 self, post_state, shocks, value_next, params, shocks["risky"]
463 )
466class FixedPortfolioTransitions:
467 """
468 Transitions for model with fixed portfolio allocation.
470 Agent allocates a fixed share to risky asset, earning portfolio return.
472 Portfolio return: rPort = Rfree + (risky - Rfree) * RiskyShareFixed
473 State transition: mNrm_{t+1} = aNrm_t * rPort / (PermGroFac * perm) + tran
474 """
476 requires_shocks: bool = True
477 _required_shock_keys: set[str] = {"perm", "tran", "risky"}
479 def post_state(
480 self,
481 post_state: dict[str, Any],
482 shocks: dict[str, Any],
483 params: SimpleNamespace,
484 ) -> dict[str, Any]:
485 """
486 Transform savings with fixed portfolio return.
488 Parameters
489 ----------
490 post_state : dict
491 Post-decision state with 'aNrm'.
492 shocks : dict
493 Shocks with 'perm', 'tran', and 'risky'.
494 params : SimpleNamespace
495 Parameters including Rfree, PermGroFac, RiskyShareFixed.
497 Returns
498 -------
499 dict
500 Next state with 'mNrm', 'rDiff', 'rPort'.
502 Raises
503 ------
504 KeyError
505 If required shock keys are missing.
506 """
507 return _portfolio_post_state(
508 self, post_state, shocks, params, params.RiskyShareFixed
509 )
511 def continuation(
512 self,
513 post_state: dict[str, Any],
514 shocks: dict[str, Any],
515 value_next: ValueFuncCRRALabeled,
516 params: SimpleNamespace,
517 utility: UtilityFuncCRRA,
518 ) -> dict[str, Any]:
519 """
520 Compute continuation value with fixed portfolio.
522 The marginal value is scaled by the portfolio return.
524 Parameters
525 ----------
526 post_state : dict
527 Post-decision state with 'aNrm'.
528 shocks : dict
529 Shocks with 'perm', 'tran', and 'risky'.
530 value_next : ValueFuncCRRALabeled
531 Next period's value function.
532 params : SimpleNamespace
533 Parameters including CRRA, PermGroFac.
534 utility : UtilityFuncCRRA
535 Utility function for inverse operations.
537 Returns
538 -------
539 dict
540 Continuation value variables.
541 """
542 return _base_continuation(
543 self, post_state, shocks, value_next, params, lambda ns: ns["rPort"]
544 )
547class PortfolioTransitions:
548 """
549 Transitions for model with optimal portfolio choice.
551 Agent optimally chooses risky share (stigma) each period.
553 Portfolio return: rPort = Rfree + (risky - Rfree) * stigma
554 State transition: mNrm_{t+1} = aNrm_t * rPort / (PermGroFac * perm) + tran
556 Also computes derivatives for portfolio optimization:
557 - dvda: derivative of value wrt assets
558 - dvds: derivative of value wrt risky share
559 """
561 requires_shocks: bool = True
562 _required_shock_keys: set[str] = {"perm", "tran", "risky"}
564 def post_state(
565 self,
566 post_state: dict[str, Any],
567 shocks: dict[str, Any],
568 params: SimpleNamespace,
569 ) -> dict[str, Any]:
570 """
571 Transform savings with optimal portfolio return.
573 Parameters
574 ----------
575 post_state : dict
576 Post-decision state with 'aNrm' and 'stigma' (risky share).
577 shocks : dict
578 Shocks with 'perm', 'tran', and 'risky'.
579 params : SimpleNamespace
580 Parameters including Rfree, PermGroFac.
582 Returns
583 -------
584 dict
585 Next state with 'mNrm', 'rDiff', 'rPort'.
587 Raises
588 ------
589 KeyError
590 If required shock keys are missing.
591 """
592 return _portfolio_post_state(
593 self, post_state, shocks, params, post_state["stigma"]
594 )
596 def continuation(
597 self,
598 post_state: dict[str, Any],
599 shocks: dict[str, Any],
600 value_next: ValueFuncCRRALabeled,
601 params: SimpleNamespace,
602 utility: UtilityFuncCRRA,
603 ) -> dict[str, Any]:
604 """
605 Compute continuation value with optimal portfolio.
607 Uses ``return_factor=1.0`` so that ``v_der`` is unscaled.
608 Then adds ``dvda`` (portfolio return times ``v_der``) for the
609 consumption FOC and ``dvds`` (excess return times assets times
610 ``v_der``) for the portfolio FOC (should equal 0 at optimum).
612 Parameters
613 ----------
614 post_state : dict
615 Post-decision state with 'aNrm' and 'stigma'.
616 shocks : dict
617 Shocks with 'perm', 'tran', and 'risky'.
618 value_next : ValueFuncCRRALabeled
619 Next period's value function.
620 params : SimpleNamespace
621 Parameters including CRRA, PermGroFac.
622 utility : UtilityFuncCRRA
623 Utility function for inverse operations.
625 Returns
626 -------
627 dict
628 Continuation value variables including dvda and dvds.
629 """
630 variables = _base_continuation(
631 self, post_state, shocks, value_next, params, 1.0
632 )
634 # Derivatives for portfolio optimization
635 variables["dvda"] = variables["rPort"] * variables["v_der"]
636 variables["dvds"] = variables["rDiff"] * post_state["aNrm"] * variables["v_der"]
638 return variables