Coverage for HARK / Labeled / solution.py: 91%
68 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
1"""
2Solution classes for labeled consumption-saving models.
4This module contains the value function and solution classes that use
5xarray for labeled, multidimensional data handling.
6"""
8from __future__ import annotations
10from typing import Mapping
12import numpy as np
13import xarray as xr
15from HARK.metric import MetricObject
16from HARK.rewards import UtilityFuncCRRA
18__all__ = [
19 "ValueFuncCRRALabeled",
20 "ConsumerSolutionLabeled",
21]
24class ValueFuncCRRALabeled(MetricObject):
25 """
26 Value function interpolation using xarray for labeled arrays.
28 This class enables value function interpolation and derivative computation
29 using xarray's labeled data structures. It stores the value function in
30 inverse form for numerical stability with CRRA utility.
32 Parameters
33 ----------
34 dataset : xr.Dataset
35 Underlying dataset containing variables:
36 - "v": value function
37 - "v_der": marginal value function
38 - "v_inv": inverse of value function
39 - "v_der_inv": inverse of marginal value function
40 CRRA : float
41 Coefficient of relative risk aversion. Must be non-negative and finite.
43 Raises
44 ------
45 ValueError
46 If CRRA is negative or not finite.
47 """
49 def __init__(self, dataset: xr.Dataset, CRRA: float) -> None:
50 if not np.isfinite(CRRA):
51 raise ValueError(f"CRRA must be finite, got {CRRA}")
52 if CRRA < 0:
53 raise ValueError(f"CRRA must be non-negative, got {CRRA}")
55 # Validate dataset structure
56 if not isinstance(dataset, xr.Dataset):
57 raise TypeError(f"dataset must be xr.Dataset, got {type(dataset)}")
59 required_vars = {"v", "v_der", "v_inv", "v_der_inv"}
60 missing_vars = required_vars - set(dataset.data_vars)
61 if missing_vars:
62 raise ValueError(
63 f"Dataset missing required variables: {missing_vars}. "
64 f"Required: {required_vars}"
65 )
67 self.dataset = dataset
68 self.CRRA = CRRA
69 self.u = UtilityFuncCRRA(CRRA)
71 def __call__(self, state: Mapping[str, np.ndarray]) -> xr.DataArray:
72 """
73 Evaluate value function at given state via interpolation.
75 Interpolates the inverse value function and then inverts to get
76 the value function, which is more numerically stable for CRRA utility.
78 Parameters
79 ----------
80 state : Mapping[str, np.ndarray]
81 State variables to evaluate value function at. Must contain
82 all coordinates present in the dataset.
84 Returns
85 -------
86 xr.DataArray
87 Value function evaluated at the given state.
89 Raises
90 ------
91 KeyError
92 If state is missing required coordinates.
93 """
94 state_dict = self._validate_state(state)
96 result = self.u(
97 self.dataset["v_inv"].interp(
98 state_dict,
99 assume_sorted=True,
100 kwargs={"fill_value": "extrapolate"},
101 )
102 )
104 result.name = "v"
105 result.attrs = self.dataset["v"].attrs
107 return result
109 def derivative(self, state: Mapping[str, np.ndarray]) -> xr.DataArray:
110 """
111 Evaluate marginal value function at given state via interpolation.
113 Interpolates the inverse marginal value function and then inverts
114 to get the marginal value function.
116 Parameters
117 ----------
118 state : Mapping[str, np.ndarray]
119 State variables to evaluate marginal value function at.
121 Returns
122 -------
123 xr.DataArray
124 Marginal value function evaluated at the given state.
126 Raises
127 ------
128 KeyError
129 If state is missing required coordinates.
130 """
131 state_dict = self._validate_state(state)
133 result = self.u.der(
134 self.dataset["v_der_inv"].interp(
135 state_dict,
136 assume_sorted=True,
137 kwargs={"fill_value": "extrapolate"},
138 )
139 )
141 result.name = "v_der"
142 result.attrs = self.dataset["v"].attrs
144 return result
146 def evaluate(self, state: Mapping[str, np.ndarray]) -> xr.Dataset:
147 """
148 Interpolate all data variables in the dataset at given state.
150 Parameters
151 ----------
152 state : Mapping[str, np.ndarray]
153 State variables to evaluate all data variables at.
155 Returns
156 -------
157 xr.Dataset
158 All interpolated data variables at the given state.
160 Raises
161 ------
162 KeyError
163 If state is missing required coordinates.
164 """
165 state_dict = self._validate_state(state)
167 result = self.dataset.interp(
168 state_dict,
169 kwargs={"fill_value": None},
170 )
171 result.attrs = self.dataset["v"].attrs
173 return result
175 def _validate_state(self, state: Mapping[str, np.ndarray]) -> dict:
176 """
177 Validate state and extract required coordinates.
179 Parameters
180 ----------
181 state : Mapping[str, np.ndarray]
182 State to validate. Must be a dict or xr.Dataset.
184 Returns
185 -------
186 dict
187 Dictionary containing only the coordinates present in both
188 the dataset and the input state.
190 Raises
191 ------
192 TypeError
193 If state is not a dict or xr.Dataset.
194 KeyError
195 If a required coordinate is missing from state.
196 """
197 if not isinstance(state, (xr.Dataset, dict)):
198 raise TypeError(f"state must be a dict or xr.Dataset, got {type(state)}")
200 state_dict = {}
201 for coord in self.dataset.coords.keys():
202 if coord not in state:
203 raise KeyError(
204 f"Required coordinate '{coord}' not found in state. "
205 f"Available keys: {list(state.keys())}"
206 )
207 state_dict[coord] = state[coord]
209 return state_dict
212class ConsumerSolutionLabeled(MetricObject):
213 """
214 Solution to a labeled consumption-saving problem.
216 This class represents the complete solution to a one-period
217 consumption-saving problem, containing the value function,
218 policy function (consumption), and continuation value function.
220 Parameters
221 ----------
222 value : ValueFuncCRRALabeled
223 Value function for this period.
224 policy : xr.Dataset
225 Policy function (consumption as function of state).
226 continuation : ValueFuncCRRALabeled or None
227 Continuation value function (value of post-decision state).
228 Can be None for terminal period solutions.
229 attrs : dict, optional
230 Additional attributes of the solution, such as minimum
231 normalized market resources. Default is None.
232 """
234 def __init__(
235 self,
236 value: ValueFuncCRRALabeled,
237 policy: xr.Dataset,
238 continuation: ValueFuncCRRALabeled | None,
239 attrs: dict | None = None,
240 ) -> None:
241 # Type validation
242 if not isinstance(value, ValueFuncCRRALabeled):
243 raise TypeError(f"value must be ValueFuncCRRALabeled, got {type(value)}")
244 if not isinstance(policy, xr.Dataset):
245 raise TypeError(f"policy must be xr.Dataset, got {type(policy)}")
246 if continuation is not None and not isinstance(
247 continuation, ValueFuncCRRALabeled
248 ):
249 raise TypeError(
250 f"continuation must be ValueFuncCRRALabeled or None, got {type(continuation)}"
251 )
253 self.value = value
254 self.policy = policy
255 self.continuation = continuation
256 self.attrs = attrs if attrs is not None else {}
258 def distance(self, other: ConsumerSolutionLabeled) -> float:
259 """
260 Compute the maximum absolute difference between two solutions.
262 This method is used to check for convergence in infinite horizon
263 problems by comparing value functions across iterations.
265 Parameters
266 ----------
267 other : ConsumerSolutionLabeled
268 Other solution to compare to.
270 Returns
271 -------
272 float
273 Maximum absolute difference between value functions.
275 Raises
276 ------
277 TypeError
278 If other is not a ConsumerSolutionLabeled.
279 ValueError
280 If one or both solutions have no value function.
281 """
282 if not isinstance(other, ConsumerSolutionLabeled):
283 raise TypeError(
284 f"Cannot compute distance with {type(other)}. "
285 f"Expected ConsumerSolutionLabeled."
286 )
287 if self.value is None or other.value is None:
288 raise ValueError(
289 "Cannot compute distance: one or both solutions have no value function"
290 )
292 value = self.value.dataset
293 other_value = other.value.dataset.interp_like(value)
295 return float(np.max(np.abs(value - other_value).to_array()))