Coverage for HARK / metric.py: 91%

109 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-10 06:19 +0000

1from warnings import warn 

2 

3import numpy as np 

4 

5 

6def distance_lists(list_a, list_b): 

7 """ 

8 If both inputs are lists, then the distance between them is the maximum 

9 distance between corresponding elements in the lists. If they differ in 

10 length, the distance is the difference in lengths. 

11 """ 

12 len_a = len(list_a) 

13 len_b = len(list_b) 

14 if len_a == len_b: 

15 return np.max([distance_metric(list_a[n], list_b[n]) for n in range(len_a)]) 

16 return np.abs(len_a - len_b) 

17 

18 

19def distance_dicts(dict_a, dict_b): 

20 """ 

21 If both inputs are dictionaries, call distance on the list of its elements. 

22 If both dictionaries have matching distance_criteria entries, compare only those keys. 

23 If they do not have the same keys, return 1000 and raise a warning. Nothing 

24 in HARK should ever hit that warning. 

25 """ 

26 # Check whether the dictionaries have matching distance_criteria 

27 if ("distance_criteria" in dict_a.keys()) and ( 

28 "distance_criteria" in dict_b.keys() 

29 ): 

30 crit_a = dict_a["distance_criteria"] 

31 crit_b = dict_b["distance_criteria"] 

32 if len(crit_a) == len(crit_b): 

33 check = [crit_a[j] == crit_b[j] for j in range(len(crit_a))] 

34 if np.all(check): 

35 # Compare only their distance_criteria 

36 return np.max( 

37 [distance_metric(dict_a[key], dict_b[key]) for key in crit_a] 

38 ) 

39 

40 # Otherwise, compare all their keys 

41 if set(dict_a.keys()) != set(dict_b.keys()): 

42 warn("Dictionaries with keys that do not match are being compared.") 

43 return 1000.0 

44 return np.max([distance_metric(dict_a[key], dict_b[key]) for key in dict_a.keys()]) 

45 

46 

47def distance_arrays(arr_a, arr_b): 

48 """ 

49 If both inputs are array-like, return the maximum absolute difference b/w 

50 corresponding elements (if same shape). If they don't even have the same number 

51 of dimensions, return 10000 times the difference in dimensions. If they have 

52 the same number of dimensions but different shapes, return the sum of differences 

53 in size for each dimension. 

54 """ 

55 shape_A = arr_a.shape 

56 shape_B = arr_b.shape 

57 if shape_A == shape_B: 

58 return np.max(np.abs(arr_a - arr_b)) 

59 

60 if len(shape_A) != len(shape_B): 

61 return 10000 * np.abs(len(shape_A) - len(shape_B)) 

62 

63 dim_diffs = np.abs(np.array(shape_A) - np.array(shape_B)) 

64 return np.sum(dim_diffs) 

65 

66 

67def distance_class(cls_a, cls_b): 

68 """ 

69 If none of the above cases, but the objects are of the same class, call the 

70 distance method of one on the other. 

71 """ 

72 if isinstance(cls_a, type(lambda: None)): 

73 warn("Cannot compare lambda functions. Returning large distance.") 

74 return 1000.0 

75 return cls_a.distance(cls_b) 

76 

77 

78def distance_metric(thing_a, thing_b): 

79 """ 

80 A "universal distance" metric that can be used as a default in many settings. 

81 

82 Parameters 

83 ---------- 

84 thing_a : object 

85 A generic object. 

86 thing_b : object 

87 Another generic object. 

88 

89 Returns: 

90 ------------ 

91 distance : float 

92 The "distance" between thing_a and thing_b. 

93 """ 

94 

95 # If both inputs are numbers, return their difference 

96 if isinstance(thing_a, (int, float)) and isinstance(thing_b, (int, float)): 

97 return np.abs(thing_a - thing_b) 

98 

99 if isinstance(thing_a, list) and isinstance(thing_b, list): 

100 return distance_lists(thing_a, thing_b) 

101 

102 if isinstance(thing_a, np.ndarray) and isinstance(thing_b, np.ndarray): 

103 return distance_arrays(thing_a, thing_b) 

104 

105 if isinstance(thing_a, dict) and isinstance(thing_b, dict): 

106 return distance_dicts(thing_a, thing_b) 

107 

108 if isinstance(thing_a, type(thing_b)): 

109 return distance_class(thing_a, thing_b) 

110 

111 # Failsafe: the inputs are very far apart 

112 return 1000.0 

113 

114 

115def describe_metric(thing, n=0, label=None, D=100000): 

116 """ 

117 Generate a description of an object's distance metric. 

118 

119 Parameters 

120 ---------- 

121 thing : object 

122 A generic object. 

123 n : int 

124 Recursive depth of this call. 

125 label : str or None 

126 Name/label of the thing, which might be a dictionary key, attribute name, 

127 list index, etc. 

128 D : int 

129 Maximum recursive depth; if n > D, empty output is returned. 

130 

131 Returns 

132 ------- 

133 desc : str 

134 Description of this object's distance metric, indented 2n spaces. 

135 """ 

