Coverage for HARK / Labeled / factories.py: 82%

90 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-25 05:22 +0000

1""" 

2Factory functions for creating labeled solutions and distributions. 

3 

4These functions create terminal period solutions and convert standard 

5HARK distributions to labeled versions for use with labeled solvers. 

6""" 

7 

8from __future__ import annotations 

9 

10from typing import TYPE_CHECKING 

11 

12import numpy as np 

13import xarray as xr 

14 

15from HARK.Calibration.Assets.AssetProcesses import ( 

16 combine_IncShkDstn_and_RiskyDstn, 

17 make_lognormal_RiskyDstn, 

18) 

19from HARK.Calibration.Income.IncomeProcesses import ( 

20 construct_lognormal_income_process_unemployment, 

21) 

22from HARK.distributions import DiscreteDistributionLabeled 

23from HARK.rewards import UtilityFuncCRRA 

24 

25from .solution import ConsumerSolutionLabeled, ValueFuncCRRALabeled 

26 

27if TYPE_CHECKING: 

28 from numpy.random import Generator 

29 

30__all__ = [ 

31 "make_solution_terminal_labeled", 

32 "make_labeled_inc_shk_dstn", 

33 "make_labeled_risky_dstn", 

34 "make_labeled_shock_dstn", 

35] 

36 

37 

38def make_solution_terminal_labeled( 

39 CRRA: float, 

40 aXtraGrid: np.ndarray, 

41) -> ConsumerSolutionLabeled: 

42 """ 

43 Construct the terminal period solution for a labeled consumption model. 

44 

45 In the terminal period, the optimal policy is to consume all resources. 

46 This function creates the value function and policy function for this 

47 terminal period. 

48 

49 Parameters 

50 ---------- 

51 CRRA : float 

52 Coefficient of relative risk aversion. 

53 aXtraGrid : np.ndarray 

54 Grid of assets above minimum. Used to construct the state grid. 

55 

56 Returns 

57 ------- 

58 ConsumerSolutionLabeled 

59 Terminal period solution with value and policy functions. 

60 

61 Raises 

62 ------ 

63 ValueError 

64 If CRRA is invalid or aXtraGrid is malformed. 

65 """ 

66 # Input validation 

67 if not np.isfinite(CRRA): 

68 raise ValueError(f"CRRA must be finite, got {CRRA}") 

69 if CRRA <= 0: 

70 raise ValueError(f"CRRA must be positive, got {CRRA}") 

71 

72 aXtraGrid = np.asarray(aXtraGrid) 

73 if len(aXtraGrid) == 0: 

74 raise ValueError("aXtraGrid cannot be empty") 

75 if np.any(aXtraGrid < 0): 

76 raise ValueError("aXtraGrid values must be non-negative") 

77 if not np.all(np.diff(aXtraGrid) > 0): 

78 raise ValueError("aXtraGrid must be strictly increasing") 

79 

80 u = UtilityFuncCRRA(CRRA) 

81 

82 # Create state grid 

83 mNrm = xr.DataArray( 

84 np.append(0.0, aXtraGrid), 

85 name="mNrm", 

86 dims=("mNrm"), 

87 attrs={"long_name": "cash_on_hand"}, 

88 ) 

89 state = xr.Dataset({"mNrm": mNrm}) 

90 

91 # Optimal decision: consume everything in terminal period 

92 cNrm = xr.DataArray( 

93 mNrm, 

94 name="cNrm", 

95 dims=state.dims, 

96 coords=state.coords, 

97 attrs={"long_name": "consumption"}, 

98 ) 

99 

100 # Compute value function variables 

101 v = u(cNrm) 

102 v.name = "v" 

103 v.attrs = {"long_name": "value function"} 

104 

105 v_der = u.der(cNrm) 

106 v_der.name = "v_der" 

107 v_der.attrs = {"long_name": "marginal value function"} 

108 

109 v_inv = cNrm.copy() 

110 v_inv.name = "v_inv" 

111 v_inv.attrs = {"long_name": "inverse value function"} 

112 

113 v_der_inv = cNrm.copy() 

114 v_der_inv.name = "v_der_inv" 

115 v_der_inv.attrs = {"long_name": "inverse marginal value function"} 

116 

