Coverage for HARK / metric.py: 92%
107 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-25 05:22 +0000
1from warnings import warn
3import numpy as np
6def distance_lists(list_a, list_b):
7 """
8 If both inputs are lists, then the distance between them is the maximum
9 distance between corresponding elements in the lists. If they differ in
10 length, the distance is the difference in lengths.
11 """
12 len_a = len(list_a)
13 len_b = len(list_b)
14 if len_a == len_b:
15 return np.max([distance_metric(list_a[n], list_b[n]) for n in range(len_a)])
16 return np.abs(len_a - len_b)
19def distance_dicts(dict_a, dict_b):
20 """
21 If both inputs are dictionaries, call distance on the list of its elements.
22 If both dictionaries have matching distance_criteria entries, compare only those keys.
23 If they do not have the same keys, return 1000 and raise a warning. Nothing
24 in HARK should ever hit that warning.
25 """
26 # Check whether the dictionaries have matching distance_criteria
27 if ("distance_criteria" in dict_a.keys()) and (
28 "distance_criteria" in dict_b.keys()
29 ):
30 crit_a = dict_a["distance_criteria"]
31 crit_b = dict_b["distance_criteria"]
32 if len(crit_a) == len(crit_b):
33 check = [crit_a[j] == crit_b[j] for j in range(len(crit_a))]
34 if np.all(check):
35 # Compare only their distance_criteria
36 return np.max(
37 [distance_metric(dict_a[key], dict_b[key]) for key in crit_a]
38 )
40 # Otherwise, compare all their keys
41 if set(dict_a.keys()) != set(dict_b.keys()):
42 warn("Dictionaries with keys that do not match are being compared.")
43 return 1000.0
44 return np.max([distance_metric(dict_a[key], dict_b[key]) for key in dict_a.keys()])
47def distance_arrays(arr_a, arr_b):
48 """
49 If both inputs are array-like, return the maximum absolute difference b/w
50 corresponding elements (if same shape). If they don't even have the same number
51 of dimensions, return 10000 times the difference in dimensions. If they have
52 the same number of dimensions but different shapes, return the sum of differences
53 in size for each dimension.
54 """
55 shape_A = arr_a.shape
56 shape_B = arr_b.shape
57 if shape_A == shape_B:
58 return np.max(np.abs(arr_a - arr_b))
60 if len(shape_A) != len(shape_B):
61 return 10000 * np.abs(len(shape_A) - len(shape_B))
63 dim_diffs = np.abs(np.array(shape_A) - np.array(shape_B))
64 return np.sum(dim_diffs)
67def distance_class(cls_a, cls_b):
68 """
69 If none of the above cases, but the objects are of the same class, call the
70 distance method of one on the other.
71 """
72 if isinstance(cls_a, type(lambda: None)):
73 warn("Cannot compare lambda functions. Returning large distance.")
74 return 1000.0
75 return cls_a.distance(cls_b)
78def distance_metric(thing_a, thing_b):
79 """
80 A "universal distance" metric that can be used as a default in many settings.
82 Parameters
83 ----------
84 thing_a : object
85 A generic object.
86 thing_b : object
87 Another generic object.
89 Returns:
90 ------------
91 distance : float
92 The "distance" between thing_a and thing_b.
93 """
95 # If both inputs are numbers, return their difference
96 if isinstance(thing_a, (int, float)) and isinstance(thing_b, (int, float)):
97 return np.abs(thing_a - thing_b)
99 if isinstance(thing_a, list) and isinstance(thing_b, list):
100 return distance_lists(thing_a, thing_b)
102 if isinstance(thing_a, np.ndarray) and isinstance(thing_b, np.ndarray):
103 return distance_arrays(thing_a, thing_b)
105 if isinstance(thing_a, dict) and isinstance(thing_b, dict):
106 return distance_dicts(thing_a, thing_b)
108 if isinstance(thing_a, type(thing_b)):
109 return distance_class(thing_a, thing_b)
111 # Failsafe: the inputs are very far apart
112 return 1000.0
115def describe_metric(thing, n=0, label=None, D=100000):
116 """
117 Generate a description of an object's distance metric.
119 Parameters
120 ----------
121 thing : object
122 A generic object.
123 n : int
124 Recursive depth of this call.
125 label : str or None
126 Name/label of the thing, which might be a dictionary key, attribute name,
127 list index, etc.
128 D : int
129 Maximum recursive depth; if n > D, empty output is returned.
131 Returns
132 -------
133 desc : str
134 Description of this object's distance metric, indented 2n spaces.
135 """
136 pad = 2
137 if n > D:
138 return ""
140 if label is None:
141 desc = ""
142 else:
143 desc = pad * n * " " + "- " + label + " "
145 # If both inputs are numbers, distance is their difference
146 if isinstance(thing, (int, float)):
147 desc += "(scalar): absolute difference of values\n"
149 elif isinstance(thing, list):
150 J = len(thing)
151 desc += "(list) largest distance among:\n"
152 if n == D:
153 desc += pad * (n + 1) * " " + "SUPPRESSED OUTPUT\n"
154 else:
155 for j in range(J):
156 desc += describe_metric(thing[j], n + 1, label="[" + str(j) + "]", D=D)
158 elif isinstance(thing, np.ndarray):
159 desc += (
160 "(array"
161 + str(thing.shape)
162 + "): greatest absolute difference among elements\n"
163 )
165 elif isinstance(thing, dict):
166 if "distance_criteria" in thing.keys():
167 my_keys = thing["distance_criteria"]
168 else:
169 my_keys = thing.keys()
170 desc += "(dict): largest distance among these keys:\n"
171 if n == D:
172 desc += pad * (n + 1) * " " + "SUPPRESSED OUTPUT\n"
173 else:
174 for key in my_keys:
175 try:
176 desc += describe_metric(thing[key], n + 1, label=key, D=D)
177 except:
178 desc += key + " (missing): CAN'T COMPARE\n"
180 elif isinstance(thing, MetricObject):
181 my_keys = thing.distance_criteria
182 desc += (
183 "(" + type(thing).__name__ + "): largest distance among these attributes:\n"
184 )
185 if len(my_keys) == 0:
186 desc += pad * (n + 1) * " " + "NO distance_criteria SPECIFIED\n"
187 if n == D:
188 desc += pad * (n + 1) * " " + "SUPPRESSED OUTPUT\n"
189 else:
190 for key in my_keys:
191 if hasattr(thing, key):
192 desc += describe_metric(getattr(thing, key), n + 1, label=key, D=D)
193 else:
194 desc += key + " (missing): CAN'T COMPARE\n"
196 else:
197 # Something has gone wrong
198 desc += "WARNING: INCOMPARABLE\n"
200 return desc
203class MetricObject:
204 """
205 A superclass for object classes in HARK. Comes with two useful methods:
206 a generic/universal distance method and an attribute assignment method.
207 """
209 distance_criteria = [] # This should be overwritten by subclasses.
211 def distance(self, other):
212 """
213 A generic distance method, which requires the existence of an attribute
214 called distance_criteria, giving a list of strings naming the attributes
215 to be considered by the distance metric.
217 Parameters
218 ----------
219 other : object
220 Another object to compare this instance to.
222 Returns
223 -------
224 (unnamed) : float
225 The distance between this object and another, using the "universal
226 distance" metric.
227 """
228 try:
229 return np.max(
230 [
231 distance_metric(getattr(self, attr_name), getattr(other, attr_name))
232 for attr_name in self.distance_criteria
233 ]
234 )
235 except (AttributeError, ValueError):
236 return 1000.0
238 def describe_distance(self, display=True, max_depth=None):
239 """
240 Generate a description for how this object's distance metric is computed.
241 By default, the description is printed to screen, but it can be returned.
243 Like the distance metric itself, the description is built recursively.
245 Parameters
246 ----------
247 display : bool, optional
248 Whether the description should be printed to screen (default True).
249 Otherwise, it is returned as a string.
250 max_depth : int or None
251 If specified, the maximum recursive depth of the description.
253 Returns
254 -------
255 out : str
256 Description of how this object's distance metric is computed, if
257 display=False.
258 """
259 max_depth = max_depth if max_depth is not None else np.inf
261 keys = self.distance_criteria
262 if len(keys) == 0:
263 out = "No distance criteria are specified; please name them in distance_criteria.\n"
264 else:
265 out = describe_metric(self, D=max_depth)
266 out = out[:-1]
267 if display:
268 print(out)
269 return
270 return out