Coverage for HARK / metric.py: 85%

53 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-07 05:16 +0000

1from warnings import warn 

2 

3import numpy as np 

4 

5 

6def distance_lists(list_a, list_b): 

7 """ 

8 If both inputs are lists, then the distance between them is the maximum 

9 distance between corresponding elements in the lists. If they differ in 

10 length, the distance is the difference in lengths. 

11 """ 

12 len_a = len(list_a) 

13 len_b = len(list_b) 

14 if len_a == len_b: 

15 return np.max([distance_metric(list_a[n], list_b[n]) for n in range(len_a)]) 

16 return np.abs(len_a - len_b) 

17 

18 

19def distance_dicts(dict_a, dict_b): 

20 """ 

21 If both inputs are dictionaries, call distance on the list of its elements. 

22 If both dictionaries have matching distance_criteria entries, compare only those keys. 

23 If they do not have the same keys, return 1000 and raise a warning. Nothing 

24 in HARK should ever hit that warning. 

25 """ 

26 # Check whether the dictionaries have matching distance_criteria 

27 if ("distance_criteria" in dict_a.keys()) and ( 

28 "distance_criteria" in dict_b.keys() 

29 ): 

30 crit_a = dict_a["distance_criteria"] 

31 crit_b = dict_b["distance_criteria"] 

32 if len(crit_a) == len(crit_b): 

33 check = [crit_a[j] == crit_b[j] for j in range(len(crit_a))] 

34 if np.all(check): 

35 # Compare only their distance_criteria 

36 return np.max( 

37 [distance_metric(dict_a[key], dict_b[key]) for key in crit_a] 

38 ) 

39 

40 # Otherwise, compare all their keys 

41 if set(dict_a.keys()) != set(dict_b.keys()): 

42 warn("Dictionaries with keys that do not match are being compared.") 

43 return 1000.0 

44 return np.max([distance_metric(dict_a[key], dict_b[key]) for key in dict_a.keys()]) 

45 

46 

47def distance_arrays(arr_a, arr_b): 

48 """ 

49 If both inputs are array-like, return the maximum absolute difference b/w 

50 corresponding elements (if same shape). If they don't even have the same number 

51 of dimensions, return 10000 times the difference in dimensions. If they have 

52 the same number of dimensions but different shapes, return the sum of differences 

53 in size for each dimension. 

54 """ 

55 shape_A = arr_a.shape 

56 shape_B = arr_b.shape 

57 if shape_A == shape_B: 

58 return np.max(np.abs(arr_a - arr_b)) 

59 

60 if len(shape_A) != len(shape_B): 

61 return 10000 * np.abs(len(shape_A) - len(shape_B)) 

62 

63 dim_diffs = np.abs(np.array(shape_A) - np.array(shape_B)) 

64 return np.sum(dim_diffs) 

65 

66 

67def distance_class(cls_a, cls_b): 

68 """ 

69 If none of the above cases, but the objects are of the same class, call the 

70 distance method of one on the other. 

71 """ 

72 if isinstance(cls_a, type(lambda: None)): 

73 warn("Cannot compare lambda functions. Returning large distance.") 

74 return 1000.0 

75 return cls_a.distance(cls_b) 

76 

77 

78def distance_metric(thing_a, thing_b): 

79 """ 

80 A "universal distance" metric that can be used as a default in many settings. 

81 

82 Parameters 

83 ---------- 

84 thing_a : object 

85 A generic object. 

86 thing_b : object 

87 Another generic object. 

88 

89 Returns: 

90 ------------ 

91 distance : float 

92 The "distance" between thing_a and thing_b. 

93 """ 

94 

95 # If both inputs are numbers, return their difference 

96 if isinstance(thing_a, (int, float)) and isinstance(thing_b, (int, float)): 

97 return np.abs(thing_a - thing_b) 

98 

99 if isinstance(thing_a, list) and isinstance(thing_b, list): 

100 return distance_lists(thing_a, thing_b) 

101 

102 if isinstance(thing_a, np.ndarray) and isinstance(thing_b, np.ndarray): 

103 return distance_arrays(thing_a, thing_b) 

104 

105 if isinstance(thing_a, dict) and isinstance(thing_b, dict): 

106 return distance_dicts(thing_a, thing_b) 

107 

108 if isinstance(thing_a, type(thing_b)): 

109 return distance_class(thing_a, thing_b) 

110 

111 # Failsafe: the inputs are very far apart 

112 return 1000.0 

113 

114 

115class MetricObject: 

116 """ 

117 A superclass for object classes in HARK. Comes with two useful methods: 

118 a generic/universal distance method and an attribute assignment method. 

119 """ 

120 

121 distance_criteria = [] # This should be overwritten by subclasses. 

122 

123 def distance(self, other): 

124 """ 

125 A generic distance method, which requires the existence of an attribute 

126 called distance_criteria, giving a list of strings naming the attributes 

127 to be considered by the distance metric. 

128 

129 Parameters 

130 ---------- 

131 other : object 

132 Another object to compare this instance to. 

133 

134 Returns 

135 ------- 

136 (unnamed) : float 

137 The distance between this object and another, using the "universal 

138 distance" metric. 

139 """ 

140 try: 

141 return np.max( 

142 [ 

143 distance_metric(getattr(self, attr_name), getattr(other, attr_name)) 

144 for attr_name in self.distance_criteria 

145 ] 

146 ) 

147 except (AttributeError, ValueError): 

148 return 1000.0