Coverage for HARK / Labeled / config.py: 91%

53 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-25 05:22 +0000

1""" 

2Configuration management for labeled consumption-saving models. 

3 

4This module provides immutable configuration objects for model parameters, 

5replacing the module-level dict mutation pattern with a more robust approach. 

6""" 

7 

8from __future__ import annotations 

9 

10from copy import deepcopy 

11from dataclasses import dataclass, field 

12from typing import Any, Callable 

13 

14from HARK.ConsumptionSaving.ConsIndShockModel import ( 

15 IndShockConsumerType_aXtraGrid_default, 

16 init_idiosyncratic_shocks, 

17 init_perfect_foresight, 

18) 

19from HARK.ConsumptionSaving.ConsPortfolioModel import init_portfolio 

20from HARK.ConsumptionSaving.ConsRiskyAssetModel import ( 

21 IndShockRiskyAssetConsumerType_constructor_default, 

22 init_risky_asset, 

23) 

24from HARK.utilities import make_assets_grid 

25 

26from .factories import ( 

27 make_labeled_inc_shk_dstn, 

28 make_labeled_risky_dstn, 

29 make_labeled_shock_dstn, 

30 make_solution_terminal_labeled, 

31) 

32 

33__all__ = [ 

34 "ModelConfig", 

35 "PERF_FORESIGHT_CONFIG", 

36 "IND_SHOCK_CONFIG", 

37 "RISKY_ASSET_CONFIG", 

38 "PORTFOLIO_CONFIG", 

39 "get_config", 

40] 

41 

42 

43@dataclass(frozen=True) 

44class ModelConfig: 

45 """ 

46 Immutable configuration for a labeled consumption model. 

47 

48 This class represents a complete configuration including base parameters 

49 and constructor functions. Configurations can inherit from parent configs 

50 through the parent field. 

51 

52 Parameters 

53 ---------- 

54 base_params : dict 

55 Model-specific parameter overrides. 

56 constructors : dict 

57 Constructor functions for computed parameters. 

58 parent : ModelConfig, optional 

59 Parent configuration to inherit from. 

60 """ 

61 

62 base_params: dict[str, Any] = field(default_factory=dict) 

63 constructors: dict[str, Callable] = field(default_factory=dict) 

64 parent: ModelConfig | None = None 

65 

66 def build_params(self) -> dict[str, Any]: 

67 """ 

68 Build complete parameter dictionary by merging with parent chain. 

69 

70 Returns 

71 ------- 

72 dict[str, Any] 

73 Complete parameter dictionary with constructors. 

74 """ 

75 if self.parent is not None: 

76 params = self.parent.build_params() 

77 else: 

78 params = {} 

79 

80 params.update(deepcopy(self.base_params)) 

81 

82 # Merge constructors 

83 if "constructors" not in params: 

84 params["constructors"] = {} 

85 

86 params["constructors"] = {**params.get("constructors", {}), **self.constructors} 

87 

88 return params 

89 

90 

91# ============================================================================= 

92# Default Configurations 

93# ============================================================================= 

94 

95# Perfect Foresight Labeled 

96_pf_base_params = deepcopy(init_idiosyncratic_shocks) 

97_pf_base_params.update(init_perfect_foresight) 

98_pf_base_params.update(IndShockConsumerType_aXtraGrid_default) 

99 

100_pf_constructors = deepcopy(init_idiosyncratic_shocks.get("constructors", {})) 

101_pf_constructors["solution_terminal"] = make_solution_terminal_labeled 

102_pf_constructors["aXtraGrid"] = make_assets_grid 

103 

104PERF_FORESIGHT_CONFIG = ModelConfig( 

105 base_params=_pf_base_params, 

106 constructors=_pf_constructors, 

107) 

108 

109# IndShock Labeled 

110_ind_shock_constructors = { 

111 "IncShkDstn": make_labeled_inc_shk_dstn, 

112} 

113 

114IND_SHOCK_CONFIG = ModelConfig( 

115 base_params={}, 

116 constructors=_ind_shock_constructors, 

117 parent=PERF_FORESIGHT_CONFIG, 

118) 

119 

120# Risky Asset Labeled 

121_risky_constructors = deepcopy(IndShockRiskyAssetConsumerType_constructor_default) 

122_risky_constructors["IncShkDstn"] = make_labeled_inc_shk_dstn 

123_risky_constructors["RiskyDstn"] = make_labeled_risky_dstn 

124_risky_constructors["ShockDstn"] = make_labeled_shock_dstn 

125_risky_constructors["solution_terminal"] = make_solution_terminal_labeled 

126if "solve_one_period" in _risky_constructors: 

127 del _risky_constructors["solve_one_period"] 

128 

129_risky_base = deepcopy(init_risky_asset) 

130# Remove solve_one_period from base_params constructors to avoid conflict with our solver 

131if "constructors" in _risky_base and "solve_one_period" in _risky_base["constructors"]: 

132 del _risky_base["constructors"]["solve_one_period"] 

133 

134RISKY_ASSET_CONFIG = ModelConfig( 

135 base_params=_risky_base, 

136 constructors=_risky_constructors, 

137) 

138 

139# Portfolio Labeled 

140_portfolio_constructors = { 

141 "IncShkDstn": make_labeled_inc_shk_dstn, 

142 "RiskyDstn": make_labeled_risky_dstn, 

143 "ShockDstn": make_labeled_shock_dstn, 

144 "solution_terminal": make_solution_terminal_labeled, 

145} 

146 

147_portfolio_base = deepcopy(init_portfolio) 

148_portfolio_base["RiskyShareFixed"] = [0.0] 

149 

150PORTFOLIO_CONFIG = ModelConfig( 

151 base_params=_portfolio_base, 

152 constructors=_portfolio_constructors, 

153) 

154 

155 

156def get_config(model_type: str) -> ModelConfig: 

157 """ 

158 Get configuration for a named model type. 

159 

160 Parameters 

161 ---------- 

162 model_type : str 

163 One of 'perfect_foresight', 'ind_shock', 'risky_asset', 'portfolio'. 

164 

165 Returns 

166 ------- 

167 ModelConfig 

168 Configuration for the specified model type. 

169 

170 Raises 

171 ------ 

172 ValueError 

173 If model_type is not recognized. 

174 """ 

175 configs = { 

176 "perfect_foresight": PERF_FORESIGHT_CONFIG, 

177 "ind_shock": IND_SHOCK_CONFIG, 

178 "risky_asset": RISKY_ASSET_CONFIG, 

179 "portfolio": PORTFOLIO_CONFIG, 

180 } 

181 

182 if model_type not in configs: 

183 raise ValueError( 

184 f"Unknown model type '{model_type}'. Available: {list(configs.keys())}" 

185 ) 

186 

187 return configs[model_type]