Coverage for HARK / dcegm.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-07 05:16 +0000

1""" 

2Functions for working with the discrete-continuous EGM (DCEGM) algorithm as 

3described in "The endogenous grid method for discrete-continuous dynamic 

4choice models with (or without) taste shocks" by Iskhakov et al. (2016) 

5[https://doi.org/10.3982/QE643 and ijrsDCEGM2017 in our Zotero] 

6 

7Example can be found in https://github.com/econ-ark/DemARK/blob/master/notebooks/DCEGM-Upper-Envelope.ipynb 

8""" 

9 

10import numpy as np 

11from interpolation import interp 

12from numba import njit 

13 

14 

15@njit("Tuple((float64,float64))(float64[:], float64[:], float64[:])", cache=True) 

16def calc_linear_crossing(x, left_y, right_y): # pragma: no cover 

17 """ 

18 Computes the intersection between two line segments, defined by two common 

19 x points, and the values of both segments at both x points. The intercept 

20 is only found if it happens between the two x coordinates. 

21 

22 Parameters 

23 ---------- 

24 x : np.array, length 2 

25 The two common x coordinates. x[0] < x[1] is assumed 

26 left_y : np.array, length 2 

27 y values of the two segments at x[0] 

28 right_y : np.array, length 2 

29 y values of the two segments at x[1] 

30 

31 Returns 

32 ------- 

33 (m_int, v_int): a tuple with the corrdinates of the intercept. 

34 if there is no intercept in the interval [x[0],x[1]], (nan,nan) 

35 

36 """ 

37 

38 # Find slopes of both segments 

39 delta_x = x[1] - x[0] 

40 s0 = (right_y[0] - left_y[0]) / delta_x 

41 s1 = (right_y[1] - left_y[1]) / delta_x 

42 

43 if s1 == s0: 

44 # If they have the same slope, they can only cross if they perfectly 

45 # overlap. In this case, return the left extreme 

46 if left_y[0] == left_y[1]: 

47 return (x[0], left_y[0]) 

48 else: 

49 return (np.nan, np.nan) 

50 else: 

51 # Find h where intercept happens at m[0] + h 

52 h = (left_y[0] - left_y[1]) / (s1 - s0) 

53 

54 # Return the crossing if it happens between the given x coordinates. 

55 # If not, return nan 

56 if h >= 0 and h <= (x[1] - x[0]): 

57 return (x[0] + h, left_y[0] + h * s0) 

58 else: 

59 return (np.nan, np.nan) 

60 

61 

62@njit( 

63 "Tuple((float64[:,:],int64[:,:]))(float64[:], float64[:,:], int64[:])", cache=True 

64) 

65def calc_cross_points(x_grid, cond_ys, opt_idx): # pragma: no cover 

66 """ 

67 Given a grid of x values, a matrix with the values of different line segments 

68 evaluated on the x grid, and a vector indicating the choice of a segment 

69 at each grid point, this function computes the coordinates of the 

70 crossing points that happen when the choice of segment changes. 

71 

72 The purpose of the function is to take (x,y) lines that are defined piece- 

73 wise, and at every gap in x where the "piece" changes, find the point where 

74 the two "pieces" involved in the change would intercept. 

75 

76 Adding these points to our piece-wise approximation will improve it, since 

77 it will eliminate interpolation between points that belong to different 

78 "pieces". 

79 

80 Parameters 

81 ---------- 

82 x_grid : np.array 

83 Grid of x values. 

84 cond_ys : 2-D np.array. Must have as many rows as possible segments, and 

85 len(x_grid) columns. 

86 cond_ys[i,j] contains the value of segment (or "piece") i at x_grid[j]. 

87 Entries can be nan if the segment is not defined at a particular point. 

88 opt_idx : np.array of indices, must have length len(x_grid). 

89 Indicates what segment is to be used at each x gridpoint. The value 

90 of the piecewise function at x_grid[k] is cond_ys[opt_idx[k],k]. 

91 

92 Returns 

93 ------- 

94 xing_points: 2D np.array 

95 Crossing points, each in its own row as an [x, y] pair. 

96 

97 segments: np.array with two columns and as many rows as xing points. 

98 Each row represents a crossing point. The first column is the index 

99 of the segment used to the left, and the second, to the right. 

100 """ 

