Coverage for HARK / Labeled / config.py: 91%
53 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
1"""
2Configuration management for labeled consumption-saving models.
4This module provides immutable configuration objects for model parameters,
5replacing the module-level dict mutation pattern with a more robust approach.
6"""
8from __future__ import annotations
10from copy import deepcopy
11from dataclasses import dataclass, field
12from typing import Any, Callable
14from HARK.ConsumptionSaving.ConsIndShockModel import (
15 IndShockConsumerType_aXtraGrid_default,
16 init_idiosyncratic_shocks,
17 init_perfect_foresight,
18)
19from HARK.ConsumptionSaving.ConsPortfolioModel import init_portfolio
20from HARK.ConsumptionSaving.ConsRiskyAssetModel import (
21 IndShockRiskyAssetConsumerType_constructor_default,
22 init_risky_asset,
23)
24from HARK.utilities import make_assets_grid
26from .factories import (
27 make_labeled_inc_shk_dstn,
28 make_labeled_risky_dstn,
29 make_labeled_shock_dstn,
30 make_solution_terminal_labeled,
31)
33__all__ = [
34 "ModelConfig",
35 "PERF_FORESIGHT_CONFIG",
36 "IND_SHOCK_CONFIG",
37 "RISKY_ASSET_CONFIG",
38 "PORTFOLIO_CONFIG",
39 "get_config",
40]
43@dataclass(frozen=True)
44class ModelConfig:
45 """
46 Immutable configuration for a labeled consumption model.
48 This class represents a complete configuration including base parameters
49 and constructor functions. Configurations can inherit from parent configs
50 through the parent field.
52 Parameters
53 ----------
54 base_params : dict
55 Model-specific parameter overrides.
56 constructors : dict
57 Constructor functions for computed parameters.
58 parent : ModelConfig, optional
59 Parent configuration to inherit from.
60 """
62 base_params: dict[str, Any] = field(default_factory=dict)
63 constructors: dict[str, Callable] = field(default_factory=dict)
64 parent: ModelConfig | None = None
66 def build_params(self) -> dict[str, Any]:
67 """
68 Build complete parameter dictionary by merging with parent chain.
70 Returns
71 -------
72 dict[str, Any]
73 Complete parameter dictionary with constructors.
74 """
75 if self.parent is not None:
76 params = self.parent.build_params()
77 else:
78 params = {}
80 params.update(deepcopy(self.base_params))
82 # Merge constructors
83 if "constructors" not in params:
84 params["constructors"] = {}
86 params["constructors"] = {**params.get("constructors", {}), **self.constructors}
88 return params
91# =============================================================================
92# Default Configurations
93# =============================================================================
95# Perfect Foresight Labeled
96_pf_base_params = deepcopy(init_idiosyncratic_shocks)
97_pf_base_params.update(init_perfect_foresight)
98_pf_base_params.update(IndShockConsumerType_aXtraGrid_default)
100_pf_constructors = deepcopy(init_idiosyncratic_shocks.get("constructors", {}))
101_pf_constructors["solution_terminal"] = make_solution_terminal_labeled
102_pf_constructors["aXtraGrid"] = make_assets_grid
104PERF_FORESIGHT_CONFIG = ModelConfig(
105 base_params=_pf_base_params,
106 constructors=_pf_constructors,
107)
109# IndShock Labeled
110_ind_shock_constructors = {
111 "IncShkDstn": make_labeled_inc_shk_dstn,
112}
114IND_SHOCK_CONFIG = ModelConfig(
115 base_params={},
116 constructors=_ind_shock_constructors,
117 parent=PERF_FORESIGHT_CONFIG,
118)
120# Risky Asset Labeled
121_risky_constructors = deepcopy(IndShockRiskyAssetConsumerType_constructor_default)
122_risky_constructors["IncShkDstn"] = make_labeled_inc_shk_dstn
123_risky_constructors["RiskyDstn"] = make_labeled_risky_dstn
124_risky_constructors["ShockDstn"] = make_labeled_shock_dstn
125_risky_constructors["solution_terminal"] = make_solution_terminal_labeled
126if "solve_one_period" in _risky_constructors:
127 del _risky_constructors["solve_one_period"]
129_risky_base = deepcopy(init_risky_asset)
130# Remove solve_one_period from base_params constructors to avoid conflict with our solver
131if "constructors" in _risky_base and "solve_one_period" in _risky_base["constructors"]:
132 del _risky_base["constructors"]["solve_one_period"]
134RISKY_ASSET_CONFIG = ModelConfig(
135 base_params=_risky_base,
136 constructors=_risky_constructors,
137)
139# Portfolio Labeled
140_portfolio_constructors = {
141 "IncShkDstn": make_labeled_inc_shk_dstn,
142 "RiskyDstn": make_labeled_risky_dstn,
143 "ShockDstn": make_labeled_shock_dstn,
144 "solution_terminal": make_solution_terminal_labeled,
145}
147_portfolio_base = deepcopy(init_portfolio)
148_portfolio_base["RiskyShareFixed"] = [0.0]
150PORTFOLIO_CONFIG = ModelConfig(
151 base_params=_portfolio_base,
152 constructors=_portfolio_constructors,
153)
156def get_config(model_type: str) -> ModelConfig:
157 """
158 Get configuration for a named model type.
160 Parameters
161 ----------
162 model_type : str
163 One of 'perfect_foresight', 'ind_shock', 'risky_asset', 'portfolio'.
165 Returns
166 -------
167 ModelConfig
168 Configuration for the specified model type.
170 Raises
171 ------
172 ValueError
173 If model_type is not recognized.
174 """
175 configs = {
176 "perfect_foresight": PERF_FORESIGHT_CONFIG,
177 "ind_shock": IND_SHOCK_CONFIG,
178 "risky_asset": RISKY_ASSET_CONFIG,
179 "portfolio": PORTFOLIO_CONFIG,
180 }
182 if model_type not in configs:
183 raise ValueError(
184 f"Unknown model type '{model_type}'. Available: {list(configs.keys())}"
185 )
187 return configs[model_type]