136 pad = 2 

137 if n > D: 

138 return "" 

139 

140 if label is None: 

141 desc = "" 

142 else: 

143 desc = pad * n * " " + "- " + label + " " 

144 

145 # If both inputs are numbers, distance is their difference 

146 if isinstance(thing, (int, float)): 

147 desc += "(scalar): absolute difference of values\n" 

148 

149 elif isinstance(thing, list): 

150 J = len(thing) 

151 desc += "(list) largest distance among:\n" 

152 if n == D: 

153 desc += pad * (n + 1) * " " + "SUPPRESSED OUTPUT\n" 

154 else: 

155 for j in range(J): 

156 desc += describe_metric(thing[j], n + 1, label="[" + str(j) + "]", D=D) 

157 

158 elif isinstance(thing, np.ndarray): 

159 desc += ( 

160 "(array" 

161 + str(thing.shape) 

162 + "): greatest absolute difference among elements\n" 

163 ) 

164 

165 elif isinstance(thing, dict): 

166 if "distance_criteria" in thing.keys(): 

167 my_keys = thing["distance_criteria"] 

168 else: 

169 my_keys = thing.keys() 

170 desc += "(dict): largest distance among these keys:\n" 

171 if n == D: 

172 desc += pad * (n + 1) * " " + "SUPPRESSED OUTPUT\n" 

173 else: 

174 for key in my_keys: 

175 try: 

176 item = thing[key] 

177 except KeyError: 

178 desc += key + " (missing): CAN'T COMPARE\n" 

179 continue 

180 desc += describe_metric(item, n + 1, label=key, D=D) 

181 

182 elif isinstance(thing, MetricObject): 

183 my_keys = thing.distance_criteria 

184 desc += ( 

185 "(" + type(thing).__name__ + "): largest distance among these attributes:\n" 

186 ) 

187 if len(my_keys) == 0: 

188 desc += pad * (n + 1) * " " + "NO distance_criteria SPECIFIED\n" 

189 if n == D: 

190 desc += pad * (n + 1) * " " + "SUPPRESSED OUTPUT\n" 

191 else: 

192 for key in my_keys: 

193 if hasattr(thing, key): 

194 desc += describe_metric(getattr(thing, key), n + 1, label=key, D=D) 

195 else: 

196 desc += key + " (missing): CAN'T COMPARE\n" 

197 

198 else: 

199 # Something has gone wrong 

200 desc += "WARNING: INCOMPARABLE\n" 

201 

202 return desc 

203 

204 

205class MetricObject: 

206 """ 

207 A superclass for object classes in HARK. Comes with two useful methods: 

208 a generic/universal distance method and an attribute assignment method. 

209 """ 

210 

211 distance_criteria = [] # This should be overwritten by subclasses. 

212 

213 def distance(self, other): 

214 """ 

215 A generic distance method, which requires the existence of an attribute 

216 called distance_criteria, giving a list of strings naming the attributes 

217 to be considered by the distance metric. 

218 

219 Parameters 

220 ---------- 

221 other : object 

222 Another object to compare this instance to. 

223 

224 Returns 

225 ------- 

226 (unnamed) : float 

227 The distance between this object and another, using the "universal 

228 distance" metric. 

229 """ 

230 try: 

231 return np.max( 

232 [ 

233 distance_metric(getattr(self, attr_name), getattr(other, attr_name)) 

234 for attr_name in self.distance_criteria 

235 ] 

236 ) 

237 except (AttributeError, ValueError): 

238 return 1000.0 

239 

240 def describe_distance(self, display=True, max_depth=None): 

241 """ 

242 Generate a description for how this object's distance metric is computed. 

243 By default, the description is printed to screen, but it can be returned. 

244 

245 Like the distance metric itself, the description is built recursively. 

246 

247 Parameters 

248 ---------- 

249 display : bool, optional 

250 Whether the description should be printed to screen (default True). 

251 Otherwise, it is returned as a string. 

252 max_depth : int or None 

253 If specified, the maximum recursive depth of the description. 

254 

255 Returns 

256 ------- 

257 out : str 

258 Description of how this object's distance metric is computed, if 

259 display=False. 

260 """ 

261 max_depth = max_depth if max_depth is not None else np.inf 

262 

263 keys = self.distance_criteria 

264 if len(keys) == 0: 

265 out = "No distance criteria are specified; please name them in distance_criteria.\n" 

266 else: 

267 out = describe_metric(self, D=max_depth) 

268 out = out[:-1] 

269 if display: 

270 print(out) 

271 return 

272 return out