Skip to content

Commit 64fe2ca

Browse files
committed
fix: set diag sparse matrix type to CSR to fix sp solve warning
1 parent 51c8e6c commit 64fe2ca

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

csaps/_sspumv.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
from typing import Literal, cast
6-
import functools
6+
from functools import partial
77

88
import numpy as np
99
from scipy.interpolate import PPoly
@@ -14,6 +14,9 @@
1414
from ._reshape import prod, to_2d
1515
from ._types import FloatNDArrayType, MultivariateDataType, UnivariateDataType
1616

17+
diags_csr = partial(sp.diags, format='csr')
18+
vpad = partial(np.pad, pad_width=[(1, 1), (0, 0)], mode='constant')
19+
1720

1821
class SplinePPForm(ISplinePPForm[np.ndarray, int], PPoly):
1922
"""The base class for univariate/multivariate spline in piecewise polynomial form
@@ -282,13 +285,13 @@ def _make_spline(x, y, w, smooth, shape, normalizedsmooth):
282285

283286
# Create diagonal sparse matrices
284287
diags_r = np.vstack((dx[1:], 2 * (dx[1:] + dx[:-1]), dx[:-1]))
285-
r = sp.spdiags(diags_r, [-1, 0, 1], pcount - 2, pcount - 2)
288+
r = sp.spdiags(diags_r, [-1, 0, 1], pcount - 2, pcount - 2, format='csr')
286289

287290
dx_recip = 1.0 / dx
288291
diags_qtw = np.vstack((dx_recip[:-1], -(dx_recip[1:] + dx_recip[:-1]), dx_recip[1:]))
289292
diags_sqrw_recip = 1.0 / np.sqrt(w)
290293

291-
qtw = sp.diags(diags_qtw, [0, 1, 2], (pcount - 2, pcount)) @ sp.diags(diags_sqrw_recip, 0, (pcount, pcount))
294+
qtw = diags_csr(diags_qtw, [0, 1, 2], (pcount - 2, pcount)) @ diags_csr(diags_sqrw_recip, 0, (pcount, pcount))
292295
qtw = qtw @ qtw.T
293296

294297
p = smooth
@@ -312,13 +315,11 @@ def _make_spline(x, y, w, smooth, shape, normalizedsmooth):
312315

313316
dx = dx[:, np.newaxis]
314317

315-
vpad = functools.partial(np.pad, pad_width=[(1, 1), (0, 0)], mode='constant')
316-
317318
d1 = np.diff(vpad(u), axis=0) / dx
318319
d2 = np.diff(vpad(d1), axis=0)
319320

320321
diags_w_recip = 1.0 / w
321-
w = sp.diags(diags_w_recip, 0, (pcount, pcount))
322+
w = diags_csr(diags_w_recip, 0, (pcount, pcount))
322323

323324
yi = y.T - (pp * w) @ d2
324325
pu = vpad(p * u)

0 commit comments

Comments
 (0)