33"""
44
55from typing import Literal , cast
6- import functools
6+ from functools import partial
77
88import numpy as np
99from scipy .interpolate import PPoly
1414from ._reshape import prod , to_2d
1515from ._types import FloatNDArrayType , MultivariateDataType , UnivariateDataType
1616
17+ diags_csr = partial (sp .diags , format = 'csr' )
18+ vpad = partial (np .pad , pad_width = [(1 , 1 ), (0 , 0 )], mode = 'constant' )
19+
1720
1821class SplinePPForm (ISplinePPForm [np .ndarray , int ], PPoly ):
1922 """The base class for univariate/multivariate spline in piecewise polynomial form
@@ -282,13 +285,13 @@ def _make_spline(x, y, w, smooth, shape, normalizedsmooth):
282285
283286 # Create diagonal sparse matrices
284287 diags_r = np .vstack ((dx [1 :], 2 * (dx [1 :] + dx [:- 1 ]), dx [:- 1 ]))
285- r = sp .spdiags (diags_r , [- 1 , 0 , 1 ], pcount - 2 , pcount - 2 )
288+ r = sp .spdiags (diags_r , [- 1 , 0 , 1 ], pcount - 2 , pcount - 2 , format = 'csr' )
286289
287290 dx_recip = 1.0 / dx
288291 diags_qtw = np .vstack ((dx_recip [:- 1 ], - (dx_recip [1 :] + dx_recip [:- 1 ]), dx_recip [1 :]))
289292 diags_sqrw_recip = 1.0 / np .sqrt (w )
290293
291- qtw = sp . diags (diags_qtw , [0 , 1 , 2 ], (pcount - 2 , pcount )) @ sp . diags (diags_sqrw_recip , 0 , (pcount , pcount ))
294+ qtw = diags_csr (diags_qtw , [0 , 1 , 2 ], (pcount - 2 , pcount )) @ diags_csr (diags_sqrw_recip , 0 , (pcount , pcount ))
292295 qtw = qtw @ qtw .T
293296
294297 p = smooth
@@ -312,13 +315,11 @@ def _make_spline(x, y, w, smooth, shape, normalizedsmooth):
312315
313316 dx = dx [:, np .newaxis ]
314317
315- vpad = functools .partial (np .pad , pad_width = [(1 , 1 ), (0 , 0 )], mode = 'constant' )
316-
317318 d1 = np .diff (vpad (u ), axis = 0 ) / dx
318319 d2 = np .diff (vpad (d1 ), axis = 0 )
319320
320321 diags_w_recip = 1.0 / w
321- w = sp . diags (diags_w_recip , 0 , (pcount , pcount ))
322+ w = diags_csr (diags_w_recip , 0 , (pcount , pcount ))
322323
323324 yi = y .T - (pp * w ) @ d2
324325 pu = vpad (p * u )
0 commit comments