Coverage for HARK / Labeled / transitions.py: 91%
112 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
1"""
2Transition functions for labeled consumption-saving models.
4This module implements the Strategy pattern for state transitions,
5allowing different model types to share the same solver structure
6while varying only the transition dynamics.
7"""
9from __future__ import annotations
11from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
13import numpy as np
15if TYPE_CHECKING:
16 from types import SimpleNamespace
18 from HARK.rewards import UtilityFuncCRRA
20 from .solution import ValueFuncCRRALabeled
22__all__ = [
23 "Transitions",
24 "PerfectForesightTransitions",
25 "IndShockTransitions",
26 "RiskyAssetTransitions",
27 "FixedPortfolioTransitions",
28 "PortfolioTransitions",
29]
32def _validate_shock_keys(
33 shocks: dict[str, Any], required_keys: set[str], class_name: str
34) -> None:
35 """
36 Validate that shocks dictionary contains required keys.
38 Parameters
39 ----------
40 shocks : dict
41 Shock dictionary to validate.
42 required_keys : set
43 Set of required key names.
44 class_name : str
45 Name of the class for error messages.
47 Raises
48 ------
49 KeyError
50 If any required key is missing from shocks.
51 """
52 missing_keys = required_keys - set(shocks.keys())
53 if missing_keys:
54 raise KeyError(
55 f"{class_name} requires shock keys {required_keys} but got {set(shocks.keys())}. "
56 f"Missing: {missing_keys}. "
57 f"Ensure the shock distribution has the correct variable names."
58 )
61@runtime_checkable
62class Transitions(Protocol):
63 """
64 Protocol defining the interface for model-specific transitions.
66 Each model type (PerfForesight, IndShock, RiskyAsset, etc.) implements
67 this protocol with its specific transition dynamics. The transitions
68 include:
69 - post_state: How savings today become resources tomorrow
70 - continuation: How to compute continuation value from post-state
71 """
73 requires_shocks: bool
75 def post_state(
76 self,
77 post_state: dict[str, Any],
78 shocks: dict[str, Any] | None,
79 params: SimpleNamespace,
80 ) -> dict[str, Any]:
81 """Transform post-decision state to next period's state."""
82 ...
84 def continuation(
85 self,
86 post_state: dict[str, Any],
87 shocks: dict[str, Any] | None,
88 value_next: ValueFuncCRRALabeled,
89 params: SimpleNamespace,
90 utility: UtilityFuncCRRA,
91 ) -> dict[str, Any]:
92 """Compute continuation value from post-decision state."""
93 ...
96class PerfectForesightTransitions:
97 """
98 Transitions for perfect foresight consumption model.
100 In perfect foresight, there are no shocks. Next period's market
101 resources depend only on savings, risk-free return, and growth.
103 State transition: mNrm_{t+1} = aNrm_t * Rfree / PermGroFac + 1
104 """
106 requires_shocks: bool = False
108 def post_state(
109 self,
110 post_state: dict[str, Any],
111 shocks: dict[str, Any] | None,
112 params: SimpleNamespace,
113 ) -> dict[str, Any]:
114 """
115 Transform savings to next period's market resources.
117 Parameters
118 ----------
119 post_state : dict
120 Post-decision state with 'aNrm' (normalized assets).
121 shocks : dict or None
122 Not used for perfect foresight.
123 params : SimpleNamespace
124 Parameters including Rfree and PermGroFac.
126 Returns
127 -------
128 dict
129 Next state with 'mNrm' (normalized market resources).
130 """
131 next_state = {}
132 next_state["mNrm"] = post_state["aNrm"] * params.Rfree / params.PermGroFac + 1
133 return next_state
135 def continuation(
136 self,
137 post_state: dict[str, Any],
138 shocks: dict[str, Any] | None,
139 value_next: ValueFuncCRRALabeled,
140 params: SimpleNamespace,
141 utility: UtilityFuncCRRA,
142 ) -> dict[str, Any]:
143 """
144 Compute continuation value for perfect foresight model.
146 Parameters
147 ----------
148 post_state : dict
149 Post-decision state with 'aNrm'.
150 shocks : dict or None
151 Not used for perfect foresight.
152 value_next : ValueFuncCRRALabeled
153 Next period's value function.
154 params : SimpleNamespace
155 Parameters including CRRA, Rfree, PermGroFac.
156 utility : UtilityFuncCRRA
157 Utility function for inverse operations.
159 Returns
160 -------
161 dict
162 Continuation value variables including v, v_der, v_inv, v_der_inv.
163 """
164 variables = {}
165 next_state = self.post_state(post_state, shocks, params)
166 variables.update(next_state)
168 # Value scaled by permanent income growth
169 variables["v"] = params.PermGroFac ** (1 - params.CRRA) * value_next(next_state)
171 # Marginal value scaled by return and growth
172 variables["v_der"] = (
173 params.Rfree
174 * params.PermGroFac ** (-params.CRRA)
175 * value_next.derivative(next_state)
176 )
178 variables["v_inv"] = utility.inv(variables["v"])
179 variables["v_der_inv"] = utility.derinv(variables["v_der"])
181 variables["contributions"] = variables["v"]
182 variables["value"] = np.sum(variables["v"])
184 return variables
187class IndShockTransitions:
188 """
189 Transitions for model with idiosyncratic income shocks.
191 Adds permanent and transitory income shocks to the transition.
193 State transition: mNrm_{t+1} = aNrm_t * Rfree / (PermGroFac * perm) + tran
194 """
196 requires_shocks: bool = True
197 _required_shock_keys: set[str] = {"perm", "tran"}
199 def post_state(
200 self,
201 post_state: dict[str, Any],
202 shocks: dict[str, Any],
203 params: SimpleNamespace,
204 ) -> dict[str, Any]:
205 """
206 Transform savings to next period's market resources with income shocks.
208 Parameters
209 ----------
210 post_state : dict
211 Post-decision state with 'aNrm'.
212 shocks : dict
213 Income shocks with 'perm' and 'tran'.
214 params : SimpleNamespace
215 Parameters including Rfree and PermGroFac.
217 Returns
218 -------
219 dict
220 Next state with 'mNrm'.
222 Raises
223 ------
224 KeyError
225 If required shock keys are missing.
226 """
227 _validate_shock_keys(shocks, self._required_shock_keys, "IndShockTransitions")
228 next_state = {}
229 next_state["mNrm"] = (
230 post_state["aNrm"] * params.Rfree / (params.PermGroFac * shocks["perm"])
231 + shocks["tran"]
232 )
233 return next_state
235 def continuation(
236 self,
237 post_state: dict[str, Any],
238 shocks: dict[str, Any],
239 value_next: ValueFuncCRRALabeled,
240 params: SimpleNamespace,
241 utility: UtilityFuncCRRA,
242 ) -> dict[str, Any]:
243 """
244 Compute continuation value with income shocks.
246 Parameters
247 ----------
248 post_state : dict
249 Post-decision state with 'aNrm'.
250 shocks : dict
251 Income shocks with 'perm' and 'tran'.
252 value_next : ValueFuncCRRALabeled
253 Next period's value function.
254 params : SimpleNamespace
255 Parameters including CRRA, Rfree, PermGroFac.
256 utility : UtilityFuncCRRA
257 Utility function for inverse operations.
259 Returns
260 -------
261 dict
262 Continuation value variables.
263 """
264 variables = {}
265 next_state = self.post_state(post_state, shocks, params)
266 variables.update(next_state)
268 # Permanent income scaling
269 psi = params.PermGroFac * shocks["perm"]
270 variables["psi"] = psi
272 variables["v"] = psi ** (1 - params.CRRA) * value_next(next_state)
273 variables["v_der"] = (
274 params.Rfree * psi ** (-params.CRRA) * value_next.derivative(next_state)
275 )
277 variables["contributions"] = variables["v"]
278 variables["value"] = np.sum(variables["v"])
280 return variables
283class RiskyAssetTransitions:
284 """
285 Transitions for model with risky asset returns.
287 Savings earn a stochastic risky return instead of risk-free rate.
289 State transition: mNrm_{t+1} = aNrm_t * risky / (PermGroFac * perm) + tran
290 """
292 requires_shocks: bool = True
293 _required_shock_keys: set[str] = {"perm", "tran", "risky"}
295 def post_state(
296 self,
297 post_state: dict[str, Any],
298 shocks: dict[str, Any],
299 params: SimpleNamespace,
300 ) -> dict[str, Any]:
301 """
302 Transform savings with risky asset return.
304 Parameters
305 ----------
306 post_state : dict
307 Post-decision state with 'aNrm'.
308 shocks : dict
309 Shocks with 'perm', 'tran', and 'risky'.
310 params : SimpleNamespace
311 Parameters including PermGroFac.
313 Returns
314 -------
315 dict
316 Next state with 'mNrm'.
318 Raises
319 ------
320 KeyError
321 If required shock keys are missing.
322 """
323 _validate_shock_keys(shocks, self._required_shock_keys, "RiskyAssetTransitions")
324 next_state = {}
325 next_state["mNrm"] = (
326 post_state["aNrm"] * shocks["risky"] / (params.PermGroFac * shocks["perm"])
327 + shocks["tran"]
328 )
329 return next_state
331 def continuation(
332 self,
333 post_state: dict[str, Any],
334 shocks: dict[str, Any],
335 value_next: ValueFuncCRRALabeled,
336 params: SimpleNamespace,
337 utility: UtilityFuncCRRA,
338 ) -> dict[str, Any]:
339 """
340 Compute continuation value with risky asset.
342 The marginal value is scaled by the risky return instead of Rfree.
344 Parameters
345 ----------
346 post_state : dict
347 Post-decision state with 'aNrm'.
348 shocks : dict
349 Shocks with 'perm', 'tran', and 'risky'.
350 value_next : ValueFuncCRRALabeled
351 Next period's value function.
352 params : SimpleNamespace
353 Parameters including CRRA, PermGroFac.
354 utility : UtilityFuncCRRA
355 Utility function for inverse operations.
357 Returns
358 -------
359 dict
360 Continuation value variables.
361 """
362 variables = {}
363 next_state = self.post_state(post_state, shocks, params)
364 variables.update(next_state)
366 psi = params.PermGroFac * shocks["perm"]
367 variables["psi"] = psi
369 variables["v"] = psi ** (1 - params.CRRA) * value_next(next_state)
370 # Risky return scales marginal value
371 variables["v_der"] = (
372 shocks["risky"] * psi ** (-params.CRRA) * value_next.derivative(next_state)
373 )
375 variables["contributions"] = variables["v"]
376 variables["value"] = np.sum(variables["v"])
378 return variables
381class FixedPortfolioTransitions:
382 """
383 Transitions for model with fixed portfolio allocation.
385 Agent allocates a fixed share to risky asset, earning portfolio return.
387 Portfolio return: rPort = Rfree + (risky - Rfree) * RiskyShareFixed
388 State transition: mNrm_{t+1} = aNrm_t * rPort / (PermGroFac * perm) + tran
389 """
391 requires_shocks: bool = True
392 _required_shock_keys: set[str] = {"perm", "tran", "risky"}
394 def post_state(
395 self,
396 post_state: dict[str, Any],
397 shocks: dict[str, Any],
398 params: SimpleNamespace,
399 ) -> dict[str, Any]:
400 """
401 Transform savings with fixed portfolio return.
403 Parameters
404 ----------
405 post_state : dict
406 Post-decision state with 'aNrm'.
407 shocks : dict
408 Shocks with 'perm', 'tran', and 'risky'.
409 params : SimpleNamespace
410 Parameters including Rfree, PermGroFac, RiskyShareFixed.
412 Returns
413 -------
414 dict
415 Next state with 'mNrm', 'rDiff', 'rPort'.
417 Raises
418 ------
419 KeyError
420 If required shock keys are missing.
421 """
422 _validate_shock_keys(
423 shocks, self._required_shock_keys, "FixedPortfolioTransitions"
424 )
425 next_state = {}
426 next_state["rDiff"] = shocks["risky"] - params.Rfree
427 next_state["rPort"] = (
428 params.Rfree + next_state["rDiff"] * params.RiskyShareFixed
429 )
430 next_state["mNrm"] = (
431 post_state["aNrm"]
432 * next_state["rPort"]
433 / (params.PermGroFac * shocks["perm"])
434 + shocks["tran"]
435 )
436 return next_state
438 def continuation(
439 self,
440 post_state: dict[str, Any],
441 shocks: dict[str, Any],
442 value_next: ValueFuncCRRALabeled,
443 params: SimpleNamespace,
444 utility: UtilityFuncCRRA,
445 ) -> dict[str, Any]:
446 """
447 Compute continuation value with fixed portfolio.
449 The marginal value is scaled by the portfolio return.
451 Parameters
452 ----------
453 post_state : dict
454 Post-decision state with 'aNrm'.
455 shocks : dict
456 Shocks with 'perm', 'tran', and 'risky'.
457 value_next : ValueFuncCRRALabeled
458 Next period's value function.
459 params : SimpleNamespace
460 Parameters including CRRA, PermGroFac.
461 utility : UtilityFuncCRRA
462 Utility function for inverse operations.
464 Returns
465 -------
466 dict
467 Continuation value variables.
468 """
469 variables = {}
470 next_state = self.post_state(post_state, shocks, params)
471 variables.update(next_state)
473 psi = params.PermGroFac * shocks["perm"]
474 variables["psi"] = psi
476 variables["v"] = psi ** (1 - params.CRRA) * value_next(next_state)
477 # Portfolio return scales marginal value
478 variables["v_der"] = (
479 next_state["rPort"]
480 * psi ** (-params.CRRA)
481 * value_next.derivative(next_state)
482 )
484 variables["contributions"] = variables["v"]
485 variables["value"] = np.sum(variables["v"])
487 return variables
490class PortfolioTransitions:
491 """
492 Transitions for model with optimal portfolio choice.
494 Agent optimally chooses risky share (stigma) each period.
496 Portfolio return: rPort = Rfree + (risky - Rfree) * stigma
497 State transition: mNrm_{t+1} = aNrm_t * rPort / (PermGroFac * perm) + tran
499 Also computes derivatives for portfolio optimization:
500 - dvda: derivative of value wrt assets
501 - dvds: derivative of value wrt risky share
502 """
504 requires_shocks: bool = True
505 _required_shock_keys: set[str] = {"perm", "tran", "risky"}
507 def post_state(
508 self,
509 post_state: dict[str, Any],
510 shocks: dict[str, Any],
511 params: SimpleNamespace,
512 ) -> dict[str, Any]:
513 """
514 Transform savings with optimal portfolio return.
516 Parameters
517 ----------
518 post_state : dict
519 Post-decision state with 'aNrm' and 'stigma' (risky share).
520 shocks : dict
521 Shocks with 'perm', 'tran', and 'risky'.
522 params : SimpleNamespace
523 Parameters including Rfree, PermGroFac.
525 Returns
526 -------
527 dict
528 Next state with 'mNrm', 'rDiff', 'rPort'.
530 Raises
531 ------
532 KeyError
533 If required shock keys are missing.
534 """
535 _validate_shock_keys(shocks, self._required_shock_keys, "PortfolioTransitions")
536 next_state = {}
537 next_state["rDiff"] = shocks["risky"] - params.Rfree
538 next_state["rPort"] = params.Rfree + next_state["rDiff"] * post_state["stigma"]
539 next_state["mNrm"] = (
540 post_state["aNrm"]
541 * next_state["rPort"]
542 / (params.PermGroFac * shocks["perm"])
543 + shocks["tran"]
544 )
545 return next_state
547 def continuation(
548 self,
549 post_state: dict[str, Any],
550 shocks: dict[str, Any],
551 value_next: ValueFuncCRRALabeled,
552 params: SimpleNamespace,
553 utility: UtilityFuncCRRA,
554 ) -> dict[str, Any]:
555 """
556 Compute continuation value with optimal portfolio.
558 Also computes derivatives needed for portfolio optimization:
559 - dvda: used for consumption FOC
560 - dvds: used for portfolio FOC (should equal 0 at optimum)
562 Parameters
563 ----------
564 post_state : dict
565 Post-decision state with 'aNrm' and 'stigma'.
566 shocks : dict
567 Shocks with 'perm', 'tran', and 'risky'.
568 value_next : ValueFuncCRRALabeled
569 Next period's value function.
570 params : SimpleNamespace
571 Parameters including CRRA, PermGroFac.
572 utility : UtilityFuncCRRA
573 Utility function for inverse operations.
575 Returns
576 -------
577 dict
578 Continuation value variables including dvda and dvds.
579 """
580 variables = {}
581 next_state = self.post_state(post_state, shocks, params)
582 variables.update(next_state)
584 psi = params.PermGroFac * shocks["perm"]
585 variables["psi"] = psi
587 variables["v"] = psi ** (1 - params.CRRA) * value_next(next_state)
588 variables["v_der"] = psi ** (-params.CRRA) * value_next.derivative(next_state)
590 # Derivatives for portfolio optimization
591 variables["dvda"] = next_state["rPort"] * variables["v_der"]
592 variables["dvds"] = (
593 next_state["rDiff"] * post_state["aNrm"] * variables["v_der"]
594 )
596 variables["contributions"] = variables["v"]
597 variables["value"] = np.sum(variables["v"])
599 return variables