117 dataset = xr.Dataset( 

118 { 

119 "cNrm": cNrm, 

120 "v": v, 

121 "v_der": v_der, 

122 "v_inv": v_inv, 

123 "v_der_inv": v_der_inv, 

124 } 

125 ) 

126 

127 vfunc = ValueFuncCRRALabeled(dataset[["v", "v_der", "v_inv", "v_der_inv"]], CRRA) 

128 

129 solution_terminal = ConsumerSolutionLabeled( 

130 value=vfunc, 

131 policy=dataset[["cNrm"]], 

132 continuation=None, 

133 attrs={"m_nrm_min": 0.0}, 

134 ) 

135 return solution_terminal 

136 

137 

138def make_labeled_inc_shk_dstn( 

139 T_cycle: int, 

140 PermShkStd: list[float], 

141 PermShkCount: int, 

142 TranShkStd: list[float], 

143 TranShkCount: int, 

144 T_retire: int, 

145 UnempPrb: float, 

146 IncUnemp: float, 

147 UnempPrbRet: float, 

148 IncUnempRet: float, 

149 RNG: Generator, 

150 neutral_measure: bool = False, 

151) -> list[DiscreteDistributionLabeled]: 

152 """ 

153 Create labeled income shock distributions. 

154 

155 Wrapper around construct_lognormal_income_process_unemployment that 

156 converts the resulting distributions to labeled versions. 

157 

158 Parameters 

159 ---------- 

160 T_cycle : int 

161 Number of periods in the cycle. 

162 PermShkStd : list[float] 

163 Standard deviation of permanent shocks by period. 

164 PermShkCount : int 

165 Number of permanent shock points. 

166 TranShkStd : list[float] 

167 Standard deviation of transitory shocks by period. 

168 TranShkCount : int 

169 Number of transitory shock points. 

170 T_retire : int 

171 Period of retirement (0 means never retire). 

172 UnempPrb : float 

173 Probability of unemployment during working life. 

174 IncUnemp : float 

175 Income during unemployment. 

176 UnempPrbRet : float 

177 Probability of "unemployment" in retirement. 

178 IncUnempRet : float 

179 Income during retirement "unemployment". 

180 RNG : Generator 

181 Random number generator. 

182 neutral_measure : bool, optional 

183 Whether to use risk-neutral measure. Default False. 

184 

185 Returns 

186 ------- 

187 list[DiscreteDistributionLabeled] 

188 List of labeled income shock distributions, one per period. 

189 

190 Raises 

191 ------ 

192 ValueError 

193 If input parameters fail validation checks. 

194 """ 

195 # Input validation 

196 if T_cycle <= 0: 

197 raise ValueError(f"T_cycle must be positive, got {T_cycle}") 

198 if PermShkCount <= 0: 

199 raise ValueError(f"PermShkCount must be positive, got {PermShkCount}") 

200 if TranShkCount <= 0: 

201 raise ValueError(f"TranShkCount must be positive, got {TranShkCount}") 

202 if len(PermShkStd) == 0: 

203 raise ValueError("PermShkStd cannot be empty") 

204 if len(TranShkStd) == 0: 

205 raise ValueError("TranShkStd cannot be empty") 

206 if not (0 <= UnempPrb <= 1): 

207 raise ValueError(f"UnempPrb must be in [0, 1], got {UnempPrb}") 

208 if not (0 <= UnempPrbRet <= 1): 

209 raise ValueError(f"UnempPrbRet must be in [0, 1], got {UnempPrbRet}") 

210 if RNG is None: 

211 raise ValueError("RNG cannot be None") 

212 

213 IncShkDstnBase = construct_lognormal_income_process_unemployment( 

214 T_cycle, 

215 PermShkStd, 

216 PermShkCount, 

217 TranShkStd, 

218 TranShkCount, 

219 T_retire, 

220 UnempPrb, 

221 IncUnemp, 

222 UnempPrbRet, 

223 IncUnempRet, 

224 RNG, 

225 neutral_measure, 

226 ) 

227 

228 IncShkDstn = [] 

229 for i in range(len(IncShkDstnBase.dstns)): 

