Coverage for HARK/ConsumptionSaving/ConsLabeledModel.py: 72%
364 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-02 05:14 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-02 05:14 +0000
1from dataclasses import dataclass
2from types import SimpleNamespace
3from typing import Mapping
5import numpy as np
6import xarray as xr
8from HARK.Calibration.Assets.AssetProcesses import (
9 make_lognormal_RiskyDstn,
10 combine_IncShkDstn_and_RiskyDstn,
11)
12from HARK.ConsumptionSaving.ConsIndShockModel import (
13 IndShockConsumerType,
14 init_perfect_foresight,
15 init_idiosyncratic_shocks,
16 IndShockConsumerType_aXtraGrid_default,
17)
18from HARK.ConsumptionSaving.ConsPortfolioModel import (
19 PortfolioConsumerType,
20 init_portfolio,
21)
22from HARK.ConsumptionSaving.ConsRiskyAssetModel import (
23 RiskyAssetConsumerType,
24 init_risky_asset,
25 IndShockRiskyAssetConsumerType_constructor_default,
26)
27from HARK.Calibration.Income.IncomeProcesses import (
28 construct_lognormal_income_process_unemployment,
29)
30from HARK.ConsumptionSaving.LegacyOOsolvers import ConsIndShockSetup
31from HARK.core import make_one_period_oo_solver
32from HARK.distributions import DiscreteDistributionLabeled
33from HARK.metric import MetricObject
34from HARK.rewards import UtilityFuncCRRA
35from HARK.utilities import make_assets_grid
38class ValueFuncCRRALabeled(MetricObject):
39 """
40 Class to allow for value function interpolation using xarray.
41 """
43 def __init__(self, dataset: xr.Dataset, CRRA: float):
44 """
45 Initialize a value function.
47 Parameters
48 ----------
49 dataset : xr.Dataset
50 Underlying dataset that should include a variable named
51 "v_inv" that is the inverse of the value function.
53 CRRA : float
54 Coefficient of relative risk aversion.
55 """
57 self.dataset = dataset
58 self.CRRA = CRRA
59 self.u = UtilityFuncCRRA(CRRA)
61 def __call__(self, state: Mapping[str, np.ndarray]) -> xr.Dataset:
62 """
63 Interpolate inverse value function then invert to get value function at given state.
65 Parameters
66 ----------
67 state : Mapping[str, np.ndarray]
68 State to evaluate value function at.
70 Returns
71 -------
72 result : xr.Dataset
73 """
75 state_dict = self._validate_state(state)
77 result = self.u(
78 self.dataset["v_inv"].interp(
79 state_dict,
80 assume_sorted=True,
81 kwargs={"fill_value": "extrapolate"},
82 )
83 )
85 result.name = "v"
86 result.attrs = self.dataset["v"].attrs
88 return result
90 def derivative(self, state):
91 """
92 Interpolate inverse marginal value function then invert to get marginal value function at given state.
94 Parameters
95 ----------
96 state : Mapping[str, np.ndarray]
97 State to evaluate marginal value function at.
99 Returns
100 -------
101 result : xr.Dataset
102 """
104 state_dict = self._validate_state(state)
106 result = self.u.der(
107 self.dataset["v_der_inv"].interp(
108 state_dict,
109 assume_sorted=True,
110 kwargs={"fill_value": "extrapolate"},
111 )
112 )
114 result.name = "v_der"
115 result.attrs = self.dataset["v"].attrs
117 return result
119 def evaluate(self, state):
120 """
121 Interpolate all data variables in the dataset.
123 Parameters
124 ----------
125 state : Mapping[str, np.ndarray]
126 State to evaluate all data variables at.
128 Returns
129 -------
130 result : xr.Dataset
131 """
133 state_dict = self._validate_state(state)
135 result = self.dataset.interp(
136 state_dict,
137 kwargs={"fill_value": None},
138 )
139 result.attrs = self.dataset["v"].attrs
141 return result
143 def _validate_state(self, state):
144 """
145 Allowed states are either a dict or an xr.Dataset.
146 This methods keeps only the coordinates of the dataset
147 if they are both in the dataset and the input state.
149 Parameters
150 ----------
151 state : Mapping[str, np.ndarray]
152 State to validate.
154 Returns
155 -------
156 state_dict : dict
157 """
159 if isinstance(state, (xr.Dataset, dict)):
160 state_dict = {}
161 for coords in self.dataset.coords.keys():
162 state_dict[coords] = state[coords]
163 else:
164 raise ValueError("state must be a dict or xr.Dataset")
166 return state_dict
169class ConsumerSolutionLabeled(MetricObject):
170 """
171 Class to allow for solution interpolation using xarray.
172 Represents a solution object for labeled models.
173 """
175 def __init__(
176 self,
177 value: ValueFuncCRRALabeled,
178 policy: xr.Dataset,
179 continuation: ValueFuncCRRALabeled,
180 attrs=None,
181 ):
182 """
183 Consumer Solution for labeled models.
185 Parameters
186 ----------
187 value : ValueFuncCRRALabeled
188 Value function and marginal value function.
189 policy : xr.Dataset
190 Policy function.
191 continuation : ValueFuncCRRALabeled
192 Continuation value function and marginal value function.
193 attrs : _type_, optional
194 Attributes of the solution. The default is None.
195 """
197 if attrs is None:
198 attrs = dict()
200 self.value = value # value function
201 self.policy = policy # policy function
202 self.continuation = continuation # continuation function
204 self.attrs = attrs
206 def distance(self, other: "ConsumerSolutionLabeled"):
207 """
208 Compute the distance between two solutions.
210 Parameters
211 ----------
212 other : ConsumerSolutionLabeled
213 Other solution to compare to.
215 Returns
216 -------
217 float
218 Distance between the two solutions.
219 """
221 # TODO: is there a faster way to compare two xr.Datasets?
223 value = self.value.dataset
224 other_value = other.value.dataset.interp_like(value)
226 return np.max(np.abs(value - other_value).to_array())
229###############################################################################
232def make_solution_terminal_labeled(CRRA, aXtraGrid):
233 """
234 Construct the terminal solution of the model by creating a terminal value
235 function and terminal marginal value function along with a terminal policy
236 function. This is used as the constructor for solution_terminal.
238 Parameters
239 ----------
240 CRRA : float
241 Coefficient of relative risk aversion.
242 aXtraGrid : np.array
243 Grid of assets above minimum.
245 Returns
246 -------
247 solution_terminal : ConsumerSolutionLabeled
248 Terminal period solution.
249 """
250 u = UtilityFuncCRRA(CRRA)
252 mNrm = xr.DataArray(
253 np.append(0.0, aXtraGrid),
254 name="mNrm",
255 dims=("mNrm"),
256 attrs={"long_name": "cash_on_hand"},
257 )
258 state = xr.Dataset({"mNrm": mNrm}) # only one state var in this model
260 # optimal decision is to consume everything in the last period
261 cNrm = xr.DataArray(
262 mNrm,
263 name="cNrm",
264 dims=state.dims,
265 coords=state.coords,
266 attrs={"long_name": "consumption"},
267 )
269 v = u(cNrm)
270 v.name = "v"
271 v.attrs = {"long_name": "value function"}
273 v_der = u.der(cNrm)
274 v_der.name = "v_der"
275 v_der.attrs = {"long_name": "marginal value function"}
277 v_inv = cNrm.copy()
278 v_inv.name = "v_inv"
279 v_inv.attrs = {"long_name": "inverse value function"}
281 v_der_inv = cNrm.copy()
282 v_der_inv.name = "v_der_inv"
283 v_der_inv.attrs = {"long_name": "inverse marginal value function"}
285 dataset = xr.Dataset(
286 {
287 "cNrm": cNrm,
288 "v": v,
289 "v_der": v_der,
290 "v_inv": v_inv,
291 "v_der_inv": v_der_inv,
292 }
293 )
295 vfunc = ValueFuncCRRALabeled(dataset[["v", "v_der", "v_inv", "v_der_inv"]], CRRA)
297 solution_terminal = ConsumerSolutionLabeled(
298 value=vfunc,
299 policy=dataset[["cNrm"]],
300 continuation=None,
301 attrs={"m_nrm_min": 0.0}, # minimum normalized market resources
302 )
303 return solution_terminal
306def make_labeled_inc_shk_dstn(
307 T_cycle,
308 PermShkStd,
309 PermShkCount,
310 TranShkStd,
311 TranShkCount,
312 T_retire,
313 UnempPrb,
314 IncUnemp,
315 UnempPrbRet,
316 IncUnempRet,
317 RNG,
318 neutral_measure=False,
319):
320 """
321 Wrapper around construct_lognormal_income_process_unemployment that converts
322 the IncShkDstn to a labeled version.
323 """
324 IncShkDstnBase = construct_lognormal_income_process_unemployment(
325 T_cycle,
326 PermShkStd,
327 PermShkCount,
328 TranShkStd,
329 TranShkCount,
330 T_retire,
331 UnempPrb,
332 IncUnemp,
333 UnempPrbRet,
334 IncUnempRet,
335 RNG,
336 neutral_measure,
337 )
338 IncShkDstn = []
339 for i in range(len(IncShkDstnBase.dstns)):
340 IncShkDstn.append(
341 DiscreteDistributionLabeled.from_unlabeled(
342 IncShkDstnBase[i],
343 name="Distribution of Shocks to Income",
344 var_names=["perm", "tran"],
345 )
346 )
347 return IncShkDstn
350def make_labeled_risky_dstn(T_cycle, RiskyAvg, RiskyStd, RiskyCount, RNG):
351 """
352 A wrapper around make_lognormal_RiskyDstn that makes it labeled.
353 """
354 RiskyDstnBase = make_lognormal_RiskyDstn(
355 T_cycle, RiskyAvg, RiskyStd, RiskyCount, RNG
356 )
357 RiskyDstn = DiscreteDistributionLabeled.from_unlabeled(
358 RiskyDstnBase,
359 name="Distribution of Risky Asset Returns",
360 var_names=["risky"],
361 )
362 return RiskyDstn
365def make_labeled_shock_dstn(T_cycle, IncShkDstn, RiskyDstn):
366 """
367 A wrapper function that makes the joint distributions labeled.
368 """
369 ShockDstnBase = combine_IncShkDstn_and_RiskyDstn(T_cycle, RiskyDstn, IncShkDstn)
370 ShockDstn = []
371 for i in range(len(ShockDstnBase.dstns)):
372 ShockDstn.append(
373 DiscreteDistributionLabeled.from_unlabeled(
374 ShockDstnBase[i],
375 name="Distribution of Shocks to Income and Risky Asset Returns",
376 var_names=["perm", "tran", "risky"],
377 )
378 )
379 return ShockDstn
382###############################################################################
385class ConsPerfForesightLabeledSolver(ConsIndShockSetup):
386 """
387 Solver for PerfForeshightLabeledType.
388 """
390 def create_params_namespace(self):
391 """
392 Create a namespace for parameters.
393 """
395 self.params = SimpleNamespace(
396 Discount=self.DiscFac * self.LivPrb,
397 CRRA=self.CRRA,
398 Rfree=self.Rfree,
399 PermGroFac=self.PermGroFac,
400 )
402 def calculate_borrowing_constraint(self):
403 """
404 Calculate the minimum allowable value of money resources in this period.
405 """
407 self.BoroCnstNat = (
408 self.solution_next.attrs["m_nrm_min"] - 1
409 ) / self.params.Rfree
411 def define_boundary_constraint(self):
412 """
413 If the natural borrowing constraint is a binding constraint,
414 then we can not evaluate the value function at that point,
415 so we must fill out the data by hand.
416 """
418 if self.BoroCnstArt is None or self.BoroCnstArt <= self.BoroCnstNat:
419 self.m_nrm_min = self.BoroCnstNat
420 self.nat_boro_cnst = True # natural borrowing constraint is binding
422 self.borocnst = xr.Dataset(
423 coords={"mNrm": self.m_nrm_min, "aNrm": self.m_nrm_min},
424 data_vars={
425 "cNrm": 0.0,
426 "v": -np.inf,
427 "v_inv": 0.0,
428 "reward": -np.inf,
429 "marginal_reward": np.inf,
430 "v_der": np.inf,
431 "v_der_inv": 0.0,
432 },
433 )
435 elif self.BoroCnstArt > self.BoroCnstNat:
436 self.m_nrm_min = self.BoroCnstArt
437 self.nat_boro_cnst = False # artificial borrowing constraint is binding
439 self.borocnst = xr.Dataset(
440 coords={"mNrm": self.m_nrm_min, "aNrm": self.m_nrm_min},
441 data_vars={"cNrm": 0.0},
442 )
444 def create_post_state(self):
445 """
446 Create the post state variable, which in this case is
447 the normalized assets saved this period.
448 """
450 if self.nat_boro_cnst:
451 # don't include natural borrowing constraint
452 a_grid = self.aXtraGrid + self.m_nrm_min
453 else:
454 # include artificial borrowing constraint
455 a_grid = np.append(0.0, self.aXtraGrid) + self.m_nrm_min
457 aVec = xr.DataArray(
458 a_grid,
459 name="aNrm",
460 dims=("aNrm"),
461 attrs={"long_name": "savings", "state": True},
462 )
463 post_state = xr.Dataset({"aNrm": aVec})
465 self.post_state = post_state
467 def state_transition(self, state=None, action=None, params=None):
468 """
469 State to post_state transition.
471 Parameters
472 ----------
473 state : xr.Dataset
474 State variables.
475 action : xr.Dataset
476 Action variables.
477 params : SimpleNamespace
478 Parameters.
480 Returns
481 -------
482 post_state : xr.Dataset
483 Post state variables.
484 """
486 post_state = {} # pytree
487 post_state["aNrm"] = state["mNrm"] - action["cNrm"]
488 return post_state
490 def post_state_transition(self, post_state=None, params=None):
491 """
492 Post_state to next_state transition.
494 Parameters
495 ----------
496 post_state : xr.Dataset
497 Post state variables.
498 params : SimpleNamespace
499 Parameters.
501 Returns
502 -------
503 next_state : xr.Dataset
504 Next period's state variables.
505 """
507 next_state = {} # pytree
508 next_state["mNrm"] = post_state["aNrm"] * params.Rfree / params.PermGroFac + 1
509 return next_state
511 def reverse_transition(self, post_state=None, action=None, params=None):
512 """
513 State from post state and actions.
515 Parameters
516 ----------
517 post_state : xr.Dataset
518 Post state variables.
519 action : xr.Dataset
520 Action variables.
521 params : SimpleNamespace
523 Returns
524 -------
525 state : xr.Dataset
526 State variables.
527 """
529 state = {} # pytree
530 state["mNrm"] = post_state["aNrm"] + action["cNrm"]
532 return state
534 def egm_transition(self, post_state=None, continuation=None, params=None):
535 """
536 Actions from post state using the endogenous grid method.
538 Parameters
539 ----------
540 post_state : xr.Dataset
541 Post state variables.
542 continuation : ValueFuncCRRALabeled
543 Continuation value function, next period's value function.
544 params : SimpleNamespace
546 Returns
547 -------
548 action : xr.Dataset
549 Action variables.
550 """
552 action = {} # pytree
553 action["cNrm"] = self.u.derinv(
554 params.Discount * continuation.derivative(post_state)
555 )
557 return action
559 def value_transition(self, action=None, state=None, continuation=None, params=None):
560 """
561 Value of action given state and continuation
563 Parameters
564 ----------
565 action : xr.Dataset
566 Action variables.
567 state : xr.Dataset
568 State variables.
569 continuation : ValueFuncCRRALabeled
570 Continuation value function, next period's value function.
571 params : SimpleNamespace
572 Parameters
574 Returns
575 -------
576 variables : xr.Dataset
577 Value, marginal value, reward, marginal reward, and contributions.
578 """
580 variables = {} # pytree
581 post_state = self.state_transition(state, action, params)
582 variables.update(post_state)
584 variables["reward"] = self.u(action["cNrm"])
585 variables["v"] = variables["reward"] + params.Discount * continuation(
586 post_state
587 )
588 variables["v_inv"] = self.u.inv(variables["v"])
590 variables["marginal_reward"] = self.u.der(action["cNrm"])
591 variables["v_der"] = variables["marginal_reward"]
592 variables["v_der_inv"] = action["cNrm"]
594 # for estimagic purposes
595 variables["contributions"] = variables["v"]
596 variables["value"] = np.sum(variables["v"])
598 return variables
600 def continuation_transition(self, post_state=None, value_next=None, params=None):
601 """
602 Continuation value function of post state.
604 Parameters
605 ----------
606 post_state : xr.Dataset
607 Post state variables.
608 value_next : ValueFuncCRRALabeled
609 Next period's value function.
610 params : SimpleNamespace
611 Parameters.
613 Returns
614 -------
615 variables : xr.Dataset
616 Value, marginal value, inverse value, and inverse marginal value.
617 """
619 variables = {} # pytree
620 next_state = self.post_state_transition(post_state, params)
621 variables.update(next_state)
622 variables["v"] = params.PermGroFac ** (1 - params.CRRA) * value_next(next_state)
623 variables["v_der"] = (
624 params.Rfree
625 * params.PermGroFac ** (-params.CRRA)
626 * value_next.derivative(next_state)
627 )
629 variables["v_inv"] = self.u.inv(variables["v"])
630 variables["v_der_inv"] = self.u.derinv(variables["v_der"])
632 # for estimagic purposes
633 variables["contributions"] = variables["v"]
634 variables["value"] = np.sum(variables["v"])
636 return variables
638 def prepare_to_solve(self):
639 """
640 Prepare to solve the model by creating the parameters namespace,
641 calculating the borrowing constraint, defining the boundary constraint,
642 and creating the post state.
643 """
645 self.create_params_namespace()
646 self.calculate_borrowing_constraint()
647 self.define_boundary_constraint()
648 self.create_post_state()
650 def create_continuation_function(self):
651 """
652 Create the continuation function, or the value function
653 of every possible post state.
655 Returns
656 -------
657 wfunc : ValueFuncCRRALabeled
658 Continuation function.
659 """
661 # unpack next period's solution
662 vfunc_next = self.solution_next.value
664 v_end = self.continuation_transition(self.post_state, vfunc_next, self.params)
665 # need to drop m because it's next period's m
666 v_end = xr.Dataset(v_end).drop(["mNrm"])
667 borocnst = self.borocnst.drop(["mNrm"]).expand_dims("aNrm")
668 if self.nat_boro_cnst:
669 v_end = xr.merge([borocnst, v_end])
671 wfunc = ValueFuncCRRALabeled(v_end, self.params.CRRA)
673 return wfunc
675 def endogenous_grid_method(self):
676 """
677 Solve the model using the endogenous grid method, which consists of
678 solving the model backwards in time using the following steps:
680 1. Create the continuation function, or the value function of every
681 possible post state.
682 2. Get the optimal actions/decisions from the endogenous grid transition.
683 3. Get the state from the actions and post state using the reverse transition.
684 4. EGM requires swapping dimensions; make actions and state functions of state.
685 5. Merge the actions and state into a single dataset.
686 6. If the natural borrowing constraint is not used, concatenate the
687 borrowing constraint to the dataset.
688 7. Create the value function from the variables in the dataset.
689 8. Create the policy function from the variables in the dataset.
690 9. Create the solution from the value and policy functions.
691 """
692 wfunc = self.create_continuation_function()
694 # get optimal actions/decisions from egm
695 acted = self.egm_transition(self.post_state, wfunc, self.params)
696 # get state from actions and post_state
697 state = self.reverse_transition(self.post_state, acted, self.params)
699 # egm requires swap dimensions; make actions and state functions of state
700 action = xr.Dataset(acted).swap_dims({"aNrm": "mNrm"})
701 state = xr.Dataset(state).swap_dims({"aNrm": "mNrm"})
703 egm_dataset = xr.merge([action, state])
705 if not self.nat_boro_cnst:
706 egm_dataset = xr.concat([self.borocnst, egm_dataset], dim="mNrm")
708 values = self.value_transition(egm_dataset, egm_dataset, wfunc, self.params)
709 egm_dataset.update(values)
711 if self.nat_boro_cnst:
712 egm_dataset = xr.concat(
713 [self.borocnst, egm_dataset], dim="mNrm", combine_attrs="no_conflicts"
714 )
716 egm_dataset = egm_dataset.drop("aNrm")
718 vfunc = ValueFuncCRRALabeled(
719 egm_dataset[["v", "v_der", "v_inv", "v_der_inv"]], self.params.CRRA
720 )
721 pfunc = egm_dataset[["cNrm"]]
723 self.solution = ConsumerSolutionLabeled(
724 value=vfunc,
725 policy=pfunc,
726 continuation=wfunc,
727 attrs={"m_nrm_min": self.m_nrm_min, "dataset": egm_dataset},
728 )
730 def solve(self):
731 """
732 Solve the model by endogenous grid method.
733 """
735 self.endogenous_grid_method()
737 return self.solution
740###############################################################################
742init_perf_foresight_labeled = init_idiosyncratic_shocks.copy()
743init_perf_foresight_labeled.update(init_perfect_foresight)
744PF_labeled_constructor_dict = init_idiosyncratic_shocks["constructors"].copy()
745PF_labeled_constructor_dict["solution_terminal"] = make_solution_terminal_labeled
746PF_labeled_constructor_dict["aXtraGrid"] = make_assets_grid
747init_perf_foresight_labeled["constructors"] = PF_labeled_constructor_dict
748init_perf_foresight_labeled.update(IndShockConsumerType_aXtraGrid_default)
750###############################################################################
753class PerfForesightLabeledType(IndShockConsumerType):
754 """
755 A labeled perfect foresight consumer type. This class is a subclass of
756 IndShockConsumerType, and inherits all of its methods and attributes.
758 Perfect foresight consumers have no uncertainty about income or interest
759 rates, and so the only state variable is market resources m.
760 """
762 default_ = {
763 "params": init_perf_foresight_labeled,
764 "solver": make_one_period_oo_solver(ConsPerfForesightLabeledSolver),
765 "model": "ConsPerfForesight.yaml",
766 }
768 def post_solve(self):
769 pass # Do nothing, rather than try to run calc_stable_points
772###############################################################################
775class ConsIndShockLabeledSolver(ConsPerfForesightLabeledSolver):
776 """
777 Solver for IndShockLabeledType.
778 """
780 def calculate_borrowing_constraint(self):
781 """
782 Calculate the minimum allowable value of money resources in this period.
783 This is different from the perfect foresight natural borrowing constraint
784 because of the presence of income uncertainty.
785 """
787 PermShkMinNext = np.min(self.IncShkDstn.atoms[0])
788 TranShkMinNext = np.min(self.IncShkDstn.atoms[1])
790 self.BoroCnstNat = (
791 (self.solution_next.attrs["m_nrm_min"] - TranShkMinNext)
792 * (self.params.PermGroFac * PermShkMinNext)
793 / self.params.Rfree
794 )
796 def post_state_transition(self, post_state=None, shocks=None, params=None):
797 """
798 Post state to next state transition now depends on income shocks.
800 Parameters
801 ----------
802 post_state : dict
803 Post state variables.
804 shocks : dict
805 Shocks to income.
806 params : dict
807 Parameters.
809 Returns
810 -------
811 next_state : dict
812 Next period's state variables.
813 """
815 next_state = {} # pytree
816 next_state["mNrm"] = (
817 post_state["aNrm"] * params.Rfree / (params.PermGroFac * shocks["perm"])
818 + shocks["tran"]
819 )
820 return next_state
822 def continuation_transition(
823 self, shocks=None, post_state=None, v_next=None, params=None
824 ):
825 """
826 Continuation value function of post state.
828 Parameters
829 ----------
830 shocks : dict
831 Shocks to income.
832 post_state : dict
833 Post state variables.
834 v_next : ValueFuncCRRALabeled
835 Next period's value function.
836 params : dict
837 Parameters.
839 Returns
840 -------
841 variables : dict
842 Continuation value function and its derivative.
843 """
845 variables = {} # pytree
846 next_state = self.post_state_transition(post_state, shocks, params)
847 variables.update(next_state)
849 variables["psi"] = params.PermGroFac * shocks["perm"]
851 variables["v"] = variables["psi"] ** (1 - params.CRRA) * v_next(next_state)
853 variables["v_der"] = (
854 params.Rfree
855 * variables["psi"] ** (-params.CRRA)
856 * v_next.derivative(next_state)
857 )
859 # for estimagic purposes
861 variables["contributions"] = variables["v"]
862 variables["value"] = np.sum(variables["v"])
864 return variables
866 def create_continuation_function(self):
867 """
868 Create the continuation function. Because of the income uncertainty
869 in this model, we need to integrate over the income shocks to get the
870 continuation value function. Depending on the natural borrowing constraint,
871 we may also have to append the minimum allowable value of money resources.
873 Returns
874 -------
875 wfunc : ValueFuncCRRALabeled
876 Continuation value function.
877 """
879 # unpack next period's solution
880 vfunc_next = self.solution_next.value
882 v_end = self.IncShkDstn.expected(
883 func=self.continuation_transition,
884 post_state=self.post_state,
885 v_next=vfunc_next,
886 params=self.params,
887 )
889 v_end["v_inv"] = self.u.inv(v_end["v"])
890 v_end["v_der_inv"] = self.u.derinv(v_end["v_der"])
892 borocnst = self.borocnst.drop(["mNrm"]).expand_dims("aNrm")
893 if self.nat_boro_cnst:
894 v_end = xr.merge([borocnst, v_end])
896 # need to drop m because it's next period's m
897 # v_end = xr.Dataset(v_end).drop(["mNrm"])
898 wfunc = ValueFuncCRRALabeled(v_end, self.params.CRRA)
900 return wfunc
903###############################################################################
905init_ind_shock_labeled = init_perf_foresight_labeled.copy()
906ind_shock_labeled_constructor_dict = PF_labeled_constructor_dict.copy()
907ind_shock_labeled_constructor_dict["IncShkDstn"] = make_labeled_inc_shk_dstn
908init_ind_shock_labeled["constructors"] = ind_shock_labeled_constructor_dict
911class IndShockLabeledType(PerfForesightLabeledType):
912 """
913 A labeled version of IndShockConsumerType. This class inherits from
914 PerfForesightLabeledType and adds income uncertainty.
915 """
917 default_ = {
918 "params": init_ind_shock_labeled,
919 "solver": make_one_period_oo_solver(ConsIndShockLabeledSolver),
920 "model": "ConsIndShock.yaml",
921 }
924###############################################################################
927@dataclass
928class ConsRiskyAssetLabeledSolver(ConsIndShockLabeledSolver):
929 """
930 Solver for an agent that can save in an asset that has a risky return.
931 """
933 solution_next: ConsumerSolutionLabeled # solution to next period's problem
934 ShockDstn: (
935 DiscreteDistributionLabeled # distribution of shocks to income and returns
936 )
937 LivPrb: float # survival probability
938 DiscFac: float # intertemporal discount factor
939 CRRA: float # coefficient of relative risk aversion
940 Rfree: float # interest factor on assets
941 PermGroFac: float # permanent income growth factor
942 BoroCnstArt: float # artificial borrowing constraint
943 aXtraGrid: np.ndarray # grid of end-of-period assets
945 def __post_init__(self):
946 """
947 Define utility functions.
948 """
950 self.def_utility_funcs()
952 def calculate_borrowing_constraint(self):
953 """
954 Calculate the borrowing constraint by enforcing a 0.0 artificial borrowing
955 constraint and setting the shocks to income to come from the shock distribution.
956 """
957 self.BoroCnstArt = 0.0
958 self.IncShkDstn = self.ShockDstn
959 return super().calculate_borrowing_constraint()
961 def post_state_transition(self, post_state=None, shocks=None, params=None):
962 """
963 Post_state to next_state transition with risky asset return.
965 Parameters
966 ----------
967 post_state : dict
968 Post-state variables.
969 shocks : dict
970 Shocks to income and risky asset return.
971 params : dict
972 Parameters of the model.
974 Returns
975 -------
976 next_state : dict
977 Next period's state variables.
978 """
980 next_state = {} # pytree
981 next_state["mNrm"] = (
982 post_state["aNrm"] * shocks["risky"] / (params.PermGroFac * shocks["perm"])
983 + shocks["tran"]
984 )
985 return next_state
987 def continuation_transition(
988 self, shocks=None, post_state=None, v_next=None, params=None
989 ):
990 """
991 Continuation value function of post_state with risky asset return.
993 Parameters
994 ----------
995 shocks : dict
996 Shocks to income and risky asset return.
997 post_state : dict
998 Post-state variables.
999 v_next : function
1000 Value function of next period.
1001 params : dict
1002 Parameters of the model.
1004 Returns
1005 -------
1006 variables : dict
1007 Variables of the continuation value function.
1008 """
1010 variables = {} # pytree
1011 next_state = self.post_state_transition(post_state, shocks, params)
1012 variables.update(next_state)
1014 variables["psi"] = params.PermGroFac * shocks["perm"]
1016 variables["v"] = variables["psi"] ** (1 - params.CRRA) * v_next(next_state)
1018 variables["v_der"] = (
1019 shocks["risky"]
1020 * variables["psi"] ** (-params.CRRA)
1021 * v_next.derivative(next_state)
1022 )
1024 # for estimagic purposes
1026 variables["contributions"] = variables["v"]
1027 variables["value"] = np.sum(variables["v"])
1029 return variables
1031 def create_continuation_function(self):
1032 """
1033 Create the continuation value function taking expectation
1034 over the shock distribution which includes shocks to income and
1035 the risky asset return.
1037 Returns
1038 -------
1039 wfunc : ValueFuncCRRALabeled
1040 Continuation value function.
1041 """
1042 # unpack next period's solution
1043 vfunc_next = self.solution_next.value
1045 v_end = self.ShockDstn.expected(
1046 func=self.continuation_transition,
1047 post_state=self.post_state,
1048 v_next=vfunc_next,
1049 params=self.params,
1050 )
1052 v_end["v_inv"] = self.u.inv(v_end["v"])
1053 v_end["v_der_inv"] = self.u.derinv(v_end["v_der"])
1055 borocnst = self.borocnst.drop(["mNrm"]).expand_dims("aNrm")
1056 if self.nat_boro_cnst:
1057 v_end = xr.merge([borocnst, v_end])
1059 v_end = v_end.transpose("aNrm", ...)
1061 # need to drop m because it's next period's m
1062 # v_end = xr.Dataset(v_end).drop(["mNrm"])
1063 wfunc = ValueFuncCRRALabeled(v_end, self.params.CRRA)
1065 return wfunc
1068###############################################################################
1070risky_asset_labeled_constructor_dict = (
1071 IndShockRiskyAssetConsumerType_constructor_default.copy()
1072)
1073risky_asset_labeled_constructor_dict["IncShkDstn"] = make_labeled_inc_shk_dstn
1074risky_asset_labeled_constructor_dict["RiskyDstn"] = make_labeled_risky_dstn
1075risky_asset_labeled_constructor_dict["ShockDstn"] = make_labeled_shock_dstn
1076risky_asset_labeled_constructor_dict["solution_terminal"] = (
1077 make_solution_terminal_labeled
1078)
1079del risky_asset_labeled_constructor_dict["solve_one_period"]
1080init_risky_asset_labeled = init_risky_asset.copy()
1081init_risky_asset_labeled["constructors"] = risky_asset_labeled_constructor_dict
1083###############################################################################
1086class RiskyAssetLabeledType(IndShockLabeledType, RiskyAssetConsumerType):
1087 """
1088 A labeled RiskyAssetConsumerType. This class is a subclass of
1089 RiskyAssetConsumerType, and inherits all of its methods and attributes.
1091 Risky asset consumers can only save on a risky asset that
1092 pays a stochastic return.
1093 """
1095 default_ = {
1096 "params": init_risky_asset_labeled,
1097 "solver": make_one_period_oo_solver(ConsRiskyAssetLabeledSolver),
1098 "model": "ConsRiskyAsset.yaml",
1099 }
1102###############################################################################
1105@dataclass
1106class ConsFixedPortfolioLabeledSolver(ConsRiskyAssetLabeledSolver):
1107 """
1108 Solver for an agent that can save in a risk-free and risky asset
1109 at a fixed proportion.
1110 """
1112 RiskyShareFixed: float # share of risky assets in portfolio
1114 def create_params_namespace(self):
1115 """
1116 Create a namespace for parameters.
1117 """
1119 self.params = SimpleNamespace(
1120 Discount=self.DiscFac * self.LivPrb,
1121 CRRA=self.CRRA,
1122 Rfree=self.Rfree,
1123 PermGroFac=self.PermGroFac,
1124 RiskyShareFixed=self.RiskyShareFixed,
1125 )
1127 def post_state_transition(self, post_state=None, shocks=None, params=None):
1128 """
1129 Post_state to next_state transition with fixed portfolio share.
1131 Parameters
1132 ----------
1133 post_state : dict
1134 Post-state variables.
1135 shocks : dict
1136 Shocks to income and risky asset return.
1137 params : dict
1138 Parameters of the model.
1140 Returns
1141 -------
1142 next_state : dict
1143 Next period's state variables.
1144 """
1146 next_state = {} # pytree
1147 next_state["rDiff"] = params.Rfree - shocks["risky"]
1148 next_state["rPort"] = (
1149 params.Rfree + next_state["rDiff"] * params.RiskyShareFixed
1150 )
1151 next_state["mNrm"] = (
1152 post_state["aNrm"]
1153 * next_state["rPort"]
1154 / (params.PermGroFac * shocks["perm"])
1155 + shocks["tran"]
1156 )
1157 return next_state
1159 def continuation_transition(
1160 self, shocks=None, post_state=None, v_next=None, params=None
1161 ):
1162 """
1163 Continuation value function of post_state with fixed portfolio share.
1165 Parameters
1166 ----------
1167 shocks : dict
1168 Shocks to income and risky asset return.
1169 post_state : dict
1170 Post-state variables.
1171 v_next : ValueFuncCRRALabeled
1172 Continuation value function.
1173 params : dict
1174 Parameters of the model.
1176 Returns
1177 -------
1178 variables : dict
1179 Variables of the model.
1180 """
1182 variables = {} # pytree
1183 next_state = self.post_state_transition(post_state, shocks, params)
1184 variables.update(next_state)
1186 variables["psi"] = params.PermGroFac * shocks["perm"]
1188 variables["v"] = variables["psi"] ** (1 - params.CRRA) * v_next(next_state)
1190 variables["v_der"] = (
1191 next_state["rPort"]
1192 * variables["psi"] ** (-params.CRRA)
1193 * v_next.derivative(next_state)
1194 )
1196 # for estimagic purposes
1198 variables["contributions"] = variables["v"]
1199 variables["value"] = np.sum(variables["v"])
1201 return variables
1204###############################################################################
1207@dataclass
1208class ConsPortfolioLabeledSolver(ConsFixedPortfolioLabeledSolver):
1209 """
1210 Solver for an agent that can save in a risk-free and risky asset
1211 at an optimal proportion.
1212 """
1214 ShareGrid: np.ndarray # grid of risky shares
1216 def create_post_state(self):
1217 """
1218 Create post-state variables by adding risky share, called
1219 stigma, to the post-state variables.
1220 """
1222 super().create_post_state()
1224 self.post_state["stigma"] = xr.DataArray(
1225 self.ShareGrid, dims=["stigma"], attrs={"long_name": "risky share"}
1226 )
1228 def post_state_transition(self, post_state=None, shocks=None, params=None):
1229 """
1230 Post_state to next_state transition with optimal portfolio share.
1232 Parameters
1233 ----------
1234 post_state : dict
1235 Post-state variables.
1236 shocks : dict
1237 Shocks to income and risky asset return.
1238 params : dict
1239 Parameters of the model.
1241 Returns
1242 -------
1243 next_state : dict
1244 Next period's state variables.
1245 """
1247 next_state = {} # pytree
1248 next_state["rDiff"] = shocks["risky"] - params.Rfree
1249 next_state["rPort"] = params.Rfree + next_state["rDiff"] * post_state["stigma"]
1250 next_state["mNrm"] = (
1251 post_state["aNrm"]
1252 * next_state["rPort"]
1253 / (params.PermGroFac * shocks["perm"])
1254 + shocks["tran"]
1255 )
1256 return next_state
1258 def continuation_transition(
1259 self, shocks=None, post_state=None, v_next=None, params=None
1260 ):
1261 """
1262 Continuation value function of post_state with optimal portfolio share.
1264 Parameters
1265 ----------
1266 shocks : dict
1267 Shocks to income and risky asset return.
1268 post_state : dict
1269 Post-state variables.
1270 v_next : ValueFuncCRRALabeled
1271 Continuation value function.
1272 params : dict
1273 Parameters of the model.
1275 Returns
1276 -------
1277 variables : dict
1278 Variables of the model.
1279 """
1281 variables = {} # pytree
1282 next_state = self.post_state_transition(post_state, shocks, params)
1283 variables.update(next_state)
1285 variables["psi"] = params.PermGroFac * shocks["perm"]
1287 variables["v"] = variables["psi"] ** (1 - params.CRRA) * v_next(next_state)
1289 variables["v_der"] = variables["psi"] ** (-params.CRRA) * v_next.derivative(
1290 next_state
1291 )
1293 variables["dvda"] = next_state["rPort"] * variables["v_der"]
1294 variables["dvds"] = (
1295 next_state["rDiff"] * post_state["aNrm"] * variables["v_der"]
1296 )
1298 # for estimagic purposes
1300 variables["contributions"] = variables["v"]
1301 variables["value"] = np.sum(variables["v"])
1303 return variables
1305 def create_continuation_function(self):
1306 """
1307 Create continuation function with optimal portfolio share.
1308 The continuation function is a function of the post-state before
1309 the growth period, but only a function of assets in the
1310 allocation period.
1312 Therefore, the first continuation function is a function of
1313 assets and stigma. Given this, the agent makes an optimal
1314 choice of risky share of portfolio, and the second continuation
1315 function is a function of assets only.
1317 Returns
1318 -------
1319 wfunc : ValueFuncCRRALabeled
1320 Continuation value function.
1321 """
1323 wfunc = super().create_continuation_function()
1325 dvds = wfunc.dataset["dvds"].values
1327 # For each value of aNrm, find the value of Share such that FOC-Share == 0.
1328 crossing = np.logical_and(dvds[:, 1:] <= 0.0, dvds[:, :-1] >= 0.0)
1329 share_idx = np.argmax(crossing, axis=1)
1330 a_idx = np.arange(self.post_state["aNrm"].size)
1331 bot_s = self.ShareGrid[share_idx]
1332 top_s = self.ShareGrid[share_idx + 1]
1333 bot_f = dvds[a_idx, share_idx]
1334 top_f = dvds[a_idx, share_idx + 1]
1335 alpha = 1.0 - top_f / (top_f - bot_f)
1336 opt_share = (1.0 - alpha) * bot_s + alpha * top_s
1338 # If agent wants to put more than 100% into risky asset, he is constrained
1339 # For values of aNrm at which the agent wants to put
1340 # more than 100% into risky asset, constrain them
1341 opt_share[dvds[:, -1] > 0.0] = 1.0
1342 # Likewise if he wants to put less than 0% into risky asset
1343 opt_share[dvds[:, 0] < 0.0] = 0.0
1345 if not self.nat_boro_cnst:
1346 # aNrm=0, so there's no way to "optimize" the portfolio
1347 opt_share[0] = 1.0
1349 opt_share = xr.DataArray(
1350 opt_share,
1351 coords={"aNrm": self.post_state["aNrm"].values},
1352 dims=["aNrm"],
1353 attrs={"long_name": "optimal risky share"},
1354 )
1356 v_end = wfunc.evaluate({"aNrm": self.post_state["aNrm"], "stigma": opt_share})
1358 v_end = v_end.reset_coords(names="stigma")
1360 wfunc = ValueFuncCRRALabeled(v_end, self.params.CRRA)
1362 self.post_state = self.post_state.drop("stigma")
1364 return wfunc
1367###############################################################################
1369init_portfolio_labeled = init_portfolio.copy()
1370init_portfolio_labeled_constructors = init_portfolio["constructors"].copy()
1371init_portfolio_labeled_constructors["IncShkDstn"] = make_labeled_inc_shk_dstn
1372init_portfolio_labeled_constructors["RiskyDstn"] = make_labeled_risky_dstn
1373init_portfolio_labeled_constructors["ShockDstn"] = make_labeled_shock_dstn
1374init_portfolio_labeled_constructors["solution_terminal"] = (
1375 make_solution_terminal_labeled
1376)
1377init_portfolio_labeled["constructors"] = init_portfolio_labeled_constructors
1378init_portfolio_labeled["RiskyShareFixed"] = [0.0] # This shouldn't exist
1381class PortfolioLabeledType(PortfolioConsumerType):
1382 """
1383 A labeled PortfolioConsumerType. This class is a subclass of
1384 PortfolioConsumerType, and inherits all of its methods and attributes.
1386 Portfolio consumers can save on a risk-free and
1387 risky asset at an optimal proportion.
1388 """
1390 default_ = {
1391 "params": init_portfolio_labeled,
1392 "solver": make_one_period_oo_solver(ConsPortfolioLabeledSolver),
1393 "model": "ConsPortfolio.yaml",
1394 }