101 

102 # Compute differences in the optimal index, 

103 # to find positions of segment-changes 

104 diff_max = np.append(opt_idx[1:] - opt_idx[:-1], 0) 

105 idx_change = np.where(diff_max != 0)[0] 

106 

107 # If no changes, return empty arrays 

108 if len(idx_change) == 0: 

109 points = np.zeros((0, 2), dtype=np.float64) 

110 segments = np.zeros((0, 2), dtype=np.int64) 

111 return points, segments 

112 

113 else: 

114 # To find the crossing points we need the extremes of the intervals in 

115 # which they happen, and the two candidate segments evaluated in both 

116 # extremes. switch_interv[0] has the left points and switch_interv[1] 

117 # the right points of these intervals. 

118 switch_interv = np.stack((x_grid[idx_change], x_grid[idx_change + 1]), axis=1) 

119 

120 # Store the indices of the two segments involved in the changes. 

121 # Columns are [0]: left extreme, [1]: right extreme, 

122 # Rows are individual crossing points. 

123 segments = np.stack((opt_idx[idx_change], opt_idx[idx_change + 1]), axis=1) 

124 

125 # Get values of segments on both the left and the right 

126 left_y = np.zeros_like(segments, dtype=np.float64) 

127 right_y = np.zeros_like(segments, dtype=np.float64) 

128 

129 for i, idx in enumerate(idx_change): 

130 left_y[i, 0] = cond_ys[segments[i, 0], idx] 

131 left_y[i, 1] = cond_ys[segments[i, 1], idx] 

132 

133 right_y[i, 0] = cond_ys[segments[i, 0], idx + 1] 

134 right_y[i, 1] = cond_ys[segments[i, 1], idx + 1] 

135 

136 # A valid crossing must have both switching segments well defined at the 

137 # encompassing gridpoints. Filter those that do not. 

138 valid = np.repeat(False, len(idx_change)) 

139 for i in range(len(valid)): 

140 valid[i] = np.logical_and( 

141 ~np.isnan(left_y[i, :]).any(), ~np.isnan(right_y[i, :]).any() 

142 ) 

143 

144 if not np.any(valid): 

145 # If there are no valid crossings, return empty arrays. 

146 points = np.zeros((0, 2), dtype=np.float64) 

147 segments = np.zeros((0, 2), dtype=np.int64) 

148 return points, segments 

149 

150 else: 

151 # Otherwise, subset valid crossings 

152 segments = segments[valid, :] 

153 switch_interv = switch_interv[valid, :] 

154 left_y = left_y[valid, :] 

155 right_y = right_y[valid, :] 

156 

157 # Find crossing points. 

158 xing_points = [ 

159 calc_linear_crossing(switch_interv[i, :], left_y[i, :], right_y[i, :]) 

160 for i in range(segments.shape[0]) 

161 ] 

162 

163 xing_array = np.asarray(xing_points) 

164 

165 return xing_array, segments 

166 

167 

168@njit("Tuple((int64[:],int64[:]))(float64[:], float64[:])", cache=True) 

169def calc_nondecreasing_segments(x, y): # pragma: no cover 

170 """ 

171 Given a sequence of (x,y) points, this function finds the start and end 

172 indices of its largest non-decreasing segments. 

173 

174 A non-decreasing segment is a sub-sequence of points 

175 {(x_0, y_0),...,(x_n,y_n)} such that for all 0 <= i,j <= n, 

176 If j>=i then x_j >= x_i and y_j >= y_i 

177 

178 Parameters 

179 ---------- 

180 x : 1D np.array of floats 

181 x coordinates of the sequence of points. 

182 y : 1D np.array of floats 

183 y coordinates of the sequence of points. 

184 

185 Returns 

186 ------- 

187 starts : 1D np.array of ints 

188 Indices where a new non-decreasing segment starts. 

189 ends : 1D np.array of ints 

190 Indices where a non-decreasing segment ends. 

191 

192 """ 

