Coverage for HARK / Labeled / solution.py: 91%

68 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-25 05:22 +0000

1""" 

2Solution classes for labeled consumption-saving models. 

3 

4This module contains the value function and solution classes that use 

5xarray for labeled, multidimensional data handling. 

6""" 

7 

8from __future__ import annotations 

9 

10from typing import Mapping 

11 

12import numpy as np 

13import xarray as xr 

14 

15from HARK.metric import MetricObject 

16from HARK.rewards import UtilityFuncCRRA 

17 

18__all__ = [ 

19 "ValueFuncCRRALabeled", 

20 "ConsumerSolutionLabeled", 

21] 

22 

23 

24class ValueFuncCRRALabeled(MetricObject): 

25 """ 

26 Value function interpolation using xarray for labeled arrays. 

27 

28 This class enables value function interpolation and derivative computation 

29 using xarray's labeled data structures. It stores the value function in 

30 inverse form for numerical stability with CRRA utility. 

31 

32 Parameters 

33 ---------- 

34 dataset : xr.Dataset 

35 Underlying dataset containing variables: 

36 - "v": value function 

37 - "v_der": marginal value function 

38 - "v_inv": inverse of value function 

39 - "v_der_inv": inverse of marginal value function 

40 CRRA : float 

41 Coefficient of relative risk aversion. Must be non-negative and finite. 

42 

43 Raises 

44 ------ 

45 ValueError 

46 If CRRA is negative or not finite. 

47 """ 

48 

49 def __init__(self, dataset: xr.Dataset, CRRA: float) -> None: 

50 if not np.isfinite(CRRA): 

51 raise ValueError(f"CRRA must be finite, got {CRRA}") 

52 if CRRA < 0: 

53 raise ValueError(f"CRRA must be non-negative, got {CRRA}") 

54 

55 # Validate dataset structure 

56 if not isinstance(dataset, xr.Dataset): 

57 raise TypeError(f"dataset must be xr.Dataset, got {type(dataset)}") 

58 

59 required_vars = {"v", "v_der", "v_inv", "v_der_inv"} 

60 missing_vars = required_vars - set(dataset.data_vars) 

61 if missing_vars: 

62 raise ValueError( 

63 f"Dataset missing required variables: {missing_vars}. " 

64 f"Required: {required_vars}" 

65 ) 

66 

67 self.dataset = dataset 

68 self.CRRA = CRRA 

69 self.u = UtilityFuncCRRA(CRRA) 

70 

71 def __call__(self, state: Mapping[str, np.ndarray]) -> xr.DataArray: 

72 """ 

73 Evaluate value function at given state via interpolation. 

74 

75 Interpolates the inverse value function and then inverts to get 

76 the value function, which is more numerically stable for CRRA utility. 

77 

78 Parameters 

79 ---------- 

80 state : Mapping[str, np.ndarray] 

81 State variables to evaluate value function at. Must contain 

82 all coordinates present in the dataset. 

83 

84 Returns 

85 ------- 

86 xr.DataArray 

87 Value function evaluated at the given state. 

88 

89 Raises 

90 ------ 

91 KeyError 

92 If state is missing required coordinates. 

93 """ 

94 state_dict = self._validate_state(state) 

95 

96 result = self.u( 

97 self.dataset["v_inv"].interp( 

98 state_dict, 

99 assume_sorted=True, 

100 kwargs={"fill_value": "extrapolate"}, 

101 ) 

102 ) 

103 

104 result.name = "v" 

105 result.attrs = self.dataset["v"].attrs 

106 

107 return result 

108 

109 def derivative(self, state: Mapping[str, np.ndarray]) -> xr.DataArray: 

110 """ 

111 Evaluate marginal value function at given state via interpolation. 

112 

113 Interpolates the inverse marginal value function and then inverts 

114 to get the marginal value function. 

115 

116 Parameters 

117 ---------- 

118 state : Mapping[str, np.ndarray] 

119 State variables to evaluate marginal value function at. 

120 

121 Returns 

122 ------- 

123 xr.DataArray 

124 Marginal value function evaluated at the given state. 

125 

126 Raises 

127 ------ 

128 KeyError 

129 If state is missing required coordinates. 

130 """ 

131 state_dict = self._validate_state(state) 

132 

133 result = self.u.der( 

134 self.dataset["v_der_inv"].interp( 

135 state_dict, 

136 assume_sorted=True, 

137 kwargs={"fill_value": "extrapolate"}, 

138 ) 

139 ) 

140 

141 result.name = "v_der" 

142 result.attrs = self.dataset["v"].attrs 

143 

144 return result 

145 

146 def evaluate(self, state: Mapping[str, np.ndarray]) -> xr.Dataset: 