230 IncShkDstn.append( 

231 DiscreteDistributionLabeled.from_unlabeled( 

232 IncShkDstnBase[i], 

233 name="Distribution of Shocks to Income", 

234 var_names=["perm", "tran"], 

235 ) 

236 ) 

237 return IncShkDstn 

238 

239 

240def make_labeled_risky_dstn( 

241 T_cycle: int, 

242 RiskyAvg: float, 

243 RiskyStd: float, 

244 RiskyCount: int, 

245 RNG: Generator, 

246) -> DiscreteDistributionLabeled: 

247 """ 

248 Create a labeled risky asset return distribution. 

249 

250 Wrapper around make_lognormal_RiskyDstn that converts the result 

251 to a labeled distribution. 

252 

253 Parameters 

254 ---------- 

255 T_cycle : int 

256 Number of periods in the cycle. 

257 RiskyAvg : float 

258 Mean risky return. 

259 RiskyStd : float 

260 Standard deviation of risky return. 

261 RiskyCount : int 

262 Number of risky return points. 

263 RNG : Generator 

264 Random number generator. 

265 

266 Returns 

267 ------- 

268 DiscreteDistributionLabeled 

269 Labeled distribution of risky asset returns. 

270 

271 Raises 

272 ------ 

273 ValueError 

274 If input parameters fail validation checks. 

275 """ 

276 # Input validation 

277 if T_cycle <= 0: 

278 raise ValueError(f"T_cycle must be positive, got {T_cycle}") 

279 if RiskyAvg <= 0: 

280 raise ValueError(f"RiskyAvg must be positive, got {RiskyAvg}") 

281 if RiskyStd < 0: 

282 raise ValueError(f"RiskyStd must be non-negative, got {RiskyStd}") 

283 if RiskyCount <= 0: 

284 raise ValueError(f"RiskyCount must be positive, got {RiskyCount}") 

285 if RNG is None: 

286 raise ValueError("RNG cannot be None") 

287 

288 RiskyDstnBase = make_lognormal_RiskyDstn( 

289 T_cycle, RiskyAvg, RiskyStd, RiskyCount, RNG 

290 ) 

291 

292 RiskyDstn = DiscreteDistributionLabeled.from_unlabeled( 

293 RiskyDstnBase, 

294 name="Distribution of Risky Asset Returns", 

295 var_names=["risky"], 

296 ) 

297 return RiskyDstn 

298 

299 

300def make_labeled_shock_dstn( 

301 T_cycle: int, 

302 IncShkDstn: list[DiscreteDistributionLabeled], 

303 RiskyDstn: DiscreteDistributionLabeled, 

304) -> list[DiscreteDistributionLabeled]: 

305 """ 

306 Create labeled joint shock distributions. 

307 

308 Combines income shock and risky return distributions into a joint 

309 distribution with labeled variables. 

310 

311 Parameters 

312 ---------- 

313 T_cycle : int 

314 Number of periods in the cycle. 

315 IncShkDstn : list[DiscreteDistributionLabeled] 

316 List of income shock distributions. 

317 RiskyDstn : DiscreteDistributionLabeled 

318 Risky asset return distribution. 

319 

320 Returns 

321 ------- 

322 list[DiscreteDistributionLabeled] 

323 List of labeled joint shock distributions, one per period. 

324 

325 Raises 

326 ------ 

327 ValueError 

328 If input parameters fail validation checks. 

329 """ 

330 # Input validation 

331 if T_cycle <= 0: 

332 raise ValueError(f"T_cycle must be positive, got {T_cycle}") 

333 if IncShkDstn is None or len(IncShkDstn) == 0: 

334 raise ValueError("IncShkDstn cannot be None or empty") 

335 if RiskyDstn is None: 

336 raise ValueError("RiskyDstn cannot be None") 

337 

338 ShockDstnBase = combine_IncShkDstn_and_RiskyDstn(T_cycle, RiskyDstn, IncShkDstn) 

339 

340 ShockDstn = [] 

341 for i in range(len(ShockDstnBase.dstns)): 

342 ShockDstn.append( 

343 DiscreteDistributionLabeled.from_unlabeled( 

344 ShockDstnBase[i], 

345 name="Distribution of Shocks to Income and Risky Asset Returns", 

346 var_names=["perm", "tran", "risky"], 

347 ) 

348 ) 

349 return ShockDstn