193 

194 if len(x) == 0 or len(y) == 0 or len(y) != len(x): 

195 raise Exception("x and y must be non-empty arrays of the same size.") 

196 

197 # Find points with decreases in x or y 

198 nd = np.logical_or(x[1:] < x[:-1], y[1:] < y[:-1]) 

199 idx = np.where(nd)[0] 

200 

201 # Find the starts and ends of non-decreasing segments. Convoluted because 

202 # numba does not support np.concatenate. 

203 starts = np.zeros(len(idx) + 1, dtype=np.int64) 

204 ends = np.zeros(len(idx) + 1, dtype=np.int64) 

205 

206 starts[1:] = idx + 1 

207 

208 ends[:-1] = idx 

209 ends[-1] = len(y) - 1 

210 

211 return starts, ends 

212 

213 

214def upper_envelope(segments, calc_crossings=True): 

215 """ 

216 Finds the upper envelope of a list of non-decreasing segments 

217 

218 Parameters 

219 ---------- 

220 segments : list of segments. Segments are tuples of arrays, with item[0] 

221 containing the x coordninates and item[1] the y coordinates of the 

222 points that confrom the segment item. 

223 calc_crossings : Bool, optional 

224 Indicates whether the crossing points at which the "upper" segment 

225 changes should be computed. The default is True. 

226 

227 Returns 

228 ------- 

229 x : np.array of floats 

230 x coordinates of the points that conform the upper envelope. 

231 y : np.array of floats 

232 y coordinates of the points that conform the upper envelope. 

233 env_inds : np array of ints 

234 Array of the same length as x and y. It indicates which of the 

235 provided segments is the "upper" one at every returned (x,y) point. 

236 

237 """ 

238 n_seg = len(segments) 

239 

240 # Collect the x points of all segments in an ordered array, removing duplicates 

241 x = np.unique(np.concatenate([x[0] for x in segments])) 

242 

243 # Interpolate all segments on every x point, without extrapolating. 

244 y_cond = np.zeros((n_seg, len(x))) 

245 for i in range(n_seg): 

246 if len(segments[i][0]) == 1: 

247 # If the segment is a single point, we can only know its value 

248 # at the observed point. 

249 row = np.repeat(np.nan, len(x)) 

250 ind = np.searchsorted(x, segments[i][0][0]) 

251 row[ind] = segments[i][1][0] 

252 else: 

253 # If the segment has more than one point, we can interpolate 

254 row = interp(segments[i][0], segments[i][1], x) 

255 extrap = np.logical_or(x < segments[i][0][0], x > segments[i][0][-1]) 

256 row[extrap] = np.nan 

257 

258 y_cond[i, :] = row 

259 

260 # Take the maximum to get the upper envelope. 

261 env_inds = np.nanargmax(y_cond, 0) 

262 y = y_cond[env_inds, range(len(x))] 

263 

264 # Get crossing points if needed 

265 if calc_crossings: 

266 xing_points, xing_lines = calc_cross_points(x, y_cond, env_inds) 

267 

268 if len(xing_points) > 0: 

269 # Extract x and y coordinates 

270 xing_x = np.array([p[0] for p in xing_points]) 

271 xing_y = np.array([p[1] for p in xing_points]) 

272 

273 # To capture the discontinuity, we'll add the successors of xing_x to 

274 # the grid 

275 succ = np.nextafter(xing_x, xing_x + 1) 

276 

277 # Collect points to add to grids 

278 xtra_x = np.concatenate([xing_x, succ]) 

279 # if there is a crossing, y will be the same on both segments 

280 xtra_y = np.concatenate([xing_y, xing_y]) 

281 xtra_lines = np.concatenate([xing_lines[:, 0], xing_lines[:, 1]]) 

282 

283 # Insert them 

284 idx = np.searchsorted(x, xtra_x) 

285 x = np.insert(x, idx, xtra_x) 

286 y = np.insert(y, idx, xtra_y) 

287 env_inds = np.insert(env_inds, idx, xtra_lines) 

288 

289 return x, y, env_inds