147 """ 

148 Interpolate all data variables in the dataset at given state. 

149 

150 Parameters 

151 ---------- 

152 state : Mapping[str, np.ndarray] 

153 State variables to evaluate all data variables at. 

154 

155 Returns 

156 ------- 

157 xr.Dataset 

158 All interpolated data variables at the given state. 

159 

160 Raises 

161 ------ 

162 KeyError 

163 If state is missing required coordinates. 

164 """ 

165 state_dict = self._validate_state(state) 

166 

167 result = self.dataset.interp( 

168 state_dict, 

169 kwargs={"fill_value": None}, 

170 ) 

171 result.attrs = self.dataset["v"].attrs 

172 

173 return result 

174 

175 def _validate_state(self, state: Mapping[str, np.ndarray]) -> dict: 

176 """ 

177 Validate state and extract required coordinates. 

178 

179 Parameters 

180 ---------- 

181 state : Mapping[str, np.ndarray] 

182 State to validate. Must be a dict or xr.Dataset. 

183 

184 Returns 

185 ------- 

186 dict 

187 Dictionary containing only the coordinates present in both 

188 the dataset and the input state. 

189 

190 Raises 

191 ------ 

192 TypeError 

193 If state is not a dict or xr.Dataset. 

194 KeyError 

195 If a required coordinate is missing from state. 

196 """ 

197 if not isinstance(state, (xr.Dataset, dict)): 

198 raise TypeError(f"state must be a dict or xr.Dataset, got {type(state)}") 

199 

200 state_dict = {} 

201 for coord in self.dataset.coords.keys(): 

202 if coord not in state: 

203 raise KeyError( 

204 f"Required coordinate '{coord}' not found in state. " 

205 f"Available keys: {list(state.keys())}" 

206 ) 

207 state_dict[coord] = state[coord] 

208 

209 return state_dict 

210 

211 

212class ConsumerSolutionLabeled(MetricObject): 

213 """ 

214 Solution to a labeled consumption-saving problem. 

215 

216 This class represents the complete solution to a one-period 

217 consumption-saving problem, containing the value function, 

218 policy function (consumption), and continuation value function. 

219 

220 Parameters 

221 ---------- 

222 value : ValueFuncCRRALabeled 

223 Value function for this period. 

224 policy : xr.Dataset 

225 Policy function (consumption as function of state). 

226 continuation : ValueFuncCRRALabeled or None 

227 Continuation value function (value of post-decision state). 

228 Can be None for terminal period solutions. 

229 attrs : dict, optional 

230 Additional attributes of the solution, such as minimum 

231 normalized market resources. Default is None. 

232 """ 

233 

234 def __init__( 

235 self, 

236 value: ValueFuncCRRALabeled, 

237 policy: xr.Dataset, 

238 continuation: ValueFuncCRRALabeled | None, 

239 attrs: dict | None = None, 

240 ) -> None: 

241 # Type validation 

242 if not isinstance(value, ValueFuncCRRALabeled): 

243 raise TypeError(f"value must be ValueFuncCRRALabeled, got {type(value)}") 

244 if not isinstance(policy, xr.Dataset): 

245 raise TypeError(f"policy must be xr.Dataset, got {type(policy)}") 

246 if continuation is not None and not isinstance( 

247 continuation, ValueFuncCRRALabeled 

248 ): 

249 raise TypeError( 

250 f"continuation must be ValueFuncCRRALabeled or None, got {type(continuation)}" 

251 ) 

252 

253 self.value = value 

254 self.policy = policy 

255 self.continuation = continuation 

256 self.attrs = attrs if attrs is not None else {} 

257 

258 def distance(self, other: ConsumerSolutionLabeled) -> float: 

259 """ 

260 Compute the maximum absolute difference between two solutions. 

261 

262 This method is used to check for convergence in infinite horizon 

263 problems by comparing value functions across iterations. 

264 

265 Parameters 

266 ---------- 

267 other : ConsumerSolutionLabeled 

268 Other solution to compare to. 

269 

270 Returns 

271 ------- 

272 float 

273 Maximum absolute difference between value functions. 

274 

275 Raises 

276 ------ 

277 TypeError 

278 If other is not a ConsumerSolutionLabeled. 

279 ValueError 

280 If one or both solutions have no value function. 

281 """ 

282 if not isinstance(other, ConsumerSolutionLabeled): 

283 raise TypeError( 

284 f"Cannot compute distance with {type(other)}. " 

285 f"Expected ConsumerSolutionLabeled." 

286 ) 

287 if self.value is None or other.value is None: 

288 raise ValueError( 

289 "Cannot compute distance: one or both solutions have no value function" 

290 ) 

291 

292 value = self.value.dataset 

293 other_value = other.value.dataset.interp_like(value) 

294 

295 return float(np.max(np.abs(value - other_value).to_array()))