Coverage for HARK/metric.py: 93%

41 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-02 05:14 +0000

1from warnings import warn 

2 

3import numpy as np 

4 

5 

6def distance_lists(list_a, list_b): 

7 """ 

8 If both inputs are lists, then the distance between them is the maximum 

9 distance between corresponding elements in the lists. If they differ in 

10 length, the distance is the difference in lengths. 

11 """ 

12 len_a = len(list_a) 

13 len_b = len(list_b) 

14 if len_a == len_b: 

15 return np.max([distance_metric(list_a[n], list_b[n]) for n in range(len_a)]) 

16 return np.abs(len_a - len_b) 

17 

18 

19def distance_dicts(dict_a, dict_b): 

20 """ 

21 If both inputs are dictionaries, call distance on the list of its elements. 

22 If they do not have the same keys, return 1000 and raise a warning. Nothing 

23 in HARK should ever hit that warning. 

24 """ 

25 if set(dict_a.keys()) != set(dict_b.keys()): 

26 warn("Dictionaries with keys that do not match are being compared.") 

27 return 1000.0 

28 return np.max([distance_metric(dict_a[key], dict_b[key]) for key in dict_a.keys()]) 

29 

30 

31def distance_arrays(arr_a, arr_b): 

32 """ 

33 If both inputs are array-like, return the maximum absolute difference b/w 

34 corresponding elements (if same shape); return difference in size if shapes 

35 do not align. 

36 """ 

37 if arr_a.shape == arr_b.shape: 

38 return np.max(np.abs(arr_a - arr_b)) 

39 return np.abs(arr_a.size - arr_b.size) 

40 

41 

42def distance_class(cls_a, cls_b): 

43 """ 

44 If none of the above cases, but the objects are of the same class, call the 

45 distance method of one on the other. 

46 """ 

47 if isinstance(cls_a, type(lambda: None)): 

48 warn("Cannot compare lambda functions. Returning large distance.") 

49 return 1000.0 

50 return cls_a.distance(cls_b) 

51 

52 

53def distance_metric(thing_a, thing_b): 

54 """ 

55 A "universal distance" metric that can be used as a default in many settings. 

56 

57 Parameters 

58 ---------- 

59 thing_a : object 

60 A generic object. 

61 thing_b : object 

62 Another generic object. 

63 

64 Returns: 

65 ------------ 

66 distance : float 

67 The "distance" between thing_a and thing_b. 

68 """ 

69 

70 # If both inputs are numbers, return their difference 

71 if isinstance(thing_a, (int, float)) and isinstance(thing_b, (int, float)): 

72 return np.abs(thing_a - thing_b) 

73 

74 if isinstance(thing_a, list) and isinstance(thing_b, list): 

75 return distance_lists(thing_a, thing_b) 

76 

77 if isinstance(thing_a, np.ndarray) and isinstance(thing_b, np.ndarray): 

78 return distance_arrays(thing_a, thing_b) 

79 

80 if isinstance(thing_a, dict) and isinstance(thing_b, dict): 

81 return distance_dicts(thing_a, thing_b) 

82 

83 if isinstance(thing_a, type(thing_b)): 

84 return distance_class(thing_a, thing_b) 

85 

86 # Failsafe: the inputs are very far apart 

87 return 1000.0 

88 

89 

90class MetricObject: 

91 """ 

92 A superclass for object classes in HARK. Comes with two useful methods: 

93 a generic/universal distance method and an attribute assignment method. 

94 """ 

95 

96 distance_criteria = [] # This should be overwritten by subclasses. 

97 

98 def distance(self, other): 

99 """ 

100 A generic distance method, which requires the existence of an attribute 

101 called distance_criteria, giving a list of strings naming the attributes 

102 to be considered by the distance metric. 

103 

104 Parameters 

105 ---------- 

106 other : object 

107 Another object to compare this instance to. 

108 

109 Returns 

110 ------- 

111 (unnamed) : float 

112 The distance between this object and another, using the "universal 

113 distance" metric. 

114 """ 

115 try: 

116 return np.max( 

117 [ 

118 distance_metric(getattr(self, attr_name), getattr(other, attr_name)) 

119 for attr_name in self.distance_criteria 

120 ] 

121 ) 

122 except (AttributeError, ValueError): 

123 return 1000.0