Coverage for HARK/mat_methods.py: 100%

3 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-02 05:14 +0000

1from typing import List 

2 

3import numpy as np 

4from numba import njit 

5 

6 

7@njit 

8def ravel_index( 

9 ind_mat: np.ndarray, dims: np.ndarray 

10) -> np.ndarray: # pragma: no cover 

11 """ 

12 This function takes a matrix of indices, and a vector of dimensions, and 

13 returns a vector of corresponding flattened indices 

14 """ 

15 # Initialize indices 

16 r_ind = np.zeros(ind_mat.shape[1:], dtype=np.int64) 

17 # Find index multipliers 

18 cdims = np.concatenate((np.cumprod(dims[1:][::-1])[::-1], np.array([1]))) 

19 for i, cdim in enumerate(cdims): 

20 r_ind += ind_mat[i] * cdim 

21 

22 return r_ind 

23 

24 

25@njit 

26def multidim_get_lower_index( 

27 points: np.ndarray, grids: List[np.ndarray], dims: np.ndarray 

28) -> np.ndarray: # pragma: no cover 

29 """ 

30 Get the lower index for each point in a multidimensional grid. 

31 

32 Parameters 

33 ---------- 

34 points : np.ndarray 

35 The points for which to find the lower index. 

36 grids : List[np.ndarray] 

37 The grids for each dimension. 

38 dims : np.ndarray 

39 The dimensions of the grids. 

40 

41 Returns 

42 ------- 

43 np.ndarray 

44 The indices of the lower grid point for each point in each dimension. 

45 """ 

46 inds = np.empty_like(points, dtype=np.int64) 

47 for i, grid in enumerate(grids): 

48 inds[:, i] = np.minimum( 

49 np.searchsorted(grid, points[:, i], side="right") - 1, dims[i] - 2 

50 ) 

51 

52 return inds 

53 

54 

55@njit 

56def fwd_and_bwd_diffs( 

57 points: np.ndarray, grids: List[np.ndarray], inds: np.ndarray 

58) -> np.ndarray: # pragma: no cover 

59 """ 

60 Computes backward and forward differences for each point in points for each grid in grids. 

61 

62 Parameters 

63 ---------- 

64 points : np.ndarray 

65 The points for which to compute the differences. 

66 grids : List[np.ndarray] 

67 The grids for each dimension. 

68 inds : np.ndarray 

69 The indices of the lower grid point for each point in each dimension. 

70 

71 Returns 

72 ------- 

73 np.ndarray 

74 A (2, ndim, npoints) matrix in which [:,i,:] is the backward and forward difference for the ith dimension. 

75 """ 

76 # Preallocate 

77 diffs = np.empty((2, points.shape[1], points.shape[0])) 

78 

79 for i, grid in enumerate(grids): 

80 # Backward 

81 diffs[0, i, :] = points[:, i] - grid[inds[i, :]] 

82 # Forward 

83 diffs[1, i, :] = grid[inds[i, :] + 1] - points[:, i] 

84 

85 return diffs 

86 

87 

88@njit 

89def sum_weights( 

90 weights: np.ndarray, dims: np.ndarray, add_inds: np.ndarray 

91) -> np.ndarray: # pragma: no cover 

92 """ 

93 Sums the weights that correspond to each point in the grid. 

94 

95 Parameters 

96 ---------- 

97 weights : np.ndarray 

98 The weights to be summed. 

99 dims : np.ndarray 

100 The dimensions of the grid. 

101 add_inds : np.ndarray 

102 The indices of the points in the grid to which the weights correspond. 

103 

104 Returns 

105 ------- 

106 np.ndarray 

107 The sum of the weights for each point in the grid (flattened). 

108 """ 

109 # Initialize arary to hold weights. 

110 distr = np.zeros(np.prod(dims), dtype=np.float64) 

111 

112 # Add weights point by point 

113 for i in range(weights.shape[1]): 

114 distr[add_inds[:, i]] += weights[:, i] 

115 

116 return distr 

117 

118 

119@njit 

120def denominators( 

121 inds: np.ndarray, grids: List[np.ndarray] 

122) -> np.ndarray: # pragma: no cover 

123 """ 

124 This function computes the denominators of the interpolation weights, 

125 which are the areas of the hypercubes of the grid that contain the points. 

126 

127 Parameters 

128 ---------- 

129 inds : np.ndarray 

130 The indices of the lower grid point for each point in each dimension. 

131 grids : List[np.ndarray] 

132 The grids for each dimension. 

133 

134 Returns 

135 ------- 

136 np.ndarray 

137 The denominators of the interpolation weights. 

138 """ 

