Coverage for HARK / Labeled / factories.py: 82%
90 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
1"""
2Factory functions for creating labeled solutions and distributions.
4These functions create terminal period solutions and convert standard
5HARK distributions to labeled versions for use with labeled solvers.
6"""
8from __future__ import annotations
10from typing import TYPE_CHECKING
12import numpy as np
13import xarray as xr
15from HARK.Calibration.Assets.AssetProcesses import (
16 combine_IncShkDstn_and_RiskyDstn,
17 make_lognormal_RiskyDstn,
18)
19from HARK.Calibration.Income.IncomeProcesses import (
20 construct_lognormal_income_process_unemployment,
21)
22from HARK.distributions import DiscreteDistributionLabeled
23from HARK.rewards import UtilityFuncCRRA
25from .solution import ConsumerSolutionLabeled, ValueFuncCRRALabeled
27if TYPE_CHECKING:
28 from numpy.random import Generator
30__all__ = [
31 "make_solution_terminal_labeled",
32 "make_labeled_inc_shk_dstn",
33 "make_labeled_risky_dstn",
34 "make_labeled_shock_dstn",
35]
38def make_solution_terminal_labeled(
39 CRRA: float,
40 aXtraGrid: np.ndarray,
41) -> ConsumerSolutionLabeled:
42 """
43 Construct the terminal period solution for a labeled consumption model.
45 In the terminal period, the optimal policy is to consume all resources.
46 This function creates the value function and policy function for this
47 terminal period.
49 Parameters
50 ----------
51 CRRA : float
52 Coefficient of relative risk aversion.
53 aXtraGrid : np.ndarray
54 Grid of assets above minimum. Used to construct the state grid.
56 Returns
57 -------
58 ConsumerSolutionLabeled
59 Terminal period solution with value and policy functions.
61 Raises
62 ------
63 ValueError
64 If CRRA is invalid or aXtraGrid is malformed.
65 """
66 # Input validation
67 if not np.isfinite(CRRA):
68 raise ValueError(f"CRRA must be finite, got {CRRA}")
69 if CRRA <= 0:
70 raise ValueError(f"CRRA must be positive, got {CRRA}")
72 aXtraGrid = np.asarray(aXtraGrid)
73 if len(aXtraGrid) == 0:
74 raise ValueError("aXtraGrid cannot be empty")
75 if np.any(aXtraGrid < 0):
76 raise ValueError("aXtraGrid values must be non-negative")
77 if not np.all(np.diff(aXtraGrid) > 0):
78 raise ValueError("aXtraGrid must be strictly increasing")
80 u = UtilityFuncCRRA(CRRA)
82 # Create state grid
83 mNrm = xr.DataArray(
84 np.append(0.0, aXtraGrid),
85 name="mNrm",
86 dims=("mNrm"),
87 attrs={"long_name": "cash_on_hand"},
88 )
89 state = xr.Dataset({"mNrm": mNrm})
91 # Optimal decision: consume everything in terminal period
92 cNrm = xr.DataArray(
93 mNrm,
94 name="cNrm",
95 dims=state.dims,
96 coords=state.coords,
97 attrs={"long_name": "consumption"},
98 )
100 # Compute value function variables
101 v = u(cNrm)
102 v.name = "v"
103 v.attrs = {"long_name": "value function"}
105 v_der = u.der(cNrm)
106 v_der.name = "v_der"
107 v_der.attrs = {"long_name": "marginal value function"}
109 v_inv = cNrm.copy()
110 v_inv.name = "v_inv"
111 v_inv.attrs = {"long_name": "inverse value function"}
113 v_der_inv = cNrm.copy()
114 v_der_inv.name = "v_der_inv"
115 v_der_inv.attrs = {"long_name": "inverse marginal value function"}
117 dataset = xr.Dataset(
118 {
119 "cNrm": cNrm,
120 "v": v,
121 "v_der": v_der,
122 "v_inv": v_inv,
123 "v_der_inv": v_der_inv,
124 }
125 )
127 vfunc = ValueFuncCRRALabeled(dataset[["v", "v_der", "v_inv", "v_der_inv"]], CRRA)
129 solution_terminal = ConsumerSolutionLabeled(
130 value=vfunc,
131 policy=dataset[["cNrm"]],
132 continuation=None,
133 attrs={"m_nrm_min": 0.0},
134 )
135 return solution_terminal
138def make_labeled_inc_shk_dstn(
139 T_cycle: int,
140 PermShkStd: list[float],
141 PermShkCount: int,
142 TranShkStd: list[float],
143 TranShkCount: int,
144 T_retire: int,
145 UnempPrb: float,
146 IncUnemp: float,
147 UnempPrbRet: float,
148 IncUnempRet: float,
149 RNG: Generator,
150 neutral_measure: bool = False,
151) -> list[DiscreteDistributionLabeled]:
152 """
153 Create labeled income shock distributions.
155 Wrapper around construct_lognormal_income_process_unemployment that
156 converts the resulting distributions to labeled versions.
158 Parameters
159 ----------
160 T_cycle : int
161 Number of periods in the cycle.
162 PermShkStd : list[float]
163 Standard deviation of permanent shocks by period.
164 PermShkCount : int
165 Number of permanent shock points.
166 TranShkStd : list[float]
167 Standard deviation of transitory shocks by period.
168 TranShkCount : int
169 Number of transitory shock points.
170 T_retire : int
171 Period of retirement (0 means never retire).
172 UnempPrb : float
173 Probability of unemployment during working life.
174 IncUnemp : float
175 Income during unemployment.
176 UnempPrbRet : float
177 Probability of "unemployment" in retirement.
178 IncUnempRet : float
179 Income during retirement "unemployment".
180 RNG : Generator
181 Random number generator.
182 neutral_measure : bool, optional
183 Whether to use risk-neutral measure. Default False.
185 Returns
186 -------
187 list[DiscreteDistributionLabeled]
188 List of labeled income shock distributions, one per period.
190 Raises
191 ------
192 ValueError
193 If input parameters fail validation checks.
194 """
195 # Input validation
196 if T_cycle <= 0:
197 raise ValueError(f"T_cycle must be positive, got {T_cycle}")
198 if PermShkCount <= 0:
199 raise ValueError(f"PermShkCount must be positive, got {PermShkCount}")
200 if TranShkCount <= 0:
201 raise ValueError(f"TranShkCount must be positive, got {TranShkCount}")
202 if len(PermShkStd) == 0:
203 raise ValueError("PermShkStd cannot be empty")
204 if len(TranShkStd) == 0:
205 raise ValueError("TranShkStd cannot be empty")
206 if not (0 <= UnempPrb <= 1):
207 raise ValueError(f"UnempPrb must be in [0, 1], got {UnempPrb}")
208 if not (0 <= UnempPrbRet <= 1):
209 raise ValueError(f"UnempPrbRet must be in [0, 1], got {UnempPrbRet}")
210 if RNG is None:
211 raise ValueError("RNG cannot be None")
213 IncShkDstnBase = construct_lognormal_income_process_unemployment(
214 T_cycle,
215 PermShkStd,
216 PermShkCount,
217 TranShkStd,
218 TranShkCount,
219 T_retire,
220 UnempPrb,
221 IncUnemp,
222 UnempPrbRet,
223 IncUnempRet,
224 RNG,
225 neutral_measure,
226 )
228 IncShkDstn = []
229 for i in range(len(IncShkDstnBase.dstns)):
230 IncShkDstn.append(
231 DiscreteDistributionLabeled.from_unlabeled(
232 IncShkDstnBase[i],
233 name="Distribution of Shocks to Income",
234 var_names=["perm", "tran"],
235 )
236 )
237 return IncShkDstn
240def make_labeled_risky_dstn(
241 T_cycle: int,
242 RiskyAvg: float,
243 RiskyStd: float,
244 RiskyCount: int,
245 RNG: Generator,
246) -> DiscreteDistributionLabeled:
247 """
248 Create a labeled risky asset return distribution.
250 Wrapper around make_lognormal_RiskyDstn that converts the result
251 to a labeled distribution.
253 Parameters
254 ----------
255 T_cycle : int
256 Number of periods in the cycle.
257 RiskyAvg : float
258 Mean risky return.
259 RiskyStd : float
260 Standard deviation of risky return.
261 RiskyCount : int
262 Number of risky return points.
263 RNG : Generator
264 Random number generator.
266 Returns
267 -------
268 DiscreteDistributionLabeled
269 Labeled distribution of risky asset returns.
271 Raises
272 ------
273 ValueError
274 If input parameters fail validation checks.
275 """
276 # Input validation
277 if T_cycle <= 0:
278 raise ValueError(f"T_cycle must be positive, got {T_cycle}")
279 if RiskyAvg <= 0:
280 raise ValueError(f"RiskyAvg must be positive, got {RiskyAvg}")
281 if RiskyStd < 0:
282 raise ValueError(f"RiskyStd must be non-negative, got {RiskyStd}")
283 if RiskyCount <= 0:
284 raise ValueError(f"RiskyCount must be positive, got {RiskyCount}")
285 if RNG is None:
286 raise ValueError("RNG cannot be None")
288 RiskyDstnBase = make_lognormal_RiskyDstn(
289 T_cycle, RiskyAvg, RiskyStd, RiskyCount, RNG
290 )
292 RiskyDstn = DiscreteDistributionLabeled.from_unlabeled(
293 RiskyDstnBase,
294 name="Distribution of Risky Asset Returns",
295 var_names=["risky"],
296 )
297 return RiskyDstn
300def make_labeled_shock_dstn(
301 T_cycle: int,
302 IncShkDstn: list[DiscreteDistributionLabeled],
303 RiskyDstn: DiscreteDistributionLabeled,
304) -> list[DiscreteDistributionLabeled]:
305 """
306 Create labeled joint shock distributions.
308 Combines income shock and risky return distributions into a joint
309 distribution with labeled variables.
311 Parameters
312 ----------
313 T_cycle : int
314 Number of periods in the cycle.
315 IncShkDstn : list[DiscreteDistributionLabeled]
316 List of income shock distributions.
317 RiskyDstn : DiscreteDistributionLabeled
318 Risky asset return distribution.
320 Returns
321 -------
322 list[DiscreteDistributionLabeled]
323 List of labeled joint shock distributions, one per period.
325 Raises
326 ------
327 ValueError
328 If input parameters fail validation checks.
329 """
330 # Input validation
331 if T_cycle <= 0:
332 raise ValueError(f"T_cycle must be positive, got {T_cycle}")
333 if IncShkDstn is None or len(IncShkDstn) == 0:
334 raise ValueError("IncShkDstn cannot be None or empty")
335 if RiskyDstn is None:
336 raise ValueError("RiskyDstn cannot be None")
338 ShockDstnBase = combine_IncShkDstn_and_RiskyDstn(T_cycle, RiskyDstn, IncShkDstn)
340 ShockDstn = []
341 for i in range(len(ShockDstnBase.dstns)):
342 ShockDstn.append(
343 DiscreteDistributionLabeled.from_unlabeled(
344 ShockDstnBase[i],
345 name="Distribution of Shocks to Income and Risky Asset Returns",
346 var_names=["perm", "tran", "risky"],
347 )
348 )
349 return ShockDstn