"""Tests for spline filtering.""" | |
import pytest | |
import numpy as np | |
from scipy._lib._array_api import assert_almost_equal | |
from scipy import ndimage | |
from scipy.conftest import array_api_compatible | |
skip_xp_backends = pytest.mark.skip_xp_backends | |
pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends"), | |
skip_xp_backends(cpu_only=True, exceptions=['cupy', 'jax.numpy'],)] | |
def get_spline_knot_values(order): | |
"""Knot values to the right of a B-spline's center.""" | |
knot_values = {0: [1], | |
1: [1], | |
2: [6, 1], | |
3: [4, 1], | |
4: [230, 76, 1], | |
5: [66, 26, 1]} | |
return knot_values[order] | |
def make_spline_knot_matrix(xp, n, order, mode='mirror'): | |
"""Matrix to invert to find the spline coefficients.""" | |
knot_values = get_spline_knot_values(order) | |
# NB: do computations with numpy, convert to xp as the last step only | |
matrix = np.zeros((n, n)) | |
for diag, knot_value in enumerate(knot_values): | |
indices = np.arange(diag, n) | |
if diag == 0: | |
matrix[indices, indices] = knot_value | |
else: | |
matrix[indices, indices - diag] = knot_value | |
matrix[indices - diag, indices] = knot_value | |
knot_values_sum = knot_values[0] + 2 * sum(knot_values[1:]) | |
if mode == 'mirror': | |
start, step = 1, 1 | |
elif mode == 'reflect': | |
start, step = 0, 1 | |
elif mode == 'grid-wrap': | |
start, step = -1, -1 | |
else: | |
raise ValueError(f'unsupported mode {mode}') | |
for row in range(len(knot_values) - 1): | |
for idx, knot_value in enumerate(knot_values[row + 1:]): | |
matrix[row, start + step*idx] += knot_value | |
matrix[-row - 1, -start - 1 - step*idx] += knot_value | |
return xp.asarray(matrix / knot_values_sum) | |
def test_spline_filter_vs_matrix_solution(order, mode, xp): | |
n = 100 | |
eye = xp.eye(n, dtype=xp.float64) | |
spline_filter_axis_0 = ndimage.spline_filter1d(eye, axis=0, order=order, | |
mode=mode) | |
spline_filter_axis_1 = ndimage.spline_filter1d(eye, axis=1, order=order, | |
mode=mode) | |
matrix = make_spline_knot_matrix(xp, n, order, mode=mode) | |
assert_almost_equal(eye, spline_filter_axis_0 @ matrix) | |
assert_almost_equal(eye, spline_filter_axis_1 @ matrix.T) | |