139 denoms = np.ones(inds.shape[1], dtype=np.float64) 

140 for i, g in enumerate(grids): 

141 d = np.diff(g) 

142 denoms *= d[inds[i, :]] 

143 return denoms 

144 

145 

146@njit 

147def get_combinations(ndim: int) -> np.ndarray: # pragma: no cover 

148 """ 

149 Produces an array with all the 2**ndim possible combinations of 0s and 1s. 

150 This is used later to generate all the possible combinations of backward and forward differences. 

151 

152 Parameters 

153 ---------- 

154 ndim : int 

155 The number of dimensions. 

156 

157 Returns 

158 ------- 

159 np.ndarray 

160 An array with all the 2**ndim possible combinations of 0s and 1s. 

161 """ 

162 bits = np.zeros((2**ndim, ndim), dtype=np.int64) 

163 for i in range(ndim): 

164 col = (ndim - 1) - i 

165 for j in range(2**ndim): 

166 bits[j, col] = (j >> i) & 1 

167 return bits 

168 

169 

170@njit 

171def numerators( 

172 diffs: np.ndarray, comb_inds: np.ndarray, ndims: int, npoints: int 

173) -> np.ndarray: # pragma: no cover 

174 """ 

175 Finds the numerators of the interpolation weights, which are the areas of the hypercubes 

176 formed by the points and the grid points that contain them. 

177 

178 Parameters 

179 ---------- 

180 diffs : np.ndarray 

181 A (2, ndim, npoints) that contains the forward and backward differences of point coordinates. 

182 and the grid points that contain them along every dimension. 

183 comb_inds : np.ndarray 

184 An array with all the 2**ndim possible combinations of 0s and 1s (fwd and bwd differences). 

185 ndims : int 

186 The number of dimensions. 

187 npoints : int 

188 The number of points. 

189 

190 Returns 

191 ------- 

192 np.ndarray 

193 The numerators of the interpolation weights. 

194 """ 

195 numers = np.ones((2**ndims, npoints), dtype=np.float64) 

196 for i in range(2**ndims): 

197 for d, j in enumerate(comb_inds[i]): 

198 numers[i, :] *= diffs[j, d, :] 

199 

200 return numers 

201 

202 

203@njit 

204def mass_to_grid( 

205 points: np.ndarray, mass: np.ndarray, grids: List[np.ndarray] 

206) -> np.ndarray: # pragma: no cover 

207 """ 

208 Distributes the mass of a set of R^n points to a rectangular R^n grid, 

209 following the 'lottery' method. 

210 

211 Parameters 

212 ---------- 

213 points : np.ndarray 

214 shape = (#points, #dims) The points to be distributed. 

215 mass : np.ndarray 

216 shape = (#points) The mass of each point. 

217 grids : List[np.ndarray] 

218 The grids for each dimension. 

219 

220 Returns 

221 ------- 

222 np.ndarray 

223 The mass of each point in the grid. (flattened). 

224 """ 

225 dims = np.array([len(g) for g in grids]) 

226 ndims = len(grids) 

227 npoints = points.shape[0] 

228 

229 # Trim points to maximum and minimum of grids 

230 grid_inf_lims = np.expand_dims(np.array([x[0] for x in grids]), 0) 

231 grid_sup_lims = np.expand_dims(np.array([x[-1] for x in grids]), 0) 

232 points = np.clip(points, grid_inf_lims, grid_sup_lims) 

233 

234 # Find lower indices along every dimension 

235 inds = multidim_get_lower_index(points, grids, dims).T 

236 

237 # Forward and backward differences 

238 diffs = fwd_and_bwd_diffs(points, grids, inds) 

239 

240 # Matrix with combinations of forward and backward differencess 

241 comb_inds = get_combinations(len(grids)) 

242 

243 # Find denominators 

244 numers = numerators(diffs, comb_inds, ndims, npoints) 

245 denoms = denominators(inds, grids) 

246 

247 # Multiply the ndim differences to find areas 

248 fact = mass / denoms 

249 

250 # Weights add up to 1 

251 weights = numers * np.expand_dims(fact, 0) 

252 

253 # A (ndim, 2**ndim, npoints) matrix in which [:,:,n] nth row has 

254 # the indices where we should add weights[:,n] 

255 add_inds = np.expand_dims(inds, axis=1) + (1 - np.expand_dims(comb_inds.T, -1)) 

256 

257 # Make indices unidimensional (to not do *inds in multidim matrices with numba) 

258 add_inds = ravel_index(add_inds, dims) 

259 distr = sum_weights(weights, dims, add_inds) 

260 

261 return distr