Coverage for HARK / Labeled / transitions.py: 98%
91 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 05:31 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 05:31 +0000
1"""
2Transition functions for labeled consumption-saving models.
4This module implements the Strategy pattern for state transitions,
5allowing different model types to share the same solver structure
6while varying only the transition dynamics.
7"""
9from __future__ import annotations
11from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
13import numpy as np
15if TYPE_CHECKING:
16 from types import SimpleNamespace
18 from HARK.rewards import UtilityFuncCRRA
20 from .solution import ValueFuncCRRALabeled
22__all__ = [
23 "Transitions",
24 "PerfectForesightTransitions",
25 "IndShockTransitions",
26 "RiskyAssetTransitions",
27 "FixedPortfolioTransitions",
28 "PortfolioTransitions",
29]
32def _validate_shock_keys(
33 shocks: dict[str, Any], required_keys: set[str], class_name: str
34) -> None:
35 """
36 Validate that shocks dictionary contains required keys.
38 Parameters
39 ----------
40 shocks : dict
41 Shock dictionary to validate.
42 required_keys : set
43 Set of required key names.
44 class_name : str
45 Name of the class for error messages.
47 Raises
48 ------
49 KeyError
50 If any required key is missing from shocks.
51 """
52 missing_keys = required_keys - set(shocks.keys())
53 if missing_keys:
54 raise KeyError(
55 f"{class_name} requires shock keys {required_keys} but got {set(shocks.keys())}. "
56 f"Missing: {missing_keys}. "
57 f"Ensure the shock distribution has the correct variable names."
58 )
61def _base_continuation(
62 transitions,
63 post_state: dict[str, Any],
64 shocks: dict[str, Any],
65 value_next: "ValueFuncCRRALabeled",
66 params: "SimpleNamespace",
67 return_factor: Any,
68) -> dict[str, Any]:
69 """
70 Shared computation kernel for stochastic continuation methods.
72 Computes the next state, permanent income scaling (psi), value (v),
73 marginal value (v_der), contributions, and aggregate value that are
74 common to all stochastic transition classes.
76 Parameters
77 ----------
78 transitions : Transitions instance
79 The calling transitions object, whose ``post_state`` method is
80 used to map the post-decision state forward through shocks.
81 post_state : dict
82 Post-decision state (e.g. containing 'aNrm').
83 shocks : dict
84 Realized shocks for this quadrature node (must include 'perm').
85 value_next : ValueFuncCRRALabeled
86 Next period's value function, callable and with a ``derivative``
87 method.
88 params : SimpleNamespace
89 Model parameters; must expose ``PermGroFac`` and ``CRRA``.
90 return_factor : scalar, array, or callable
91 Factor that scales the marginal value ``v_der``. Pass
92 ``params.Rfree`` for IndShock, ``shocks["risky"]`` for
93 RiskyAsset, ``1.0`` for Portfolio (which applies its own scaling
94 afterward), or a callable ``(next_state) -> factor`` when the
95 factor depends on ``next_state`` (e.g. ``lambda ns: ns["rPort"]``
96 for FixedPortfolio).
98 Returns
99 -------
100 dict
101 Variables dict containing: all entries from ``next_state``,
102 ``psi``, ``v``, ``v_der``, ``contributions``, and ``value``.
103 """
104 variables = {}
105 next_state = transitions.post_state(post_state, shocks, params)
106 variables.update(next_state)
108 psi = params.PermGroFac * shocks["perm"]
109 variables["psi"] = psi
111 # Allow return_factor to depend on next_state without a second post_state call.
112 if callable(return_factor):
113 factor = return_factor(next_state)
114 else:
115 factor = return_factor
117 variables["v"] = psi ** (1 - params.CRRA) * value_next(next_state)
118 variables["v_der"] = (
119 factor * psi ** (-params.CRRA) * value_next.derivative(next_state)
120 )
122 variables["contributions"] = variables["v"]
123 variables["value"] = np.sum(variables["v"])
125 return variables
128@runtime_checkable
129class Transitions(Protocol):
130 """
131 Protocol defining the interface for model-specific transitions.
133 Each model type (PerfForesight, IndShock, RiskyAsset, etc.) implements
134 this protocol with its specific transition dynamics. The transitions
135 include:
136 - post_state: How savings today become resources tomorrow
137 - continuation: How to compute continuation value from post-state
138 """
140 requires_shocks: bool
142 def post_state(
143 self,
144 post_state: dict[str, Any],
145 shocks: dict[str, Any] | None,
146 params: SimpleNamespace,
147 ) -> dict[str, Any]:
148 """Transform post-decision state to next period's state."""
149 ...
151 def continuation(
152 self,
153 post_state: dict[str, Any],
154 shocks: dict[str, Any] | None,
155 value_next: ValueFuncCRRALabeled,
156 params: SimpleNamespace,
157 utility: UtilityFuncCRRA,
158 ) -> dict[str, Any]:
159 """Compute continuation value from post-decision state."""
160 ...
163class PerfectForesightTransitions:
164 """
165 Transitions for perfect foresight consumption model.
167 In perfect foresight, there are no shocks. Next period's market
168 resources depend only on savings, risk-free return, and growth.
170 State transition: mNrm_{t+1} = aNrm_t * Rfree / PermGroFac + 1
171 """
173 requires_shocks: bool = False
175 def post_state(
176 self,
177 post_state: dict[str, Any],
178 shocks: dict[str, Any] | None,
179 params: SimpleNamespace,
180 ) -> dict[str, Any]:
181 """
182 Transform savings to next period's market resources.
184 Parameters
185 ----------
186 post_state : dict
187 Post-decision state with 'aNrm' (normalized assets).
188 shocks : dict or None
189 Not used for perfect foresight.
190 params : SimpleNamespace
191 Parameters including Rfree and PermGroFac.
193 Returns
194 -------
195 dict
196 Next state with 'mNrm' (normalized market resources).
197 """
198 next_state = {}
199 next_state["mNrm"] = post_state["aNrm"] * params.Rfree / params.PermGroFac + 1
200 return next_state
202 def continuation(
203 self,
204 post_state: dict[str, Any],
205 shocks: dict[str, Any] | None,
206 value_next: ValueFuncCRRALabeled,
207 params: SimpleNamespace,
208 utility: UtilityFuncCRRA,
209 ) -> dict[str, Any]:
210 """
211 Compute continuation value for perfect foresight model.
213 Parameters
214 ----------
215 post_state : dict
216 Post-decision state with 'aNrm'.
217 shocks : dict or None
218 Not used for perfect foresight.
219 value_next : ValueFuncCRRALabeled
220 Next period's value function.
221 params : SimpleNamespace
222 Parameters including CRRA, Rfree, PermGroFac.
223 utility : UtilityFuncCRRA
224 Utility function for inverse operations.
226 Returns
227 -------
228 dict
229 Continuation value variables including v, v_der, v_inv, v_der_inv.
230 """
231 variables = {}
232 next_state = self.post_state(post_state, shocks, params)
233 variables.update(next_state)
235 # Value scaled by permanent income growth
236 variables["v"] = params.PermGroFac ** (1 - params.CRRA) * value_next(next_state)
238 # Marginal value scaled by return and growth
239 variables["v_der"] = (
240 params.Rfree
241 * params.PermGroFac ** (-params.CRRA)
242 * value_next.derivative(next_state)
243 )
245 variables["v_inv"] = utility.inv(variables["v"])
246 variables["v_der_inv"] = utility.derinv(variables["v_der"])
248 variables["contributions"] = variables["v"]
249 variables["value"] = np.sum(variables["v"])
251 return variables
254class IndShockTransitions:
255 """
256 Transitions for model with idiosyncratic income shocks.
258 Adds permanent and transitory income shocks to the transition.
260 State transition: mNrm_{t+1} = aNrm_t * Rfree / (PermGroFac * perm) + tran
261 """
263 requires_shocks: bool = True
264 _required_shock_keys: set[str] = {"perm", "tran"}
266 def post_state(
267 self,
268 post_state: dict[str, Any],
269 shocks: dict[str, Any],
270 params: SimpleNamespace,
271 ) -> dict[str, Any]:
272 """
273 Transform savings to next period's market resources with income shocks.
275 Parameters
276 ----------
277 post_state : dict
278 Post-decision state with 'aNrm'.
279 shocks : dict
280 Income shocks with 'perm' and 'tran'.
281 params : SimpleNamespace
282 Parameters including Rfree and PermGroFac.
284 Returns
285 -------
286 dict
287 Next state with 'mNrm'.
289 Raises
290 ------
291 KeyError
292 If required shock keys are missing.
293 """
294 _validate_shock_keys(shocks, self._required_shock_keys, "IndShockTransitions")
295 next_state = {}
296 next_state["mNrm"] = (
297 post_state["aNrm"] * params.Rfree / (params.PermGroFac * shocks["perm"])
298 + shocks["tran"]
299 )
300 return next_state
302 def continuation(
303 self,
304 post_state: dict[str, Any],
305 shocks: dict[str, Any],
306 value_next: ValueFuncCRRALabeled,
307 params: SimpleNamespace,
308 utility: UtilityFuncCRRA,
309 ) -> dict[str, Any]:
310 """
311 Compute continuation value with income shocks.
313 Parameters
314 ----------
315 post_state : dict
316 Post-decision state with 'aNrm'.
317 shocks : dict
318 Income shocks with 'perm' and 'tran'.
319 value_next : ValueFuncCRRALabeled
320 Next period's value function.
321 params : SimpleNamespace
322 Parameters including CRRA, Rfree, PermGroFac.
323 utility : UtilityFuncCRRA
324 Utility function for inverse operations.
326 Returns
327 -------
328 dict
329 Continuation value variables.
330 """
331 return _base_continuation(
332 self, post_state, shocks, value_next, params, params.Rfree
333 )
336class RiskyAssetTransitions:
337 """
338 Transitions for model with risky asset returns.
340 Savings earn a stochastic risky return instead of risk-free rate.
342 State transition: mNrm_{t+1} = aNrm_t * risky / (PermGroFac * perm) + tran
343 """
345 requires_shocks: bool = True
346 _required_shock_keys: set[str] = {"perm", "tran", "risky"}
348 def post_state(
349 self,
350 post_state: dict[str, Any],
351 shocks: dict[str, Any],
352 params: SimpleNamespace,
353 ) -> dict[str, Any]:
354 """
355 Transform savings with risky asset return.
357 Parameters
358 ----------
359 post_state : dict
360 Post-decision state with 'aNrm'.
361 shocks : dict
362 Shocks with 'perm', 'tran', and 'risky'.
363 params : SimpleNamespace
364 Parameters including PermGroFac.
366 Returns
367 -------
368 dict
369 Next state with 'mNrm'.
371 Raises
372 ------
373 KeyError
374 If required shock keys are missing.
375 """
376 _validate_shock_keys(shocks, self._required_shock_keys, "RiskyAssetTransitions")
377 next_state = {}
378 next_state["mNrm"] = (
379 post_state["aNrm"] * shocks["risky"] / (params.PermGroFac * shocks["perm"])
380 + shocks["tran"]
381 )
382 return next_state
384 def continuation(
385 self,
386 post_state: dict[str, Any],
387 shocks: dict[str, Any],
388 value_next: ValueFuncCRRALabeled,
389 params: SimpleNamespace,
390 utility: UtilityFuncCRRA,
391 ) -> dict[str, Any]:
392 """
393 Compute continuation value with risky asset.
395 The marginal value is scaled by the risky return instead of Rfree.
397 Parameters
398 ----------
399 post_state : dict
400 Post-decision state with 'aNrm'.
401 shocks : dict
402 Shocks with 'perm', 'tran', and 'risky'.
403 value_next : ValueFuncCRRALabeled
404 Next period's value function.
405 params : SimpleNamespace
406 Parameters including CRRA, PermGroFac.
407 utility : UtilityFuncCRRA
408 Utility function for inverse operations.
410 Returns
411 -------
412 dict
413 Continuation value variables.
414 """
415 return _base_continuation(
416 self, post_state, shocks, value_next, params, shocks["risky"]
417 )
420class FixedPortfolioTransitions:
421 """
422 Transitions for model with fixed portfolio allocation.
424 Agent allocates a fixed share to risky asset, earning portfolio return.
426 Portfolio return: rPort = Rfree + (risky - Rfree) * RiskyShareFixed
427 State transition: mNrm_{t+1} = aNrm_t * rPort / (PermGroFac * perm) + tran
428 """
430 requires_shocks: bool = True
431 _required_shock_keys: set[str] = {"perm", "tran", "risky"}
433 def post_state(
434 self,
435 post_state: dict[str, Any],
436 shocks: dict[str, Any],
437 params: SimpleNamespace,
438 ) -> dict[str, Any]:
439 """
440 Transform savings with fixed portfolio return.
442 Parameters
443 ----------
444 post_state : dict
445 Post-decision state with 'aNrm'.
446 shocks : dict
447 Shocks with 'perm', 'tran', and 'risky'.
448 params : SimpleNamespace
449 Parameters including Rfree, PermGroFac, RiskyShareFixed.
451 Returns
452 -------
453 dict
454 Next state with 'mNrm', 'rDiff', 'rPort'.
456 Raises
457 ------
458 KeyError
459 If required shock keys are missing.
460 """
461 _validate_shock_keys(
462 shocks, self._required_shock_keys, "FixedPortfolioTransitions"
463 )
464 next_state = {}
465 next_state["rDiff"] = shocks["risky"] - params.Rfree
466 next_state["rPort"] = (
467 params.Rfree + next_state["rDiff"] * params.RiskyShareFixed
468 )
469 next_state["mNrm"] = (
470 post_state["aNrm"]
471 * next_state["rPort"]
472 / (params.PermGroFac * shocks["perm"])
473 + shocks["tran"]
474 )
475 return next_state
477 def continuation(
478 self,
479 post_state: dict[str, Any],
480 shocks: dict[str, Any],
481 value_next: ValueFuncCRRALabeled,
482 params: SimpleNamespace,
483 utility: UtilityFuncCRRA,
484 ) -> dict[str, Any]:
485 """
486 Compute continuation value with fixed portfolio.
488 The marginal value is scaled by the portfolio return.
490 Parameters
491 ----------
492 post_state : dict
493 Post-decision state with 'aNrm'.
494 shocks : dict
495 Shocks with 'perm', 'tran', and 'risky'.
496 value_next : ValueFuncCRRALabeled
497 Next period's value function.
498 params : SimpleNamespace
499 Parameters including CRRA, PermGroFac.
500 utility : UtilityFuncCRRA
501 Utility function for inverse operations.
503 Returns
504 -------
505 dict
506 Continuation value variables.
507 """
508 return _base_continuation(
509 self, post_state, shocks, value_next, params, lambda ns: ns["rPort"]
510 )
513class PortfolioTransitions:
514 """
515 Transitions for model with optimal portfolio choice.
517 Agent optimally chooses risky share (stigma) each period.
519 Portfolio return: rPort = Rfree + (risky - Rfree) * stigma
520 State transition: mNrm_{t+1} = aNrm_t * rPort / (PermGroFac * perm) + tran
522 Also computes derivatives for portfolio optimization:
523 - dvda: derivative of value wrt assets
524 - dvds: derivative of value wrt risky share
525 """
527 requires_shocks: bool = True
528 _required_shock_keys: set[str] = {"perm", "tran", "risky"}
530 def post_state(
531 self,
532 post_state: dict[str, Any],
533 shocks: dict[str, Any],
534 params: SimpleNamespace,
535 ) -> dict[str, Any]:
536 """
537 Transform savings with optimal portfolio return.
539 Parameters
540 ----------
541 post_state : dict
542 Post-decision state with 'aNrm' and 'stigma' (risky share).
543 shocks : dict
544 Shocks with 'perm', 'tran', and 'risky'.
545 params : SimpleNamespace
546 Parameters including Rfree, PermGroFac.
548 Returns
549 -------
550 dict
551 Next state with 'mNrm', 'rDiff', 'rPort'.
553 Raises
554 ------
555 KeyError
556 If required shock keys are missing.
557 """
558 _validate_shock_keys(shocks, self._required_shock_keys, "PortfolioTransitions")
559 next_state = {}
560 next_state["rDiff"] = shocks["risky"] - params.Rfree
561 next_state["rPort"] = params.Rfree + next_state["rDiff"] * post_state["stigma"]
562 next_state["mNrm"] = (
563 post_state["aNrm"]
564 * next_state["rPort"]
565 / (params.PermGroFac * shocks["perm"])
566 + shocks["tran"]
567 )
568 return next_state
570 def continuation(
571 self,
572 post_state: dict[str, Any],
573 shocks: dict[str, Any],
574 value_next: ValueFuncCRRALabeled,
575 params: SimpleNamespace,
576 utility: UtilityFuncCRRA,
577 ) -> dict[str, Any]:
578 """
579 Compute continuation value with optimal portfolio.
581 Uses ``return_factor=1.0`` so that ``v_der`` is unscaled.
582 Then adds ``dvda`` (portfolio return times ``v_der``) for the
583 consumption FOC and ``dvds`` (excess return times assets times
584 ``v_der``) for the portfolio FOC (should equal 0 at optimum).
586 Parameters
587 ----------
588 post_state : dict
589 Post-decision state with 'aNrm' and 'stigma'.
590 shocks : dict
591 Shocks with 'perm', 'tran', and 'risky'.
592 value_next : ValueFuncCRRALabeled
593 Next period's value function.
594 params : SimpleNamespace
595 Parameters including CRRA, PermGroFac.
596 utility : UtilityFuncCRRA
597 Utility function for inverse operations.
599 Returns
600 -------
601 dict
602 Continuation value variables including dvda and dvds.
603 """
604 variables = _base_continuation(
605 self, post_state, shocks, value_next, params, 1.0
606 )
608 # Derivatives for portfolio optimization
609 variables["dvda"] = variables["rPort"] * variables["v_der"]
610 variables["dvds"] = variables["rDiff"] * post_state["aNrm"] * variables["v_der"]
612 return variables