Coverage for HARK / Labeled / solvers.py: 91%
225 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
1"""
2Solvers for labeled consumption-saving models.
4This module implements the Template Method pattern for the Endogenous Grid
5Method (EGM) algorithm. The base solver defines the algorithm skeleton,
6while concrete solvers override specific hook methods for their model type.
7"""
9from __future__ import annotations
11import warnings
12from types import SimpleNamespace
13from typing import TYPE_CHECKING, Any
15import numpy as np
16import xarray as xr
18from HARK.metric import MetricObject
19from HARK.rewards import UtilityFuncCRRA
21from .solution import ConsumerSolutionLabeled, ValueFuncCRRALabeled
22from .transitions import (
23 FixedPortfolioTransitions,
24 IndShockTransitions,
25 PerfectForesightTransitions,
26 PortfolioTransitions,
27 RiskyAssetTransitions,
28)
30if TYPE_CHECKING:
31 from HARK.distributions import DiscreteDistributionLabeled
33__all__ = [
34 "BaseLabeledSolver",
35 "ConsPerfForesightLabeledSolver",
36 "ConsIndShockLabeledSolver",
37 "ConsRiskyAssetLabeledSolver",
38 "ConsFixedPortfolioLabeledSolver",
39 "ConsPortfolioLabeledSolver",
40]
43class BaseLabeledSolver(MetricObject):
44 """
45 Base solver implementing Template Method pattern for EGM algorithm.
47 This class provides the algorithm skeleton for solving consumption-saving
48 problems using the Endogenous Grid Method. Subclasses customize behavior
49 by:
50 1. Setting TransitionsClass to specify model-specific transitions
51 2. Overriding hook methods for model-specific logic
53 Template Method: solve()
54 Hook Methods:
55 - create_params_namespace(): Add model-specific parameters
56 - calculate_borrowing_constraint(): Model-specific constraint logic
57 - create_post_state(): Add extra state dimensions (e.g., stigma)
58 - create_continuation_function(): Handle shock integration
60 Parameters
61 ----------
62 solution_next : ConsumerSolutionLabeled
63 Solution to next period's problem.
64 LivPrb : float
65 Survival probability.
66 DiscFac : float
67 Intertemporal discount factor.
68 CRRA : float
69 Coefficient of relative risk aversion.
70 Rfree : float
71 Risk-free interest factor.
72 PermGroFac : float
73 Permanent income growth factor.
74 BoroCnstArt : float or None
75 Artificial borrowing constraint.
76 aXtraGrid : np.ndarray
77 Grid of end-of-period asset values above minimum.
78 **kwargs
79 Additional model-specific parameters.
81 Raises
82 ------
83 ValueError
84 If CRRA is invalid or aXtraGrid is malformed.
85 """
87 # Class-level strategy specification - override in subclasses
88 TransitionsClass: type = PerfectForesightTransitions
90 def __init__(
91 self,
92 solution_next: ConsumerSolutionLabeled,
93 LivPrb: float,
94 DiscFac: float,
95 CRRA: float,
96 Rfree: float,
97 PermGroFac: float,
98 BoroCnstArt: float | None,
99 aXtraGrid: np.ndarray,
100 **kwargs,
101 ) -> None:
102 # Input validation - solution_next
103 if solution_next is None:
104 raise ValueError("solution_next cannot be None")
105 if not isinstance(solution_next, ConsumerSolutionLabeled):
106 raise TypeError(
107 f"solution_next must be ConsumerSolutionLabeled, got {type(solution_next)}"
108 )
109 if "m_nrm_min" not in solution_next.attrs:
110 raise ValueError(
111 "solution_next.attrs must contain 'm_nrm_min'. "
112 "Use make_solution_terminal_labeled() to create valid terminal solutions."
113 )
115 # Input validation - CRRA
116 if not np.isfinite(CRRA):
117 raise ValueError(f"CRRA must be finite, got {CRRA}")
118 if CRRA < 0:
119 raise ValueError(f"CRRA must be non-negative, got {CRRA}")
121 # Input validation - economic parameters
122 if LivPrb <= 0 or LivPrb > 1:
123 raise ValueError(f"LivPrb must be in (0, 1], got {LivPrb}")
124 if DiscFac <= 0:
125 raise ValueError(f"DiscFac must be positive, got {DiscFac}")
126 if Rfree <= 0:
127 raise ValueError(f"Rfree must be positive, got {Rfree}")
128 if PermGroFac <= 0:
129 raise ValueError(f"PermGroFac must be positive, got {PermGroFac}")
131 # Input validation - asset grid
132 aXtraGrid = np.asarray(aXtraGrid)
133 if len(aXtraGrid) == 0:
134 raise ValueError("aXtraGrid cannot be empty")
135 if np.any(aXtraGrid < 0):
136 raise ValueError("aXtraGrid values must be non-negative")
137 if not np.all(np.diff(aXtraGrid) > 0):
138 raise ValueError("aXtraGrid must be strictly increasing")
140 # Store parameters
141 self.solution_next = solution_next
142 self.LivPrb = LivPrb
143 self.DiscFac = DiscFac
144 self.CRRA = CRRA
145 self.Rfree = Rfree
146 self.PermGroFac = PermGroFac
147 self.BoroCnstArt = BoroCnstArt
148 self.aXtraGrid = aXtraGrid
150 # Initialize utility function
151 self.u = UtilityFuncCRRA(CRRA)
153 # Initialize transitions strategy
154 self.transitions = self.TransitionsClass()
156 # Store additional kwargs
157 for key, value in kwargs.items():
158 setattr(self, key, value)
160 # =========================================================================
161 # TEMPLATE METHOD - The algorithm skeleton
162 # =========================================================================
164 def solve(self) -> ConsumerSolutionLabeled:
165 """
166 Solve the consumption-saving problem using EGM.
168 This is the template method that defines the algorithm skeleton.
169 It calls hook methods that subclasses can override.
171 Returns
172 -------
173 ConsumerSolutionLabeled
174 Solution containing value function, policy, and continuation.
175 """
176 self.prepare_to_solve()
177 self.endogenous_grid_method()
178 return self.solution
180 # =========================================================================
181 # HOOK METHODS - Override in subclasses for customization
182 # =========================================================================
184 def create_params_namespace(self) -> SimpleNamespace:
185 """
186 Create parameters namespace.
188 Override in subclasses to add model-specific parameters.
190 Returns
191 -------
192 SimpleNamespace
193 Parameters for this period's problem.
194 """
195 return SimpleNamespace(
196 Discount=self.DiscFac * self.LivPrb,
197 CRRA=self.CRRA,
198 Rfree=self.Rfree,
199 PermGroFac=self.PermGroFac,
200 )
202 def calculate_borrowing_constraint(self) -> None:
203 """
204 Calculate the natural borrowing constraint.
206 Override in shock models to account for minimum shock realizations.
207 Sets self.BoroCnstNat.
208 """
209 self.BoroCnstNat = (
210 (self.solution_next.attrs["m_nrm_min"] - 1)
211 * self.params.PermGroFac
212 / self.params.Rfree
213 )
215 def create_post_state(self) -> xr.Dataset:
216 """
217 Create the post-decision state grid.
219 Override in portfolio models to add the risky share dimension.
221 Returns
222 -------
223 xr.Dataset
224 Post-decision state grid.
225 """
226 if self.nat_boro_cnst:
227 a_grid = self.aXtraGrid + self.m_nrm_min
228 else:
229 a_grid = np.append(0.0, self.aXtraGrid) + self.m_nrm_min
231 aVec = xr.DataArray(
232 a_grid,
233 name="aNrm",
234 dims=("aNrm"),
235 attrs={"long_name": "savings", "state": True},
236 )
237 return xr.Dataset({"aNrm": aVec})
239 def create_continuation_function(self) -> ValueFuncCRRALabeled:
240 """
241 Create the continuation value function.
243 Override in shock models to integrate over shock distributions.
245 Returns
246 -------
247 ValueFuncCRRALabeled
248 Continuation value function.
249 """
250 value_next = self.solution_next.value
252 v_end = self.transitions.continuation(
253 self.post_state, None, value_next, self.params, self.u
254 )
255 v_end = xr.Dataset(v_end).drop_vars(["mNrm"])
257 v_end["v_inv"] = self.u.inv(v_end["v"])
258 v_end["v_der_inv"] = self.u.derinv(v_end["v_der"])
260 borocnst = self.borocnst.drop_vars(["mNrm"]).expand_dims("aNrm")
261 if self.nat_boro_cnst:
262 v_end = xr.merge([borocnst, v_end], join="outer", compat="no_conflicts")
264 return ValueFuncCRRALabeled(v_end, self.params.CRRA)
266 # =========================================================================
267 # CORE METHODS - Shared implementation, rarely overridden
268 # =========================================================================
270 def prepare_to_solve(self) -> None:
271 """Prepare solver state before running EGM."""
272 self.params = self.create_params_namespace()
273 self.calculate_borrowing_constraint()
274 self.define_boundary_constraint()
275 self.post_state = self.create_post_state()
277 def define_boundary_constraint(self) -> None:
278 """Define borrowing constraint boundary conditions."""
279 if self.BoroCnstArt is None or self.BoroCnstArt <= self.BoroCnstNat:
280 self.m_nrm_min = self.BoroCnstNat
281 self.nat_boro_cnst = True
282 self.borocnst = xr.Dataset(
283 coords={"mNrm": self.m_nrm_min, "aNrm": self.m_nrm_min},
284 data_vars={
285 "cNrm": 0.0,
286 "v": -np.inf,
287 "v_inv": 0.0,
288 "reward": -np.inf,
289 "marginal_reward": np.inf,
290 "v_der": np.inf,
291 "v_der_inv": 0.0,
292 },
293 )
294 else:
295 self.m_nrm_min = self.BoroCnstArt
296 self.nat_boro_cnst = False
297 self.borocnst = xr.Dataset(
298 coords={"mNrm": self.m_nrm_min, "aNrm": self.m_nrm_min},
299 data_vars={"cNrm": 0.0},
300 )
302 def state_transition(
303 self, state: dict[str, Any], action: dict[str, Any], params: SimpleNamespace
304 ) -> dict[str, Any]:
305 """Compute post-decision state from state and action."""
306 return {"aNrm": state["mNrm"] - action["cNrm"]}
308 def reverse_transition(
309 self,
310 post_state: dict[str, Any],
311 action: dict[str, Any],
312 params: SimpleNamespace,
313 ) -> dict[str, Any]:
314 """Recover state from post-decision state and action (for EGM)."""
315 return {"mNrm": post_state["aNrm"] + action["cNrm"]}
317 def egm_transition(
318 self,
319 post_state: dict[str, Any],
320 continuation: ValueFuncCRRALabeled,
321 params: SimpleNamespace,
322 ) -> dict[str, Any]:
323 """Compute optimal action using first-order condition (EGM)."""
324 return {
325 "cNrm": self.u.derinv(params.Discount * continuation.derivative(post_state))
326 }
328 def value_transition(
329 self,
330 action: dict[str, Any],
331 state: dict[str, Any],
332 continuation: ValueFuncCRRALabeled,
333 params: SimpleNamespace,
334 ) -> dict[str, Any]:
335 """Compute value function variables from action, state, and continuation."""
336 variables = {}
337 post_state = self.state_transition(state, action, params)
338 variables.update(post_state)
340 variables["reward"] = self.u(action["cNrm"])
341 variables["v"] = variables["reward"] + params.Discount * continuation(
342 post_state
343 )
344 variables["v_inv"] = self.u.inv(variables["v"])
346 variables["marginal_reward"] = self.u.der(action["cNrm"])
347 variables["v_der"] = variables["marginal_reward"]
348 variables["v_der_inv"] = action["cNrm"]
350 variables["contributions"] = variables["v"]
351 variables["value"] = np.sum(variables["v"])
353 return variables
355 def _continuation_for_expectation(
356 self,
357 shocks: dict[str, Any],
358 post_state: dict[str, Any],
359 value_next: ValueFuncCRRALabeled,
360 params: SimpleNamespace,
361 ) -> dict[str, Any]:
362 """
363 Wrapper for continuation transition compatible with expected().
365 This method adapts the transitions.continuation() interface to work
366 with the expected() function from DiscreteDistributionLabeled.
368 Parameters
369 ----------
370 shocks : dict[str, Any]
371 Shock realizations (e.g., perm, tran, risky).
372 post_state : dict[str, Any]
373 Post-decision state (e.g., aNrm).
374 value_next : ValueFuncCRRALabeled
375 Next period's value function.
376 params : SimpleNamespace
377 Model parameters.
379 Returns
380 -------
381 dict[str, Any]
382 Continuation value variables.
383 """
384 return self.transitions.continuation(
385 post_state, shocks, value_next, params, self.u
386 )
388 def endogenous_grid_method(self) -> None:
389 """Execute the Endogenous Grid Method algorithm."""
390 wfunc = self.create_continuation_function()
392 # Check for numerical issues in continuation function
393 if np.any(~np.isfinite(wfunc.dataset["v_der_inv"].values)):
394 warnings.warn(
395 "Continuation value function contains NaN or Inf values. "
396 "This may indicate invalid parameters (CRRA too high, "
397 "PermGroFac issues, or extreme shock realizations).",
398 RuntimeWarning,
399 stacklevel=2,
400 )
402 # EGM: Get optimal actions from first-order condition
403 acted = self.egm_transition(self.post_state, wfunc, self.params)
404 state = self.reverse_transition(self.post_state, acted, self.params)
406 # Check for numerical issues in EGM results
407 if np.any(acted["cNrm"] < 0):
408 warnings.warn(
409 "EGM produced negative consumption values. "
410 "Check discount factor and interest rate parameters.",
411 RuntimeWarning,
412 stacklevel=2,
413 )
415 # Swap dimensions for state-based indexing
416 action = xr.Dataset(acted).swap_dims({"aNrm": "mNrm"})
417 state = xr.Dataset(state).swap_dims({"aNrm": "mNrm"})
419 egm_dataset = xr.merge([action, state])
421 if not self.nat_boro_cnst:
422 egm_dataset = xr.concat(
423 [self.borocnst, egm_dataset], dim="mNrm", data_vars="all"
424 )
426 # Compute values
427 values = self.value_transition(egm_dataset, egm_dataset, wfunc, self.params)
428 egm_dataset.update(values)
430 if self.nat_boro_cnst:
431 egm_dataset = xr.concat(
432 [self.borocnst, egm_dataset],
433 dim="mNrm",
434 data_vars="all",
435 combine_attrs="no_conflicts",
436 )
438 egm_dataset = egm_dataset.drop_vars("aNrm")
440 # Build solution
441 vfunc = ValueFuncCRRALabeled(
442 egm_dataset[["v", "v_der", "v_inv", "v_der_inv"]], self.params.CRRA
443 )
444 pfunc = egm_dataset[["cNrm"]]
446 self.solution = ConsumerSolutionLabeled(
447 value=vfunc,
448 policy=pfunc,
449 continuation=wfunc,
450 attrs={"m_nrm_min": self.m_nrm_min, "dataset": egm_dataset},
451 )
454class ConsPerfForesightLabeledSolver(BaseLabeledSolver):
455 """
456 Solver for perfect foresight consumption model.
458 Uses PerfectForesightTransitions - no shocks, risk-free return only.
459 """
461 TransitionsClass = PerfectForesightTransitions
464class ConsIndShockLabeledSolver(BaseLabeledSolver):
465 """
466 Solver for consumption model with idiosyncratic income shocks.
468 Uses IndShockTransitions and integrates continuation value over
469 the income shock distribution.
471 Additional Parameters
472 ---------------------
473 IncShkDstn : DiscreteDistributionLabeled
474 Distribution of income shocks with 'perm' and 'tran' variables.
475 """
477 TransitionsClass = IndShockTransitions
479 def __init__(
480 self,
481 solution_next: ConsumerSolutionLabeled,
482 IncShkDstn: DiscreteDistributionLabeled,
483 LivPrb: float,
484 DiscFac: float,
485 CRRA: float,
486 Rfree: float,
487 PermGroFac: float,
488 BoroCnstArt: float | None,
489 aXtraGrid: np.ndarray,
490 **kwargs,
491 ) -> None:
492 self.IncShkDstn = IncShkDstn
493 super().__init__(
494 solution_next=solution_next,
495 LivPrb=LivPrb,
496 DiscFac=DiscFac,
497 CRRA=CRRA,
498 Rfree=Rfree,
499 PermGroFac=PermGroFac,
500 BoroCnstArt=BoroCnstArt,
501 aXtraGrid=aXtraGrid,
502 **kwargs,
503 )
505 def calculate_borrowing_constraint(self) -> None:
506 """Calculate constraint accounting for minimum shock realizations."""
507 PermShkMinNext = np.min(self.IncShkDstn.atoms[0])
508 TranShkMinNext = np.min(self.IncShkDstn.atoms[1])
510 self.BoroCnstNat = (
511 (self.solution_next.attrs["m_nrm_min"] - TranShkMinNext)
512 * (self.params.PermGroFac * PermShkMinNext)
513 / self.params.Rfree
514 )
516 def create_continuation_function(self) -> ValueFuncCRRALabeled:
517 """Create continuation function by integrating over income shocks."""
518 value_next = self.solution_next.value
520 v_end = self.IncShkDstn.expected(
521 func=self._continuation_for_expectation,
522 post_state=self.post_state,
523 value_next=value_next,
524 params=self.params,
525 )
527 v_end["v_inv"] = self.u.inv(v_end["v"])
528 v_end["v_der_inv"] = self.u.derinv(v_end["v_der"])
530 borocnst = self.borocnst.drop_vars(["mNrm"]).expand_dims("aNrm")
531 if self.nat_boro_cnst:
532 v_end = xr.merge([borocnst, v_end], join="outer", compat="no_conflicts")
534 return ValueFuncCRRALabeled(v_end, self.params.CRRA)
537class ConsRiskyAssetLabeledSolver(BaseLabeledSolver):
538 """
539 Solver for consumption model with risky asset.
541 Uses RiskyAssetTransitions - all savings earn stochastic risky return.
543 Additional Parameters
544 ---------------------
545 ShockDstn : DiscreteDistributionLabeled
546 Joint distribution of income and risky return shocks.
547 """
549 TransitionsClass = RiskyAssetTransitions
551 def __init__(
552 self,
553 solution_next: ConsumerSolutionLabeled,
554 ShockDstn: DiscreteDistributionLabeled,
555 LivPrb: float,
556 DiscFac: float,
557 CRRA: float,
558 Rfree: float,
559 PermGroFac: float,
560 BoroCnstArt: float | None,
561 aXtraGrid: np.ndarray,
562 **kwargs,
563 ) -> None:
564 self.ShockDstn = ShockDstn
565 super().__init__(
566 solution_next=solution_next,
567 LivPrb=LivPrb,
568 DiscFac=DiscFac,
569 CRRA=CRRA,
570 Rfree=Rfree,
571 PermGroFac=PermGroFac,
572 BoroCnstArt=BoroCnstArt,
573 aXtraGrid=aXtraGrid,
574 **kwargs,
575 )
577 def calculate_borrowing_constraint(self) -> None:
578 """Calculate constraint with artificial borrowing constraint."""
579 self.BoroCnstArt = 0.0
580 self.IncShkDstn = self.ShockDstn
582 PermShkMinNext = np.min(self.ShockDstn.atoms[0])
583 TranShkMinNext = np.min(self.ShockDstn.atoms[1])
585 self.BoroCnstNat = (
586 (self.solution_next.attrs["m_nrm_min"] - TranShkMinNext)
587 * (self.params.PermGroFac * PermShkMinNext)
588 / self.params.Rfree
589 )
591 def create_continuation_function(self) -> ValueFuncCRRALabeled:
592 """Create continuation function integrating over shock distribution."""
593 value_next = self.solution_next.value
595 v_end = self.ShockDstn.expected(
596 func=self._continuation_for_expectation,
597 post_state=self.post_state,
598 value_next=value_next,
599 params=self.params,
600 )
602 v_end["v_inv"] = self.u.inv(v_end["v"])
603 v_end["v_der_inv"] = self.u.derinv(v_end["v_der"])
605 borocnst = self.borocnst.drop_vars(["mNrm"]).expand_dims("aNrm")
606 if self.nat_boro_cnst:
607 v_end = xr.merge([borocnst, v_end], join="outer", compat="no_conflicts")
609 v_end = v_end.transpose("aNrm", ...)
611 return ValueFuncCRRALabeled(v_end, self.params.CRRA)
614class ConsFixedPortfolioLabeledSolver(ConsRiskyAssetLabeledSolver):
615 """
616 Solver for consumption model with fixed portfolio allocation.
618 Uses FixedPortfolioTransitions - agent allocates fixed share to risky asset.
620 Additional Parameters
621 ---------------------
622 RiskyShareFixed : float
623 Fixed share of savings allocated to risky asset.
624 """
626 TransitionsClass = FixedPortfolioTransitions
628 def __init__(
629 self,
630 solution_next: ConsumerSolutionLabeled,
631 ShockDstn: DiscreteDistributionLabeled,
632 LivPrb: float,
633 DiscFac: float,
634 CRRA: float,
635 Rfree: float,
636 PermGroFac: float,
637 BoroCnstArt: float | None,
638 aXtraGrid: np.ndarray,
639 RiskyShareFixed: float,
640 **kwargs,
641 ) -> None:
642 # Validate RiskyShareFixed
643 if RiskyShareFixed < 0 or RiskyShareFixed > 1:
644 raise ValueError(
645 f"RiskyShareFixed must be in [0, 1], got {RiskyShareFixed}"
646 )
648 self.RiskyShareFixed = RiskyShareFixed
649 super().__init__(
650 solution_next=solution_next,
651 ShockDstn=ShockDstn,
652 LivPrb=LivPrb,
653 DiscFac=DiscFac,
654 CRRA=CRRA,
655 Rfree=Rfree,
656 PermGroFac=PermGroFac,
657 BoroCnstArt=BoroCnstArt,
658 aXtraGrid=aXtraGrid,
659 **kwargs,
660 )
662 def create_params_namespace(self) -> SimpleNamespace:
663 """Add RiskyShareFixed to parameters."""
664 params = super().create_params_namespace()
665 params.RiskyShareFixed = self.RiskyShareFixed
666 return params
669class ConsPortfolioLabeledSolver(ConsRiskyAssetLabeledSolver):
670 """
671 Solver for consumption model with optimal portfolio choice.
673 Uses PortfolioTransitions - agent optimally chooses risky share each period.
674 The optimal share is found by solving the portfolio first-order condition.
676 Additional Parameters
677 ---------------------
678 ShareGrid : np.ndarray
679 Grid of risky share values to search over.
680 """
682 TransitionsClass = PortfolioTransitions
684 def __init__(
685 self,
686 solution_next: ConsumerSolutionLabeled,
687 ShockDstn: DiscreteDistributionLabeled,
688 LivPrb: float,
689 DiscFac: float,
690 CRRA: float,
691 Rfree: float,
692 PermGroFac: float,
693 BoroCnstArt: float | None,
694 aXtraGrid: np.ndarray,
695 ShareGrid: np.ndarray,
696 **kwargs,
697 ) -> None:
698 # Validate ShareGrid
699 ShareGrid = np.asarray(ShareGrid)
700 if len(ShareGrid) == 0:
701 raise ValueError("ShareGrid cannot be empty")
702 if np.any(ShareGrid < 0) or np.any(ShareGrid > 1):
703 raise ValueError("ShareGrid values must be in [0, 1]")
704 if not np.all(np.diff(ShareGrid) > 0):
705 raise ValueError("ShareGrid must be strictly increasing")
707 self.ShareGrid = ShareGrid
708 super().__init__(
709 solution_next=solution_next,
710 ShockDstn=ShockDstn,
711 LivPrb=LivPrb,
712 DiscFac=DiscFac,
713 CRRA=CRRA,
714 Rfree=Rfree,
715 PermGroFac=PermGroFac,
716 BoroCnstArt=BoroCnstArt,
717 aXtraGrid=aXtraGrid,
718 **kwargs,
719 )
721 def create_post_state(self) -> xr.Dataset:
722 """Add risky share dimension to post-decision state."""
723 post_state = super().create_post_state()
724 post_state["stigma"] = xr.DataArray(
725 self.ShareGrid, dims=["stigma"], attrs={"long_name": "risky share"}
726 )
727 return post_state
729 def create_continuation_function(self) -> ValueFuncCRRALabeled:
730 """
731 Create continuation function with optimal portfolio choice.
733 First computes continuation value over the (aNrm, stigma) grid,
734 then finds the optimal stigma for each aNrm level.
735 """
736 # Get continuation value over full (aNrm, stigma) grid
737 wfunc = super().create_continuation_function()
739 dvds = wfunc.dataset["dvds"].values
741 # Find optimal share using linear interpolation on FOC
742 crossing = np.logical_and(dvds[:, 1:] <= 0.0, dvds[:, :-1] >= 0.0)
743 share_idx = np.argmax(crossing, axis=1)
744 a_idx = np.arange(self.post_state["aNrm"].size)
746 bottom_share = self.ShareGrid[share_idx]
747 top_share = self.ShareGrid[share_idx + 1]
748 bottom_foc = dvds[a_idx, share_idx]
749 top_foc = dvds[a_idx, share_idx + 1]
751 # Linear interpolation with division-by-zero protection
752 denominator = top_foc - bottom_foc
753 fallback_mask = np.abs(denominator) <= 1e-12
754 if np.any(fallback_mask):
755 n_fallbacks = np.sum(fallback_mask)
756 warnings.warn(
757 f"Portfolio optimization used fallback interpolation for {n_fallbacks} "
758 f"grid points due to near-zero FOC difference. "
759 f"Consider refining ShareGrid for more accurate results.",
760 RuntimeWarning,
761 stacklevel=2,
762 )
763 alpha = np.where(
764 ~fallback_mask,
765 1.0 - top_foc / denominator,
766 0.5,
767 )
768 opt_share = (1.0 - alpha) * bottom_share + alpha * top_share
770 # Handle corner solutions
771 opt_share[dvds[:, -1] > 0.0] = 1.0 # Want more than 100% risky
772 opt_share[dvds[:, 0] < 0.0] = 0.0 # Want less than 0% risky
774 if not self.nat_boro_cnst:
775 # At aNrm = 0 the portfolio share is irrelevant; 1.0 is limit as a --> 0
776 opt_share[0] = 1.0
778 opt_share = xr.DataArray(
779 opt_share,
780 coords={"aNrm": self.post_state["aNrm"].values},
781 dims=["aNrm"],
782 attrs={"long_name": "optimal risky share"},
783 )
785 # Evaluate continuation at optimal share
786 v_end = wfunc.evaluate({"aNrm": self.post_state["aNrm"], "stigma": opt_share})
787 v_end = v_end.reset_coords(names="stigma")
789 wfunc = ValueFuncCRRALabeled(v_end, self.params.CRRA)
791 # Remove stigma from post_state for EGM
792 self.post_state = self.post_state.drop_vars("stigma")
794 return wfunc