Coverage for HARK / dcegm.py: 100%
32 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-07 05:16 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-07 05:16 +0000
1"""
2Functions for working with the discrete-continuous EGM (DCEGM) algorithm as
3described in "The endogenous grid method for discrete-continuous dynamic
4choice models with (or without) taste shocks" by Iskhakov et al. (2016)
5[https://doi.org/10.3982/QE643 and ijrsDCEGM2017 in our Zotero]
7Example can be found in https://github.com/econ-ark/DemARK/blob/master/notebooks/DCEGM-Upper-Envelope.ipynb
8"""
10import numpy as np
11from interpolation import interp
12from numba import njit
15@njit("Tuple((float64,float64))(float64[:], float64[:], float64[:])", cache=True)
16def calc_linear_crossing(x, left_y, right_y): # pragma: no cover
17 """
18 Computes the intersection between two line segments, defined by two common
19 x points, and the values of both segments at both x points. The intercept
20 is only found if it happens between the two x coordinates.
22 Parameters
23 ----------
24 x : np.array, length 2
25 The two common x coordinates. x[0] < x[1] is assumed
26 left_y : np.array, length 2
27 y values of the two segments at x[0]
28 right_y : np.array, length 2
29 y values of the two segments at x[1]
31 Returns
32 -------
33 (m_int, v_int): a tuple with the corrdinates of the intercept.
34 if there is no intercept in the interval [x[0],x[1]], (nan,nan)
36 """
38 # Find slopes of both segments
39 delta_x = x[1] - x[0]
40 s0 = (right_y[0] - left_y[0]) / delta_x
41 s1 = (right_y[1] - left_y[1]) / delta_x
43 if s1 == s0:
44 # If they have the same slope, they can only cross if they perfectly
45 # overlap. In this case, return the left extreme
46 if left_y[0] == left_y[1]:
47 return (x[0], left_y[0])
48 else:
49 return (np.nan, np.nan)
50 else:
51 # Find h where intercept happens at m[0] + h
52 h = (left_y[0] - left_y[1]) / (s1 - s0)
54 # Return the crossing if it happens between the given x coordinates.
55 # If not, return nan
56 if h >= 0 and h <= (x[1] - x[0]):
57 return (x[0] + h, left_y[0] + h * s0)
58 else:
59 return (np.nan, np.nan)
62@njit(
63 "Tuple((float64[:,:],int64[:,:]))(float64[:], float64[:,:], int64[:])", cache=True
64)
65def calc_cross_points(x_grid, cond_ys, opt_idx): # pragma: no cover
66 """
67 Given a grid of x values, a matrix with the values of different line segments
68 evaluated on the x grid, and a vector indicating the choice of a segment
69 at each grid point, this function computes the coordinates of the
70 crossing points that happen when the choice of segment changes.
72 The purpose of the function is to take (x,y) lines that are defined piece-
73 wise, and at every gap in x where the "piece" changes, find the point where
74 the two "pieces" involved in the change would intercept.
76 Adding these points to our piece-wise approximation will improve it, since
77 it will eliminate interpolation between points that belong to different
78 "pieces".
80 Parameters
81 ----------
82 x_grid : np.array
83 Grid of x values.
84 cond_ys : 2-D np.array. Must have as many rows as possible segments, and
85 len(x_grid) columns.
86 cond_ys[i,j] contains the value of segment (or "piece") i at x_grid[j].
87 Entries can be nan if the segment is not defined at a particular point.
88 opt_idx : np.array of indices, must have length len(x_grid).
89 Indicates what segment is to be used at each x gridpoint. The value
90 of the piecewise function at x_grid[k] is cond_ys[opt_idx[k],k].
92 Returns
93 -------
94 xing_points: 2D np.array
95 Crossing points, each in its own row as an [x, y] pair.
97 segments: np.array with two columns and as many rows as xing points.
98 Each row represents a crossing point. The first column is the index
99 of the segment used to the left, and the second, to the right.
100 """
102 # Compute differences in the optimal index,
103 # to find positions of segment-changes
104 diff_max = np.append(opt_idx[1:] - opt_idx[:-1], 0)
105 idx_change = np.where(diff_max != 0)[0]
107 # If no changes, return empty arrays
108 if len(idx_change) == 0:
109 points = np.zeros((0, 2), dtype=np.float64)
110 segments = np.zeros((0, 2), dtype=np.int64)
111 return points, segments
113 else:
114 # To find the crossing points we need the extremes of the intervals in
115 # which they happen, and the two candidate segments evaluated in both
116 # extremes. switch_interv[0] has the left points and switch_interv[1]
117 # the right points of these intervals.
118 switch_interv = np.stack((x_grid[idx_change], x_grid[idx_change + 1]), axis=1)
120 # Store the indices of the two segments involved in the changes.
121 # Columns are [0]: left extreme, [1]: right extreme,
122 # Rows are individual crossing points.
123 segments = np.stack((opt_idx[idx_change], opt_idx[idx_change + 1]), axis=1)
125 # Get values of segments on both the left and the right
126 left_y = np.zeros_like(segments, dtype=np.float64)
127 right_y = np.zeros_like(segments, dtype=np.float64)
129 for i, idx in enumerate(idx_change):
130 left_y[i, 0] = cond_ys[segments[i, 0], idx]
131 left_y[i, 1] = cond_ys[segments[i, 1], idx]
133 right_y[i, 0] = cond_ys[segments[i, 0], idx + 1]
134 right_y[i, 1] = cond_ys[segments[i, 1], idx + 1]
136 # A valid crossing must have both switching segments well defined at the
137 # encompassing gridpoints. Filter those that do not.
138 valid = np.repeat(False, len(idx_change))
139 for i in range(len(valid)):
140 valid[i] = np.logical_and(
141 ~np.isnan(left_y[i, :]).any(), ~np.isnan(right_y[i, :]).any()
142 )
144 if not np.any(valid):
145 # If there are no valid crossings, return empty arrays.
146 points = np.zeros((0, 2), dtype=np.float64)
147 segments = np.zeros((0, 2), dtype=np.int64)
148 return points, segments
150 else:
151 # Otherwise, subset valid crossings
152 segments = segments[valid, :]
153 switch_interv = switch_interv[valid, :]
154 left_y = left_y[valid, :]
155 right_y = right_y[valid, :]
157 # Find crossing points.
158 xing_points = [
159 calc_linear_crossing(switch_interv[i, :], left_y[i, :], right_y[i, :])
160 for i in range(segments.shape[0])
161 ]
163 xing_array = np.asarray(xing_points)
165 return xing_array, segments
168@njit("Tuple((int64[:],int64[:]))(float64[:], float64[:])", cache=True)
169def calc_nondecreasing_segments(x, y): # pragma: no cover
170 """
171 Given a sequence of (x,y) points, this function finds the start and end
172 indices of its largest non-decreasing segments.
174 A non-decreasing segment is a sub-sequence of points
175 {(x_0, y_0),...,(x_n,y_n)} such that for all 0 <= i,j <= n,
176 If j>=i then x_j >= x_i and y_j >= y_i
178 Parameters
179 ----------
180 x : 1D np.array of floats
181 x coordinates of the sequence of points.
182 y : 1D np.array of floats
183 y coordinates of the sequence of points.
185 Returns
186 -------
187 starts : 1D np.array of ints
188 Indices where a new non-decreasing segment starts.
189 ends : 1D np.array of ints
190 Indices where a non-decreasing segment ends.
192 """
194 if len(x) == 0 or len(y) == 0 or len(y) != len(x):
195 raise Exception("x and y must be non-empty arrays of the same size.")
197 # Find points with decreases in x or y
198 nd = np.logical_or(x[1:] < x[:-1], y[1:] < y[:-1])
199 idx = np.where(nd)[0]
201 # Find the starts and ends of non-decreasing segments. Convoluted because
202 # numba does not support np.concatenate.
203 starts = np.zeros(len(idx) + 1, dtype=np.int64)
204 ends = np.zeros(len(idx) + 1, dtype=np.int64)
206 starts[1:] = idx + 1
208 ends[:-1] = idx
209 ends[-1] = len(y) - 1
211 return starts, ends
214def upper_envelope(segments, calc_crossings=True):
215 """
216 Finds the upper envelope of a list of non-decreasing segments
218 Parameters
219 ----------
220 segments : list of segments. Segments are tuples of arrays, with item[0]
221 containing the x coordninates and item[1] the y coordinates of the
222 points that confrom the segment item.
223 calc_crossings : Bool, optional
224 Indicates whether the crossing points at which the "upper" segment
225 changes should be computed. The default is True.
227 Returns
228 -------
229 x : np.array of floats
230 x coordinates of the points that conform the upper envelope.
231 y : np.array of floats
232 y coordinates of the points that conform the upper envelope.
233 env_inds : np array of ints
234 Array of the same length as x and y. It indicates which of the
235 provided segments is the "upper" one at every returned (x,y) point.
237 """
238 n_seg = len(segments)
240 # Collect the x points of all segments in an ordered array, removing duplicates
241 x = np.unique(np.concatenate([x[0] for x in segments]))
243 # Interpolate all segments on every x point, without extrapolating.
244 y_cond = np.zeros((n_seg, len(x)))
245 for i in range(n_seg):
246 if len(segments[i][0]) == 1:
247 # If the segment is a single point, we can only know its value
248 # at the observed point.
249 row = np.repeat(np.nan, len(x))
250 ind = np.searchsorted(x, segments[i][0][0])
251 row[ind] = segments[i][1][0]
252 else:
253 # If the segment has more than one point, we can interpolate
254 row = interp(segments[i][0], segments[i][1], x)
255 extrap = np.logical_or(x < segments[i][0][0], x > segments[i][0][-1])
256 row[extrap] = np.nan
258 y_cond[i, :] = row
260 # Take the maximum to get the upper envelope.
261 env_inds = np.nanargmax(y_cond, 0)
262 y = y_cond[env_inds, range(len(x))]
264 # Get crossing points if needed
265 if calc_crossings:
266 xing_points, xing_lines = calc_cross_points(x, y_cond, env_inds)
268 if len(xing_points) > 0:
269 # Extract x and y coordinates
270 xing_x = np.array([p[0] for p in xing_points])
271 xing_y = np.array([p[1] for p in xing_points])
273 # To capture the discontinuity, we'll add the successors of xing_x to
274 # the grid
275 succ = np.nextafter(xing_x, xing_x + 1)
277 # Collect points to add to grids
278 xtra_x = np.concatenate([xing_x, succ])
279 # if there is a crossing, y will be the same on both segments
280 xtra_y = np.concatenate([xing_y, xing_y])
281 xtra_lines = np.concatenate([xing_lines[:, 0], xing_lines[:, 1]])
283 # Insert them
284 idx = np.searchsorted(x, xtra_x)
285 x = np.insert(x, idx, xtra_x)
286 y = np.insert(y, idx, xtra_y)
287 env_inds = np.insert(env_inds, idx, xtra_lines)
289 return x, y, env_inds