Coverage for HARK / Labeled / solvers.py: 91%
214 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-10 06:19 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-10 06:19 +0000
1"""
2Solvers for labeled consumption-saving models.
4This module implements the Template Method pattern for the Endogenous Grid
5Method (EGM) algorithm. The base solver defines the algorithm skeleton,
6while concrete solvers override specific hook methods for their model type.
7"""
9from __future__ import annotations
11import warnings
12from types import SimpleNamespace
13from typing import TYPE_CHECKING, Any
15import numpy as np
16import xarray as xr
18from HARK.metric import MetricObject
19from HARK.rewards import UtilityFuncCRRA
21from .solution import ConsumerSolutionLabeled, ValueFuncCRRALabeled
22from .transitions import (
23 FixedPortfolioTransitions,
24 IndShockTransitions,
25 PerfectForesightTransitions,
26 PortfolioTransitions,
27 RiskyAssetTransitions,
28)
30if TYPE_CHECKING:
31 from HARK.distributions import DiscreteDistributionLabeled
33__all__ = [
34 "BaseLabeledSolver",
35 "ConsPerfForesightLabeledSolver",
36 "ConsIndShockLabeledSolver",
37 "ConsRiskyAssetLabeledSolver",
38 "ConsFixedPortfolioLabeledSolver",
39 "ConsPortfolioLabeledSolver",
40]
43class BaseLabeledSolver(MetricObject):
44 """
45 Base solver implementing Template Method pattern for EGM algorithm.
47 This class provides the algorithm skeleton for solving consumption-saving
48 problems using the Endogenous Grid Method. Subclasses customize behavior
49 by:
50 1. Setting TransitionsClass to specify model-specific transitions
51 2. Overriding hook methods for model-specific logic
53 Template Method: solve()
54 Hook Methods:
55 - create_params_namespace(): Add model-specific parameters
56 - calculate_borrowing_constraint(): Model-specific constraint logic
57 - create_post_state(): Add extra state dimensions (e.g., stigma)
58 - create_continuation_function(): Handle shock integration
60 Parameters
61 ----------
62 solution_next : ConsumerSolutionLabeled
63 Solution to next period's problem.
64 LivPrb : float
65 Survival probability.
66 DiscFac : float
67 Intertemporal discount factor.
68 CRRA : float
69 Coefficient of relative risk aversion.
70 Rfree : float
71 Risk-free interest factor.
72 PermGroFac : float
73 Permanent income growth factor.
74 BoroCnstArt : float or None
75 Artificial borrowing constraint.
76 aXtraGrid : np.ndarray
77 Grid of end-of-period asset values above minimum.
78 **kwargs
79 Additional model-specific parameters.
81 Raises
82 ------
83 ValueError
84 If CRRA is invalid or aXtraGrid is malformed.
85 """
87 # Class-level strategy specification - override in subclasses
88 TransitionsClass: type = PerfectForesightTransitions
90 def __init__(
91 self,
92 solution_next: ConsumerSolutionLabeled,
93 LivPrb: float,
94 DiscFac: float,
95 CRRA: float,
96 Rfree: float,
97 PermGroFac: float,
98 BoroCnstArt: float | None,
99 aXtraGrid: np.ndarray,
100 **kwargs,
101 ) -> None:
102 # Input validation - solution_next
103 if solution_next is None:
104 raise ValueError("solution_next cannot be None")
105 if not isinstance(solution_next, ConsumerSolutionLabeled):
106 raise TypeError(
107 f"solution_next must be ConsumerSolutionLabeled, got {type(solution_next)}"
108 )
109 if "m_nrm_min" not in solution_next.attrs:
110 raise ValueError(
111 "solution_next.attrs must contain 'm_nrm_min'. "
112 "Use make_solution_terminal_labeled() to create valid terminal solutions."
113 )
115 # Input validation - CRRA
116 if not np.isfinite(CRRA):
117 raise ValueError(f"CRRA must be finite, got {CRRA}")
118 if CRRA < 0:
119 raise ValueError(f"CRRA must be non-negative, got {CRRA}")
121 # Input validation - economic parameters
122 if LivPrb <= 0 or LivPrb > 1:
123 raise ValueError(f"LivPrb must be in (0, 1], got {LivPrb}")
124 if DiscFac <= 0:
125 raise ValueError(f"DiscFac must be positive, got {DiscFac}")
126 if Rfree <= 0:
127 raise ValueError(f"Rfree must be positive, got {Rfree}")
128 if PermGroFac <= 0:
129 raise ValueError(f"PermGroFac must be positive, got {PermGroFac}")
131 # Input validation - asset grid
132 aXtraGrid = np.asarray(aXtraGrid)
133 if len(aXtraGrid) == 0:
134 raise ValueError("aXtraGrid cannot be empty")
135 if np.any(aXtraGrid < 0):
136 raise ValueError("aXtraGrid values must be non-negative")
137 if not np.all(np.diff(aXtraGrid) > 0):
138 raise ValueError("aXtraGrid must be strictly increasing")
140 # Store parameters
141 self.solution_next = solution_next
142 self.LivPrb = LivPrb
143 self.DiscFac = DiscFac
144 self.CRRA = CRRA
145 self.Rfree = Rfree
146 self.PermGroFac = PermGroFac
147 self.BoroCnstArt = BoroCnstArt
148 self.aXtraGrid = aXtraGrid
150 # Initialize utility function
151 self.u = UtilityFuncCRRA(CRRA)
153 # Initialize transitions strategy
154 self.transitions = self.TransitionsClass()
156 # Store additional kwargs
157 for key, value in kwargs.items():
158 setattr(self, key, value)
160 # =========================================================================
161 # TEMPLATE METHOD - The algorithm skeleton
162 # =========================================================================
164 def solve(self) -> ConsumerSolutionLabeled:
165 """
166 Solve the consumption-saving problem using EGM.
168 This is the template method that defines the algorithm skeleton.
169 It calls hook methods that subclasses can override.
171 Returns
172 -------
173 ConsumerSolutionLabeled
174 Solution containing value function, policy, and continuation.
175 """
176 self.prepare_to_solve()
177 self.endogenous_grid_method()
178 return self.solution
180 # =========================================================================
181 # HOOK METHODS - Override in subclasses for customization
182 # =========================================================================
184 def create_params_namespace(self) -> SimpleNamespace:
185 """
186 Create parameters namespace.
188 Override in subclasses to add model-specific parameters.
190 Returns
191 -------
192 SimpleNamespace
193 Parameters for this period's problem.
194 """
195 return SimpleNamespace(
196 Discount=self.DiscFac * self.LivPrb,
197 CRRA=self.CRRA,
198 Rfree=self.Rfree,
199 PermGroFac=self.PermGroFac,
200 )
202 def calculate_borrowing_constraint(self) -> None:
203 """
204 Calculate the natural borrowing constraint.
206 Override in shock models to account for minimum shock realizations.
207 Sets self.BoroCnstNat.
208 """
209 self.BoroCnstNat = (
210 (self.solution_next.attrs["m_nrm_min"] - 1)
211 * self.params.PermGroFac
212 / self.params.Rfree
213 )
215 def create_post_state(self) -> xr.Dataset:
216 """
217 Create the post-decision state grid.
219 Override in portfolio models to add the risky share dimension.
221 Returns
222 -------
223 xr.Dataset
224 Post-decision state grid.
225 """
226 if self.nat_boro_cnst:
227 a_grid = self.aXtraGrid + self.m_nrm_min
228 else:
229 a_grid = np.append(0.0, self.aXtraGrid) + self.m_nrm_min
231 aVec = xr.DataArray(
232 a_grid,
233 name="aNrm",
234 dims=("aNrm"),
235 attrs={"long_name": "savings", "state": True},
236 )
237 return xr.Dataset({"aNrm": aVec})
239 def create_continuation_function(self) -> ValueFuncCRRALabeled:
240 """
241 Create the continuation value function.
243 Override in shock models to integrate over shock distributions.
245 Returns
246 -------
247 ValueFuncCRRALabeled
248 Continuation value function.
249 """
250 value_next = self.solution_next.value
252 v_end = self.transitions.continuation(
253 self.post_state, None, value_next, self.params, self.u
254 )
255 v_end = xr.Dataset(v_end).drop_vars(["mNrm"])
257 return self._finalize_value_func(v_end)
259 def _finalize_value_func(self, v_end, transform=None) -> ValueFuncCRRALabeled:
260 """Apply the post-processing shared by ``create_continuation_function``
261 across solver subclasses: write inverse value variables, merge in the
262 borrowing-constraint boundary when natural, and optionally apply
263 ``transform`` (e.g. a ``.transpose(...)`` reordering)."""
264 v_end["v_inv"] = self.u.inv(v_end["v"])
265 v_end["v_der_inv"] = self.u.derinv(v_end["v_der"])
266 borocnst = self.borocnst.drop_vars(["mNrm"]).expand_dims("aNrm")
267 if self.nat_boro_cnst:
268 v_end = xr.merge([borocnst, v_end], join="outer", compat="no_conflicts")
269 if transform is not None:
270 v_end = transform(v_end)
271 return ValueFuncCRRALabeled(v_end, self.params.CRRA)
273 def _natural_boro_cnst_with_shocks(self, perm_shk_min, tran_shk_min) -> float:
274 """Compute ``BoroCnstNat`` accounting for minimum shock realizations.
275 Shared by ``ConsIndShockLabeledSolver`` and ``ConsRiskyAssetLabeledSolver``."""
276 return (
277 (self.solution_next.attrs["m_nrm_min"] - tran_shk_min)
278 * (self.params.PermGroFac * perm_shk_min)
279 / self.params.Rfree
280 )
282 # =========================================================================
283 # CORE METHODS - Shared implementation, rarely overridden
284 # =========================================================================
286 def prepare_to_solve(self) -> None:
287 """Prepare solver state before running EGM."""
288 self.params = self.create_params_namespace()
289 self.calculate_borrowing_constraint()
290 self.define_boundary_constraint()
291 self.post_state = self.create_post_state()
293 def define_boundary_constraint(self) -> None:
294 """Define borrowing constraint boundary conditions."""
295 if self.BoroCnstArt is None or self.BoroCnstArt <= self.BoroCnstNat:
296 self.m_nrm_min = self.BoroCnstNat
297 self.nat_boro_cnst = True
298 self.borocnst = xr.Dataset(
299 coords={"mNrm": self.m_nrm_min, "aNrm": self.m_nrm_min},
300 data_vars={
301 "cNrm": 0.0,
302 "v": -np.inf,
303 "v_inv": 0.0,
304 "reward": -np.inf,
305 "marginal_reward": np.inf,
306 "v_der": np.inf,
307 "v_der_inv": 0.0,
308 },
309 )
310 else:
311 self.m_nrm_min = self.BoroCnstArt
312 self.nat_boro_cnst = False
313 self.borocnst = xr.Dataset(
314 coords={"mNrm": self.m_nrm_min, "aNrm": self.m_nrm_min},
315 data_vars={"cNrm": 0.0},
316 )
318 def state_transition(
319 self, state: dict[str, Any], action: dict[str, Any], params: SimpleNamespace
320 ) -> dict[str, Any]:
321 """Compute post-decision state from state and action."""
322 return {"aNrm": state["mNrm"] - action["cNrm"]}
324 def reverse_transition(
325 self,
326 post_state: dict[str, Any],
327 action: dict[str, Any],
328 params: SimpleNamespace,
329 ) -> dict[str, Any]:
330 """Recover state from post-decision state and action (for EGM)."""
331 return {"mNrm": post_state["aNrm"] + action["cNrm"]}
333 def egm_transition(
334 self,
335 post_state: dict[str, Any],
336 continuation: ValueFuncCRRALabeled,
337 params: SimpleNamespace,
338 ) -> dict[str, Any]:
339 """Compute optimal action using first-order condition (EGM)."""
340 return {
341 "cNrm": self.u.derinv(params.Discount * continuation.derivative(post_state))
342 }
344 def value_transition(
345 self,
346 action: dict[str, Any],
347 state: dict[str, Any],
348 continuation: ValueFuncCRRALabeled,
349 params: SimpleNamespace,
350 ) -> dict[str, Any]:
351 """Compute value function variables from action, state, and continuation."""
352 variables = {}
353 post_state = self.state_transition(state, action, params)
354 variables.update(post_state)
356 variables["reward"] = self.u(action["cNrm"])
357 variables["v"] = variables["reward"] + params.Discount * continuation(
358 post_state
359 )
360 variables["v_inv"] = self.u.inv(variables["v"])
362 variables["marginal_reward"] = self.u.der(action["cNrm"])
363 variables["v_der"] = variables["marginal_reward"]
364 variables["v_der_inv"] = action["cNrm"]
366 variables["contributions"] = variables["v"]
367 variables["value"] = np.sum(variables["v"])
369 return variables
371 def _continuation_for_expectation(
372 self,
373 shocks: dict[str, Any],
374 post_state: dict[str, Any],
375 value_next: ValueFuncCRRALabeled,
376 params: SimpleNamespace,
377 ) -> dict[str, Any]:
378 """
379 Wrapper for continuation transition compatible with expected().
381 This method adapts the transitions.continuation() interface to work
382 with the expected() function from DiscreteDistributionLabeled.
384 Parameters
385 ----------
386 shocks : dict[str, Any]
387 Shock realizations (e.g., perm, tran, risky).
388 post_state : dict[str, Any]
389 Post-decision state (e.g., aNrm).
390 value_next : ValueFuncCRRALabeled
391 Next period's value function.
392 params : SimpleNamespace
393 Model parameters.
395 Returns
396 -------
397 dict[str, Any]
398 Continuation value variables.
399 """
400 return self.transitions.continuation(
401 post_state, shocks, value_next, params, self.u
402 )
404 def endogenous_grid_method(self) -> None:
405 """Execute the Endogenous Grid Method algorithm."""
406 wfunc = self.create_continuation_function()
408 # Check for numerical issues in continuation function
409 if np.any(~np.isfinite(wfunc.dataset["v_der_inv"].values)):
410 warnings.warn(
411 "Continuation value function contains NaN or Inf values. "
412 "This may indicate invalid parameters (CRRA too high, "
413 "PermGroFac issues, or extreme shock realizations).",
414 RuntimeWarning,
415 stacklevel=2,
416 )
418 # EGM: Get optimal actions from first-order condition
419 acted = self.egm_transition(self.post_state, wfunc, self.params)
420 state = self.reverse_transition(self.post_state, acted, self.params)
422 # Check for numerical issues in EGM results
423 if np.any(acted["cNrm"] < 0):
424 warnings.warn(
425 "EGM produced negative consumption values. "
426 "Check discount factor and interest rate parameters.",
427 RuntimeWarning,
428 stacklevel=2,
429 )
431 # Swap dimensions for state-based indexing
432 action = xr.Dataset(acted).swap_dims({"aNrm": "mNrm"})
433 state = xr.Dataset(state).swap_dims({"aNrm": "mNrm"})
435 egm_dataset = xr.merge([action, state])
437 if not self.nat_boro_cnst:
438 egm_dataset = xr.concat(
439 [self.borocnst, egm_dataset], dim="mNrm", data_vars="all"
440 )
442 # Compute values
443 values = self.value_transition(egm_dataset, egm_dataset, wfunc, self.params)
444 egm_dataset.update(values)
446 if self.nat_boro_cnst:
447 egm_dataset = xr.concat(
448 [self.borocnst, egm_dataset],
449 dim="mNrm",
450 data_vars="all",
451 combine_attrs="no_conflicts",
452 )
454 egm_dataset = egm_dataset.drop_vars("aNrm")
456 # Build solution
457 vfunc = ValueFuncCRRALabeled(
458 egm_dataset[["v", "v_der", "v_inv", "v_der_inv"]], self.params.CRRA
459 )
460 pfunc = egm_dataset[["cNrm"]]
462 self.solution = ConsumerSolutionLabeled(
463 value=vfunc,
464 policy=pfunc,
465 continuation=wfunc,
466 attrs={"m_nrm_min": self.m_nrm_min, "dataset": egm_dataset},
467 )
470class ConsPerfForesightLabeledSolver(BaseLabeledSolver):
471 """
472 Solver for perfect foresight consumption model.
474 Uses PerfectForesightTransitions - no shocks, risk-free return only.
475 """
477 TransitionsClass = PerfectForesightTransitions
480class ConsIndShockLabeledSolver(BaseLabeledSolver):
481 """
482 Solver for consumption model with idiosyncratic income shocks.
484 Uses IndShockTransitions and integrates continuation value over
485 the income shock distribution.
487 Additional Parameters
488 ---------------------
489 IncShkDstn : DiscreteDistributionLabeled
490 Distribution of income shocks with 'perm' and 'tran' variables.
491 """
493 TransitionsClass = IndShockTransitions
495 def __init__(
496 self,
497 solution_next: ConsumerSolutionLabeled,
498 IncShkDstn: DiscreteDistributionLabeled,
499 LivPrb: float,
500 DiscFac: float,
501 CRRA: float,
502 Rfree: float,
503 PermGroFac: float,
504 BoroCnstArt: float | None,
505 aXtraGrid: np.ndarray,
506 **kwargs,
507 ) -> None:
508 self.IncShkDstn = IncShkDstn
509 super().__init__(
510 solution_next=solution_next,
511 LivPrb=LivPrb,
512 DiscFac=DiscFac,
513 CRRA=CRRA,
514 Rfree=Rfree,
515 PermGroFac=PermGroFac,
516 BoroCnstArt=BoroCnstArt,
517 aXtraGrid=aXtraGrid,
518 **kwargs,
519 )
521 def calculate_borrowing_constraint(self) -> None:
522 """Calculate constraint accounting for minimum shock realizations."""
523 self.BoroCnstNat = self._natural_boro_cnst_with_shocks(
524 np.min(self.IncShkDstn.atoms[0]),
525 np.min(self.IncShkDstn.atoms[1]),
526 )
528 def create_continuation_function(self) -> ValueFuncCRRALabeled:
529 """Create continuation function by integrating over income shocks."""
530 v_end = self.IncShkDstn.expected(
531 func=self._continuation_for_expectation,
532 post_state=self.post_state,
533 value_next=self.solution_next.value,
534 params=self.params,
535 )
536 return self._finalize_value_func(v_end)
539class ConsRiskyAssetLabeledSolver(BaseLabeledSolver):
540 """
541 Solver for consumption model with risky asset.
543 Uses RiskyAssetTransitions - all savings earn stochastic risky return.
545 Additional Parameters
546 ---------------------
547 ShockDstn : DiscreteDistributionLabeled
548 Joint distribution of income and risky return shocks.
549 """
551 TransitionsClass = RiskyAssetTransitions
553 def __init__(
554 self,
555 solution_next: ConsumerSolutionLabeled,
556 ShockDstn: DiscreteDistributionLabeled,
557 LivPrb: float,
558 DiscFac: float,
559 CRRA: float,
560 Rfree: float,
561 PermGroFac: float,
562 BoroCnstArt: float | None,
563 aXtraGrid: np.ndarray,
564 **kwargs,
565 ) -> None:
566 self.ShockDstn = ShockDstn
567 super().__init__(
568 solution_next=solution_next,
569 LivPrb=LivPrb,
570 DiscFac=DiscFac,
571 CRRA=CRRA,
572 Rfree=Rfree,
573 PermGroFac=PermGroFac,
574 BoroCnstArt=BoroCnstArt,
575 aXtraGrid=aXtraGrid,
576 **kwargs,
577 )
579 def calculate_borrowing_constraint(self) -> None:
580 """Calculate constraint with artificial borrowing constraint."""
581 self.BoroCnstArt = 0.0
582 self.IncShkDstn = self.ShockDstn
583 self.BoroCnstNat = self._natural_boro_cnst_with_shocks(
584 np.min(self.ShockDstn.atoms[0]),
585 np.min(self.ShockDstn.atoms[1]),
586 )
588 def create_continuation_function(self) -> ValueFuncCRRALabeled:
589 """Create continuation function integrating over shock distribution."""
590 v_end = self.ShockDstn.expected(
591 func=self._continuation_for_expectation,
592 post_state=self.post_state,
593 value_next=self.solution_next.value,
594 params=self.params,
595 )
596 return self._finalize_value_func(
597 v_end, transform=lambda d: d.transpose("aNrm", ...)
598 )
601class ConsFixedPortfolioLabeledSolver(ConsRiskyAssetLabeledSolver):
602 """
603 Solver for consumption model with fixed portfolio allocation.
605 Uses FixedPortfolioTransitions - agent allocates fixed share to risky asset.
607 Additional Parameters
608 ---------------------
609 RiskyShareFixed : float
610 Fixed share of savings allocated to risky asset.
611 """
613 TransitionsClass = FixedPortfolioTransitions
615 def __init__(
616 self,
617 solution_next: ConsumerSolutionLabeled,
618 ShockDstn: DiscreteDistributionLabeled,
619 LivPrb: float,
620 DiscFac: float,
621 CRRA: float,
622 Rfree: float,
623 PermGroFac: float,
624 BoroCnstArt: float | None,
625 aXtraGrid: np.ndarray,
626 RiskyShareFixed: float,
627 **kwargs,
628 ) -> None:
629 # Validate RiskyShareFixed
630 if RiskyShareFixed < 0 or RiskyShareFixed > 1:
631 raise ValueError(
632 f"RiskyShareFixed must be in [0, 1], got {RiskyShareFixed}"
633 )
635 self.RiskyShareFixed = RiskyShareFixed
636 super().__init__(
637 solution_next=solution_next,
638 ShockDstn=ShockDstn,
639 LivPrb=LivPrb,
640 DiscFac=DiscFac,
641 CRRA=CRRA,
642 Rfree=Rfree,
643 PermGroFac=PermGroFac,
644 BoroCnstArt=BoroCnstArt,
645 aXtraGrid=aXtraGrid,
646 **kwargs,
647 )
649 def create_params_namespace(self) -> SimpleNamespace:
650 """Add RiskyShareFixed to parameters."""
651 params = super().create_params_namespace()
652 params.RiskyShareFixed = self.RiskyShareFixed
653 return params
656class ConsPortfolioLabeledSolver(ConsRiskyAssetLabeledSolver):
657 """
658 Solver for consumption model with optimal portfolio choice.
660 Uses PortfolioTransitions - agent optimally chooses risky share each period.
661 The optimal share is found by solving the portfolio first-order condition.
663 Additional Parameters
664 ---------------------
665 ShareGrid : np.ndarray
666 Grid of risky share values to search over.
667 """
669 TransitionsClass = PortfolioTransitions
671 def __init__(
672 self,
673 solution_next: ConsumerSolutionLabeled,
674 ShockDstn: DiscreteDistributionLabeled,
675 LivPrb: float,
676 DiscFac: float,
677 CRRA: float,
678 Rfree: float,
679 PermGroFac: float,
680 BoroCnstArt: float | None,
681 aXtraGrid: np.ndarray,
682 ShareGrid: np.ndarray,
683 **kwargs,
684 ) -> None:
685 # Validate ShareGrid
686 ShareGrid = np.asarray(ShareGrid)
687 if len(ShareGrid) == 0:
688 raise ValueError("ShareGrid cannot be empty")
689 if np.any(ShareGrid < 0) or np.any(ShareGrid > 1):
690 raise ValueError("ShareGrid values must be in [0, 1]")
691 if not np.all(np.diff(ShareGrid) > 0):
692 raise ValueError("ShareGrid must be strictly increasing")
694 self.ShareGrid = ShareGrid
695 super().__init__(
696 solution_next=solution_next,
697 ShockDstn=ShockDstn,
698 LivPrb=LivPrb,
699 DiscFac=DiscFac,
700 CRRA=CRRA,
701 Rfree=Rfree,
702 PermGroFac=PermGroFac,
703 BoroCnstArt=BoroCnstArt,
704 aXtraGrid=aXtraGrid,
705 **kwargs,
706 )
708 def create_post_state(self) -> xr.Dataset:
709 """Add risky share dimension to post-decision state."""
710 post_state = super().create_post_state()
711 post_state["stigma"] = xr.DataArray(
712 self.ShareGrid, dims=["stigma"], attrs={"long_name": "risky share"}
713 )
714 return post_state
716 def create_continuation_function(self) -> ValueFuncCRRALabeled:
717 """
718 Create continuation function with optimal portfolio choice.
720 First computes continuation value over the (aNrm, stigma) grid,
721 then finds the optimal stigma for each aNrm level.
722 """
723 # Get continuation value over full (aNrm, stigma) grid
724 wfunc = super().create_continuation_function()
726 dvds = wfunc.dataset["dvds"].values
728 # Find optimal share using linear interpolation on FOC
729 crossing = np.logical_and(dvds[:, 1:] <= 0.0, dvds[:, :-1] >= 0.0)
730 share_idx = np.argmax(crossing, axis=1)
731 a_idx = np.arange(self.post_state["aNrm"].size)
733 bottom_share = self.ShareGrid[share_idx]
734 top_share = self.ShareGrid[share_idx + 1]
735 bottom_foc = dvds[a_idx, share_idx]
736 top_foc = dvds[a_idx, share_idx + 1]
738 # Linear interpolation with division-by-zero protection
739 denominator = top_foc - bottom_foc
740 fallback_mask = np.abs(denominator) <= 1e-12
741 if np.any(fallback_mask):
742 n_fallbacks = np.sum(fallback_mask)
743 warnings.warn(
744 f"Portfolio optimization used fallback interpolation for {n_fallbacks} "
745 f"grid points due to near-zero FOC difference. "
746 f"Consider refining ShareGrid for more accurate results.",
747 RuntimeWarning,
748 stacklevel=2,
749 )
750 alpha = np.where(
751 ~fallback_mask,
752 1.0 - top_foc / denominator,
753 0.5,
754 )
755 opt_share = (1.0 - alpha) * bottom_share + alpha * top_share
757 # Handle corner solutions
758 opt_share[dvds[:, -1] > 0.0] = 1.0 # Want more than 100% risky
759 opt_share[dvds[:, 0] < 0.0] = 0.0 # Want less than 0% risky
761 if not self.nat_boro_cnst:
762 # At aNrm = 0 the portfolio share is irrelevant; 1.0 is limit as a --> 0
763 opt_share[0] = 1.0
765 opt_share = xr.DataArray(
766 opt_share,
767 coords={"aNrm": self.post_state["aNrm"].values},
768 dims=["aNrm"],
769 attrs={"long_name": "optimal risky share"},
770 )
772 # Evaluate continuation at optimal share
773 v_end = wfunc.evaluate({"aNrm": self.post_state["aNrm"], "stigma": opt_share})
774 v_end = v_end.reset_coords(names="stigma")
776 wfunc = ValueFuncCRRALabeled(v_end, self.params.CRRA)
778 # Remove stigma from post_state for EGM
779 self.post_state = self.post_state.drop_vars("stigma")
781 return wfunc