diff --git a/index.ipynb b/index.ipynb index 32cc497..aaa90b3 100644 --- a/index.ipynb +++ b/index.ipynb @@ -46,9 +46,9 @@ "\n", "nx, ny = 6, 3\n", "\n", - "np.random.seed(0)\n", - "orography = np.random.normal(1000, 600, size=(ny, nx)) - 400\n", - "sea_level_temp = np.random.normal(290, 5, size=(ny, nx))" + "rng = np.random.default_rng(0)\n", + "orography = rng.normal(1000, 600, size=(ny, nx)) - 400\n", + "sea_level_temp = rng.normal(290, 5, size=(ny, nx))" ] }, { @@ -85,20 +85,20 @@ "\n", "import matplotlib.pyplot as plt\n", "\n", - "plt.set_cmap('viridis')\n", + "plt.set_cmap(\"viridis\")\n", "fig = plt.figure(figsize=(8, 6))\n", "\n", "plt.subplot(1, 2, 1)\n", "plt.pcolormesh(orography)\n", - "cbar = plt.colorbar(orientation='horizontal',\n", - " label='Orography (m)')\n", + "cbar = plt.colorbar(orientation=\"horizontal\",\n", + " label=\"Orography (m)\")\n", "# Reduce the maximum number of ticks to 5.\n", "cbar.ax.xaxis.get_major_locator().nbins = 5\n", "\n", "plt.subplot(1, 2, 2)\n", "plt.pcolormesh(sea_level_temp)\n", - "cbar = plt.colorbar(orientation='horizontal',\n", - " label='Sea level temperature (K)')\n", + "cbar = plt.colorbar(orientation=\"horizontal\",\n", + " label=\"Sea level temperature (K)\")\n", "# Reduce the maximum number of ticks to 5.\n", "cbar.ax.xaxis.get_major_locator().nbins = 5\n", "\n", @@ -179,17 +179,17 @@ "source": [ "plt.figure(figsize=(8, 6))\n", "plt.fill_between(np.arange(6), np.zeros(6), orography[1, :],\n", - " color='green', linewidth=2, label='Orography')\n", + " color=\"green\", linewidth=2, label=\"Orography\")\n", "\n", "plt.plot(np.zeros(nx),\n", - " color='blue', linewidth=1.2,\n", - " label='Sea level')\n", + " color=\"blue\", linewidth=1.2,\n", + " label=\"Sea level\")\n", "\n", "for i in range(9):\n", - " plt.plot(altitude[i, 1, :], color='gray', linestyle='--',\n", - " label='Model levels' if i == 0 else None)\n", + " plt.plot(altitude[i, 1, :], color=\"gray\", linestyle=\"--\",\n", + " label=\"Model levels\" if i == 0 else None)\n", "\n", - "plt.ylabel('altitude / m')\n", + "plt.ylabel(\"altitude / m\")\n", "plt.margins(0.1)\n", "plt.legend()\n", "plt.show()" @@ -250,14 +250,12 @@ } ], "source": [ - "from matplotlib.colors import LogNorm\n", - "\n", "fig = plt.figure(figsize=(8, 6))\n", "norm = plt.Normalize(vmin=temperature.min(), vmax=temperature.max())\n", "\n", "for i in range(nz):\n", " plt.subplot(3, 3, i + 1)\n", - " qm = plt.pcolormesh(temperature[i], cmap='viridis', norm=norm)\n", + " qm = plt.pcolormesh(temperature[i], cmap=\"viridis\", norm=norm)\n", "\n", "plt.subplots_adjust(right=0.84, wspace=0.3, hspace=0.3)\n", "cax = plt.axes([0.85, 0.1, 0.03, 0.8])\n", @@ -332,21 +330,21 @@ "source": [ "plt.figure(figsize=(8, 6))\n", "plt.fill_between(np.arange(6), np.zeros(6), orography[1, :],\n", - " color='green', linewidth=2, label='Orography')\n", + " color=\"green\", linewidth=2, label=\"Orography\")\n", "\n", "for i in range(9):\n", " plt.plot(altitude[i, 1, :],\n", - " color='gray', lw=1.2,\n", - " label=None if i > 0 else 'Source levels \\n(model levels)')\n", + " color=\"gray\", lw=1.2,\n", + " label=None if i > 0 else \"Source levels \\n(model levels)\")\n", "for i, target in enumerate(target_altitudes):\n", " plt.plot(np.repeat(target, 6),\n", - " color='gray', linestyle='--', lw=1.4, alpha=0.6,\n", - " label=None if i > 0 else 'Target levels \\n(altitude)')\n", + " color=\"gray\", linestyle=\"--\", lw=1.4, alpha=0.6,\n", + " label=None if i > 0 else \"Target levels \\n(altitude)\")\n", "\n", - "plt.ylabel('height / m')\n", + "plt.ylabel(\"height / m\")\n", "plt.margins(top=0.1)\n", "plt.legend()\n", - "plt.savefig('summary.png')\n", + "plt.savefig(\"summary.png\")\n", "plt.show()" ] }, @@ -413,7 +411,7 @@ "plt.figure(figsize=(8, 6))\n", "ax1 = plt.subplot(1, 2, 1)\n", "plt.fill_between(np.arange(6), np.zeros(6), orography[1, :],\n", - " color='green', linewidth=2, label='Orography')\n", + " color=\"green\", linewidth=2, label=\"Orography\")\n", "cs = plt.contourf(np.tile(np.arange(6), nz).reshape(nz, 6),\n", " altitude[:, 1],\n", " temperature[:, 1])\n", @@ -423,7 +421,7 @@ "\n", "plt.subplot(1, 2, 2, sharey=ax1)\n", "plt.fill_between(np.arange(6), np.zeros(6), orography[1, :],\n", - " color='green', linewidth=2, label='Orography')\n", + " color=\"green\", linewidth=2, label=\"Orography\")\n", "plt.contourf(np.arange(6), target_altitudes,\n", " np.ma.masked_invalid(new_temperature[:, 1]),\n", " cmap=cs.cmap, norm=cs.norm)\n", @@ -432,9 +430,9 @@ " c=new_temperature[:, 1])\n", "plt.scatter(np.tile(np.arange(nx), target_nz).reshape(target_nz, nx),\n", " np.repeat(target_altitudes, nx).reshape(target_nz, nx),\n", - " s=np.isnan(new_temperature[:, 1]) * 15, marker='x')\n", + " s=np.isnan(new_temperature[:, 1]) * 15, marker=\"x\")\n", "\n", - "plt.suptitle('Temperature cross-section before and after restratification')\n", + "plt.suptitle(\"Temperature cross-section before and after restratification\")\n", "plt.show()" ] } diff --git a/pyproject.toml b/pyproject.toml index 9adc4f4..61f4022 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,7 +145,6 @@ preview = false [tool.ruff.lint] ignore = [ - # flake8-annotations (ANN) # https://docs.astral.sh/ruff/rules/#flake8-annotations-ann "ANN001", # Missing type annotation for function argument {name} @@ -156,21 +155,6 @@ ignore = [ "ANN204", # Missing return type annotation for special method {name} "ARG002", # Unused method argument: {name} - "ARG003", # Unused class method argument: {name} - - # flake8-bugbear (B) - # https://docs.astral.sh/ruff/rules/#flake8-bugbear-b - "B028", # No explicit stacklevel keyword argument found - - # flake8-comprehensions (C4) - # https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 - "C405", # Unnecessary {obj_type} literal (rewrite as a set literal) - "C419", # Unnecessary list comprehension - - # flake8-commas (COM) - # https://docs.astral.sh/ruff/rules/#flake8-commas-com - "COM812", # Trailing comma missing. - "COM819", # Trailing comma prohibited. # pydocstyle (D) # https://docs.astral.sh/ruff/rules/#pydocstyle-d @@ -185,46 +169,12 @@ ignore = [ # https://docs.astral.sh/ruff/rules/#eradicate-era "ERA001", # Found commented-out code - # flake8-boolean-trap (FBT) - # https://docs.astral.sh/ruff/rules/#flake8-boolean-trap-fbt - "FBT002", # Boolean default positional argument in function definition - - # flake8-implicit-str-concat (ISC) - # https://docs.astral.sh/ruff/rules/single-line-implicit-string-concatenation/ - # NOTE: This rule may cause conflicts when used with "ruff format". - "ISC001", # Implicitly concatenate string literals on one line. - - # pep8-naming (N) - # https://docs.astral.sh/ruff/rules/#pep8-naming-n - "N801", # Class name {name} should use CapWords convention + "PT011", - # Refactor (R) - # https://docs.astral.sh/ruff/rules/#refactor-r - "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable - - # flake8-pytest-style (PT) - # https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt - "PT009", # Use a regular assert instead of unittest-style {assertion} - "PT027", # Use pytest.raises instead of unittest-style {assertion} - - # flake8-return (RET) - # https://docs.astral.sh/ruff/rules/#flake8-return-ret - "RET504", # Unnecessary assignment to {name} before return statement - - # Ruff-specific rules (RUF) - # https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf - "RUF005", # Consider {expression} instead of concatenation - "RUF012", # Mutable class attributes should be annotated with typing.ClassVar - - # flake8-self (SLF) - # https://docs.astral.sh/ruff/rules/#flake8-self-slf - "SLF001", # Private member accessed: {access} + "PLR2004", - # flake8-print (T20) - # https://docs.astral.sh/ruff/rules/#flake8-print-t20 - "T201", # print found - - ] + "S101", # Use of assert detected. +] preview = false select = [ "ALL", @@ -239,13 +189,20 @@ known-first-party = ["iris"] [tool.ruff.lint.per-file-ignores] # All test scripts - +"src/stratify/tests/performance.py" = [ + "T201", # print found +] +"setup.py" = [ + "T201", # print found + "RUF012", # Mutable class attributes should be annotated with typing.ClassVar +] # Change to match specific package path: -"lib/iris/tests/*.py" = [ +"src/stratify/tests/*.py" = [ # https://docs.astral.sh/ruff/rules/undocumented-public-module/ "D100", # Missing docstring in public module "D205", # 1 blank line required between summary line and description "D401", # 1 First line of docstring should be in imperative mood + "SLF001", # Private member accessed: {access} ] [tool.ruff.lint.pydocstyle] diff --git a/setup.py b/setup.py index 964cba9..91f036c 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ from Cython.Build import cythonize # isort:skip except ImportError: wmsg = "Cython unavailable, unable to build stratify extensions!" - warnings.warn(wmsg) + warnings.warn(wmsg, stacklevel=2) cythonize = None @@ -53,7 +53,7 @@ def run(self): [ ("CYTHON_TRACE", "1"), ("CYTHON_TRACE_NOGIL", "1"), - ] + ], ) cython_directives.update({"linetrace": True}) if FLAG_COVERAGE in sys.argv: @@ -70,9 +70,11 @@ def run(self): ) extensions.append(extension) -if cythonize and not any([arg in CMDS_NOCYTHONIZE for arg in sys.argv]): +if cythonize and not any(arg in CMDS_NOCYTHONIZE for arg in sys.argv): extensions = cythonize( - extensions, compiler_directives=cython_directives, language_level=3 + extensions, + compiler_directives=cython_directives, + language_level=3, ) cmdclass = {"clean_cython": CleanCython} diff --git a/src/stratify/_bounded_vinterp.py b/src/stratify/_bounded_vinterp.py index 465a96d..94874e8 100644 --- a/src/stratify/_bounded_vinterp.py +++ b/src/stratify/_bounded_vinterp.py @@ -91,7 +91,8 @@ def interpolate_conservative(z_target, z_src, fz_src, axis=-1): ) raise ValueError(msg.format(tuple(dat_shape), tuple(src_shape))) - if z_src.shape[-1] != 2: + shape_2 = 2 + if z_src.shape[-1] != shape_2: msg = "Unexpected source and target bounds shape. shape[-1] != 2" raise ValueError(msg) @@ -100,8 +101,8 @@ def interpolate_conservative(z_target, z_src, fz_src, axis=-1): # src_data bdims = list(range(fz_src.ndim - (z_src.ndim - 1))) - data_vdims = [ind for ind in range(fz_src.ndim) if ind not in (bdims + [axis])] - data_transpose = bdims + [axis] + data_vdims + data_vdims = [ind for ind in range(fz_src.ndim) if ind not in ([*bdims, axis])] + data_transpose = [*bdims, axis, *data_vdims] fz_src_reshaped = np.transpose(fz_src, data_transpose) fz_src_orig = list(fz_src_reshaped.shape) shape = ( @@ -113,21 +114,24 @@ def interpolate_conservative(z_target, z_src, fz_src, axis=-1): # Define our src and target bounds in a consistent way. # [axis_interpolation, z_varying, 2] - vdims = list(set(range(z_src.ndim)) - set([axis_relative])) - z_src_reshaped = np.transpose(z_src, [axis_relative] + vdims) - z_target_reshaped = np.transpose(z_target, [axis_relative] + vdims) + # vdims = list(set(range(z_src.ndim)) - set([axis_relative])) + vdims = list(set(range(z_src.ndim)) - {axis_relative}) + z_src_reshaped = np.transpose(z_src, [axis_relative, *vdims]) + z_target_reshaped = np.transpose(z_target, [axis_relative, *vdims]) shape = int(np.prod(z_src_reshaped.shape[1:-1])) z_src_reshaped = z_src_reshaped.reshape( - [z_src_reshaped.shape[0], shape, z_src_reshaped.shape[-1]] + [z_src_reshaped.shape[0], shape, z_src_reshaped.shape[-1]], ) shape = int(np.prod(z_target_reshaped.shape[1:-1])) z_target_reshaped = z_target_reshaped.reshape( - [z_target_reshaped.shape[0], shape, z_target_reshaped.shape[-1]] + [z_target_reshaped.shape[0], shape, z_target_reshaped.shape[-1]], ) result = conservative_interpolation( - z_src_reshaped, z_target_reshaped, fz_src_reshaped + z_src_reshaped, + z_target_reshaped, + fz_src_reshaped, ) # Turn the result into a shape consistent with the source. @@ -136,5 +140,5 @@ def interpolate_conservative(z_target, z_src, fz_src, axis=-1): shape[len(bdims)] = z_target.shape[axis_relative] result = result.reshape(shape) invert_transpose = [data_transpose.index(ind) for ind in list(range(result.ndim))] - result = result.transpose(invert_transpose) - return result + # result = result.transpose(invert_transpose) + return result.transpose(invert_transpose) diff --git a/src/stratify/tests/performance.py b/src/stratify/tests/performance.py index 88e6616..cd1a228 100644 --- a/src/stratify/tests/performance.py +++ b/src/stratify/tests/performance.py @@ -6,7 +6,7 @@ import stratify -def src_data(shape=(400, 500, 100), lazy=False): +def src_data(shape=(400, 500, 100)): z = np.tile(np.linspace(0, 100, shape[-1]), np.prod(shape[:2])).reshape(shape) if lazy: fz = da.arange(np.prod(shape), dtype=np.float64).reshape(shape) diff --git a/src/stratify/tests/test_bounded_vinterp.py b/src/stratify/tests/test_bounded_vinterp.py index 06ba948..d661356 100644 --- a/src/stratify/tests/test_bounded_vinterp.py +++ b/src/stratify/tests/test_bounded_vinterp.py @@ -2,6 +2,7 @@ import numpy as np from numpy.testing import assert_array_equal +import pytest import stratify._bounded_vinterp as bounded_vinterp @@ -24,7 +25,7 @@ def gen_bounds(self, start, stop, step): [ np.arange(start, stop - step, step), np.arange(start + step, stop, step), - ] + ], ) bounds = bounds.transpose((1, 0)) return bounds.copy() @@ -32,7 +33,10 @@ def gen_bounds(self, start, stop, step): def test_target_half_resolution(self): target_bounds = self.gen_bounds(0, 7, 2) res = bounded_vinterp.interpolate_conservative( - target_bounds, self.bounds, self.data, axis=1 + target_bounds, + self.bounds, + self.data, + axis=1, ) target_data = np.ones((4, 3)) * 2 assert_array_equal(res, target_data) @@ -40,7 +44,10 @@ def test_target_half_resolution(self): def test_target_double_resolution(self): target_bounds = self.gen_bounds(0, 6.5, 0.5) res = bounded_vinterp.interpolate_conservative( - target_bounds, self.bounds, self.data, axis=1 + target_bounds, + self.bounds, + self.data, + axis=1, ) target_data = np.ones((4, 12)) * 0.5 assert_array_equal(res, target_data) @@ -52,7 +59,10 @@ def test_no_broadcasting(self): data = self.data[0] target_bounds = self.gen_bounds(0, 7, 2) res = bounded_vinterp.interpolate_conservative( - target_bounds, self.bounds, data, axis=0 + target_bounds, + self.bounds, + data, + axis=0, ) target_data = np.ones(3) * 2 assert_array_equal(res, target_data) @@ -65,7 +75,10 @@ def test_source_with_nans(self): data[0] = data[-2:] = np.nan target_bounds = self.gen_bounds(0, 6.5, 0.5) res = bounded_vinterp.interpolate_conservative( - target_bounds, self.bounds, data, axis=0 + target_bounds, + self.bounds, + data, + axis=0, ) target_data = np.ones(12) * 0.5 target_data[:2] = np.nan @@ -78,7 +91,10 @@ def test_target_extends_above_source(self): source_bounds = self.gen_bounds(0, 7, 1) target_bounds = self.gen_bounds(0, 8, 1) res = bounded_vinterp.interpolate_conservative( - target_bounds, source_bounds, self.data, axis=1 + target_bounds, + source_bounds, + self.data, + axis=1, ) target_data = np.ones((4, 7)) target_data[:, -1] = np.nan @@ -91,10 +107,13 @@ def test_target_extends_above_source_non_equally_spaced_coords(self): target_bounds = self.gen_bounds(0, 8, 1) data = np.ones((4, 4)) res = bounded_vinterp.interpolate_conservative( - target_bounds, source_bounds, data, axis=1 + target_bounds, + source_bounds, + data, + axis=1, ) target_data = np.array( - [1 / 1.5, 1 + ((1 / 3.0) / 1), 0.25, 0.25, 0.25, 0.25, 1.0] + [1 / 1.5, 1 + ((1 / 3.0) / 1), 0.25, 0.25, 0.25, 0.25, 1.0], )[None] target_data = np.repeat(target_data, 4, 0) assert_array_equal(res, target_data) @@ -105,7 +124,10 @@ def test_target_extends_below_source(self): source_bounds = self.gen_bounds(0, 7, 1) target_bounds = self.gen_bounds(-1, 7, 1) res = bounded_vinterp.interpolate_conservative( - target_bounds, source_bounds, self.data, axis=1 + target_bounds, + source_bounds, + self.data, + axis=1, ) target_data = np.ones((4, 7)) target_data[:, 0] = np.nan @@ -124,7 +146,7 @@ def gen_bounds(self, start, stop, step): [ np.arange(start, stop - step, step), np.arange(start + step, stop, step), - ] + ], ) bounds = bounds.transpose((1, 0)) bounds = bounds[..., None, :].repeat(4, -2) @@ -134,7 +156,10 @@ def gen_bounds(self, start, stop, step): def test_target_half_resolution(self): target_bounds = self.gen_bounds(0, 7, 2) res = bounded_vinterp.interpolate_conservative( - target_bounds, self.bounds, self.data, axis=1 + target_bounds, + self.bounds, + self.data, + axis=1, ) target_data = np.ones((2, 3, 4, 3)) * 2 @@ -148,7 +173,10 @@ def test_target_half_resolution_alt_axis(self): target_bounds = target_bounds.transpose((1, 0, 2, 3)) res = bounded_vinterp.interpolate_conservative( - target_bounds, bounds, data, axis=2 + target_bounds, + bounds, + data, + axis=2, ) target_data = np.ones((2, 4, 3, 3)) * 2 assert_array_equal(res, target_data) @@ -156,7 +184,10 @@ def test_target_half_resolution_alt_axis(self): def test_target_double_resolution(self): target_bounds = self.gen_bounds(0, 6.5, 0.5) res = bounded_vinterp.interpolate_conservative( - target_bounds, self.bounds, self.data, axis=1 + target_bounds, + self.bounds, + self.data, + axis=1, ) target_data = np.ones((2, 12, 4, 3)) * 0.5 assert_array_equal(res, target_data) @@ -169,7 +200,7 @@ def test_mismatch_source_target_level_dimensionality(self): data = np.zeros((3, 4)) msg = "Expecting source and target levels dimensionality" - with self.assertRaisesRegex(ValueError, msg): + with pytest.raises(ValueError, match=msg): bounded_vinterp.interpolate_conservative(target_bounds, source_bounds, data) def test_mismatch_source_target_level_shape(self): @@ -184,9 +215,12 @@ def test_mismatch_source_target_level_shape(self): "the axis of interpolation to be identical. " r"\('-', 4, 2\) != \(2, 5, 2\)" ) - with self.assertRaisesRegex(ValueError, msg): + with pytest.raises(ValueError, match=msg): bounded_vinterp.interpolate_conservative( - target_bounds, source_bounds, data, axis=0 + target_bounds, + source_bounds, + data, + axis=0, ) def test_mismatch_between_source_levels_source_data(self): @@ -199,9 +233,12 @@ def test_mismatch_between_source_levels_source_data(self): "The provided data is not of compatible shape with the " r"provided source bounds. \('-', 3, 4\) != \(2, 4\)" ) - with self.assertRaisesRegex(ValueError, msg): + with pytest.raises(ValueError, match=msg): bounded_vinterp.interpolate_conservative( - target_bounds, source_bounds, data, axis=0 + target_bounds, + source_bounds, + data, + axis=0, ) def test_unexpected_bounds_shape(self): @@ -212,9 +249,12 @@ def test_unexpected_bounds_shape(self): data = np.zeros((3, 4)) msg = r"Unexpected source and target bounds shape. shape\[-1\] != 2" - with self.assertRaisesRegex(ValueError, msg): + with pytest.raises(ValueError, match=msg): bounded_vinterp.interpolate_conservative( - target_bounds, source_bounds, data, axis=0 + target_bounds, + source_bounds, + data, + axis=0, ) def test_not_conservative(self): @@ -226,7 +266,7 @@ def gen_bounds(start, stop, step): [ np.arange(start, stop - step, step), np.arange(start + step, stop, step), - ] + ], ) bounds = bounds.transpose((1, 0)) return bounds.copy() @@ -236,9 +276,12 @@ def gen_bounds(start, stop, step): data = np.ones((4, 6)) msg = "Weights calculation yields a less than conservative result." - with self.assertRaisesRegex(ValueError, msg): + with pytest.raises(ValueError, match=msg): bounded_vinterp.interpolate_conservative( - target_bounds, source_bounds, data, axis=1 + target_bounds, + source_bounds, + data, + axis=1, ) diff --git a/src/stratify/tests/test_vinterp.py b/src/stratify/tests/test_vinterp.py index 002bbc4..b41798e 100644 --- a/src/stratify/tests/test_vinterp.py +++ b/src/stratify/tests/test_vinterp.py @@ -1,8 +1,7 @@ -import unittest - import dask.array as da import numpy as np from numpy.testing import assert_array_almost_equal, assert_array_equal +import pytest import stratify import stratify._vinterp as vinterp @@ -18,7 +17,7 @@ def extrap_kernel(self, direction, z_src, fz_src, level, output_array): output_array[:] = np.inf if direction > 0 else -np.inf -class TestColumnInterpolation(unittest.TestCase): +class TestColumnInterpolation: def interpolate(self, x_target, x_src, rising=None): x_target = np.array(x_target) x_src = np.array(x_src) @@ -90,12 +89,12 @@ def test_interp_and_extrap(self): def test_nan_in_target(self): msg = "The target coordinate .* NaN" - with self.assertRaisesRegex(ValueError, msg): + with pytest.raises(ValueError, match=msg): self.interpolate([1, np.nan], [2, 4, 5]) def test_nan_in_src(self): msg = "The source coordinate .* NaN" - with self.assertRaisesRegex(ValueError, msg): + with pytest.raises(ValueError, match=msg): self.interpolate([1], [0, np.nan], rising=True) def test_all_nan_in_src(self): @@ -139,7 +138,7 @@ def test_length_one_interp(self): assert_array_equal(r, [-np.inf]) def test_auto_rising_not_enough_values(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): _ = self.interpolate([1], [2]) def test_auto_rising_equal_values(self): @@ -149,7 +148,7 @@ def test_auto_rising_equal_values(self): assert_array_equal(r, [-np.inf]) -class Test_INTERPOLATE_LINEAR(unittest.TestCase): +class TestInterpolateLinear: def interpolate(self, x_target): interpolation = stratify.INTERPOLATE_LINEAR extrapolation = DirectionExtrapolator() @@ -172,19 +171,25 @@ def test_on_the_mark(self): def test_zero_gradient(self): assert_array_equal( stratify.interpolate( - [1], [0, 1, 1, 2], [10, 20, 30, 40], interpolation="linear" + [1], + [0, 1, 1, 2], + [10, 20, 30, 40], + interpolation="linear", ), [20], ) def test_inbetween(self): assert_array_equal( - self.interpolate([0.5, 1.25, 2.5, 3.75]), [5, 12.5, 25, 37.5] + self.interpolate([0.5, 1.25, 2.5, 3.75]), + [5, 12.5, 25, 37.5], ) def test_high_precision(self): assert_array_almost_equal( - self.interpolate([1.123456789]), [11.23456789], decimal=6 + self.interpolate([1.123456789]), + [11.23456789], + decimal=6, ) def test_single_point(self): @@ -202,10 +207,10 @@ def test_single_point(self): extrapolation=extrapolation, rising=True, ) - self.assertEqual(r, 20) + assert r == 20 -class Test_INTERPOLATE_NEAREST(unittest.TestCase): +class TestInterpolateNearest: def interpolate(self, x_target): interpolation = stratify.INTERPOLATE_NEAREST extrapolation = DirectionExtrapolator() @@ -233,7 +238,7 @@ def test_high_precision(self): assert_array_equal(self.interpolate([1.123456789]), [10]) -class Test_EXTRAPOLATE_NAN(unittest.TestCase): +class TestExtrapolateNan: def interpolate(self, x_target): interpolation = IndexInterpolator() extrapolation = stratify.EXTRAPOLATE_NAN @@ -257,7 +262,7 @@ def test_above(self): assert_array_equal(self.interpolate([5]), [np.nan]) -class Test_EXTRAPOLATE_NEAREST(unittest.TestCase): +class TestExtrapolateNearest: def interpolate(self, x_target): interpolation = IndexInterpolator() extrapolation = stratify.EXTRAPOLATE_NEAREST @@ -281,7 +286,7 @@ def test_above(self): assert_array_equal(self.interpolate([5]), [40]) -class Test_EXTRAPOLATE_LINEAR(unittest.TestCase): +class TestExtrapolateLinear: def interpolate(self, x_target): interpolation = IndexInterpolator() extrapolation = stratify.EXTRAPOLATE_LINEAR @@ -316,8 +321,8 @@ def test_npts(self): interpolation = IndexInterpolator() extrapolation = stratify.EXTRAPOLATE_LINEAR - msg = r"Linear extrapolation requires at least 2 " r"source points. Got 1." - with self.assertRaisesRegex(ValueError, msg): + msg = r"Linear extrapolation requires at least 2 source points. Got 1." + with pytest.raises(ValueError, match=msg): stratify.interpolate( [1, 3.0], [2], @@ -328,17 +333,17 @@ def test_npts(self): ) -class Test_custom_extrap_kernel(unittest.TestCase): - class my_kernel(vinterp.PyFuncExtrapolator): +class TestCustomExtrapKernel: + class MyKernel(vinterp.PyFuncExtrapolator): def __init__(self, *args, **kwargs): - super(Test_custom_extrap_kernel.my_kernel, self).__init__(*args, **kwargs) + super(TestCustomExtrapKernel.MyKernel, self).__init__(*args, **kwargs) def extrap_kernel(self, direction, z_src, fz_src, level, output_array): output_array[:] = -10 def test(self): interpolation = IndexInterpolator() - extrapolation = Test_custom_extrap_kernel.my_kernel() + extrapolation = TestCustomExtrapKernel.MyKernel() r = stratify.interpolate( [1, 3.0], @@ -351,42 +356,42 @@ def test(self): assert_array_equal(r, [0, -10]) -class Test_Interpolation(unittest.TestCase): +class TestInterpolation: def test_axis_m1(self): data = np.empty([5, 4, 23, 7, 3]) zdata = np.empty([5, 4, 23, 7, 3]) i = vinterp._Interpolation([1, 3], zdata, data) # 1288 == 5 * 4 * 23 * 7 - self.assertEqual(i._result_working_shape, (1, 3220, 2, 1)) - self.assertEqual(i.result_shape, (5, 4, 23, 7, 2)) - self.assertEqual(i._zp_reshaped.shape, (3220, 3, 1)) - self.assertEqual(i._fp_reshaped.shape, (1, 3220, 3, 1)) - self.assertEqual(i.axis, -1) - self.assertEqual(i.orig_shape, data.shape) - self.assertIsInstance(i.z_target, np.ndarray) - self.assertEqual(list(i.z_target), [1, 3]) + assert i._result_working_shape == (1, 3220, 2, 1) + assert i.result_shape == (5, 4, 23, 7, 2) + assert i._zp_reshaped.shape == (3220, 3, 1) + assert i._fp_reshaped.shape == (1, 3220, 3, 1) + assert i.axis == -1 + assert i.orig_shape == data.shape + assert isinstance(i.z_target, np.ndarray) + assert list(i.z_target) == [1, 3] def test_axis_0(self): data = zdata = np.empty([5, 4, 23, 7, 3]) i = vinterp._Interpolation([1, 3], data, zdata, axis=0) # 1932 == 4 * 23 * 7 *3 - self.assertEqual(i._result_working_shape, (1, 1, 2, 1932)) - self.assertEqual(i.result_shape, (2, 4, 23, 7, 3)) - self.assertEqual(i._zp_reshaped.shape, (1, 5, 1932)) + assert i._result_working_shape == (1, 1, 2, 1932) + assert i.result_shape == (2, 4, 23, 7, 3) + assert i._zp_reshaped.shape == (1, 5, 1932) def test_axis_2(self): data = zdata = np.empty([5, 4, 23, 7, 3]) i = vinterp._Interpolation([1, 3], data, zdata, axis=2) # 1932 == 4 * 23 * 7 *3 - self.assertEqual(i._result_working_shape, (1, 20, 2, 21)) - self.assertEqual(i.result_shape, (5, 4, 2, 7, 3)) - self.assertEqual(i._zp_reshaped.shape, (20, 23, 21)) + assert i._result_working_shape == (1, 20, 2, 21) + assert i.result_shape == (5, 4, 2, 7, 3) + assert i._zp_reshaped.shape == (20, 23, 21) def test_inconsistent_shape(self): data = np.empty([5, 4, 23, 7, 3]) zdata = np.empty([5, 4, 3, 7, 3]) emsg = "z_src .* is not a subset of fz_src" - with self.assertRaisesRegex(ValueError, emsg): + with pytest.raises(ValueError, match=emsg): vinterp._Interpolation([1, 3], data, zdata, axis=2) def test_axis_out_of_bounds_fz_src_relative(self): @@ -395,7 +400,7 @@ def test_axis_out_of_bounds_fz_src_relative(self): zdata = np.empty((5, 4)) axis = 4 emsg = "Axis {} out of range" - with self.assertRaisesRegex(ValueError, emsg.format(axis)): + with pytest.raises(ValueError, match=emsg.format(axis)): vinterp._Interpolation([1, 3], data, zdata, axis=axis) def test_axis_out_of_bounds_z_src_absolute(self): @@ -404,7 +409,7 @@ def test_axis_out_of_bounds_z_src_absolute(self): zdata = np.empty((3, 5, 4)) axis = 0 emsg = "Axis {} out of range" - with self.assertRaisesRegex(ValueError, emsg.format(axis)): + with pytest.raises(ValueError, match=emsg.format(axis)): vinterp._Interpolation([1, 3], data, zdata, axis=axis) def test_axis_greater_than_z_src_ndim(self): @@ -414,14 +419,14 @@ def test_axis_greater_than_z_src_ndim(self): zdata = np.empty((3, 5, 4)) axis = 2 result = vinterp._Interpolation(data.copy(), data, zdata, axis=axis) - self.assertEqual(result.result_shape, (3, 5, 4)) + assert result.result_shape == (3, 5, 4) def test_nd_inconsistent_ndims(self): z_target = np.empty((2, 3, 4)) z_src = np.empty((3, 4)) fz_src = np.empty((2, 3, 4)) emsg = "z_target and z_src must have the same number of dimensions" - with self.assertRaisesRegex(ValueError, emsg): + with pytest.raises(ValueError, match=emsg): vinterp._Interpolation(z_target, z_src, fz_src) def test_nd_inconsistent_shape(self): @@ -432,29 +437,33 @@ def test_nd_inconsistent_shape(self): "z_target and z_src have different shapes, " r"got \(3, :, 6\) != \(3, :, 5\)" ) - with self.assertRaisesRegex(ValueError, emsg): + with pytest.raises(ValueError, match=emsg): vinterp._Interpolation(z_target, z_src, fz_src, axis=2) def test_result_dtype_f4(self): interp = vinterp._Interpolation( - [17.5], np.arange(4) * 10, np.arange(4, dtype="f4") + [17.5], + np.arange(4) * 10, + np.arange(4, dtype="f4"), ) result = interp.interpolate() - self.assertEqual(interp._target_dtype, np.dtype("f4")) - self.assertEqual(result.dtype, np.dtype("f4")) + assert interp._target_dtype == np.dtype("f4") + assert result.dtype == np.dtype("f4") def test_result_dtype_f8(self): interp = vinterp._Interpolation( - [17.5], np.arange(4) * 10, np.arange(4, dtype="f8") + [17.5], + np.arange(4) * 10, + np.arange(4, dtype="f8"), ) result = interp.interpolate() - self.assertEqual(interp._target_dtype, np.dtype("f8")) - self.assertEqual(result.dtype, np.dtype("f8")) + assert interp._target_dtype == np.dtype("f8") + assert result.dtype == np.dtype("f8") -class Test__Interpolation_interpolate_z_target_nd(unittest.TestCase): +class TestInterpolationInterpolateZTargetNd: def test_target_z_3d_on_axis_0(self): z_target = z_source = f_source = np.arange(3) * np.ones([4, 2, 3]) interp = vinterp._Interpolation( @@ -535,25 +544,34 @@ def test_target_z_2d_over_3d_on_axis_m1(self): assert_array_equal(result, expected) -class Test_interpolate(unittest.TestCase): +class TestInterpolate: def test_target_z_3d_axis_0(self): z_target = z_source = f_source = np.arange(3) * np.ones([4, 2, 3]) result = vinterp.interpolate( - z_target, z_source, f_source, extrapolation="linear" + z_target, + z_source, + f_source, + extrapolation="linear", ) assert_array_equal(result, f_source) def test_dask(self): z_target = z_source = f_source = np.arange(3) * np.ones([4, 2, 3]) reference = vinterp.interpolate( - z_target, z_source, f_source, extrapolation="linear" + z_target, + z_source, + f_source, + extrapolation="linear", ) # Test with various combinations of lazy input f_src = da.asarray(f_source, chunks=(2, 1, 2)) for z_tgt in (z_target, z_target.tolist(), da.asarray(z_target)): for z_src in (z_source, da.asarray(z_source)): result = vinterp.interpolate( - z_tgt, z_src, f_src, extrapolation="linear" + z_tgt, + z_src, + f_src, + extrapolation="linear", ) assert_array_equal(reference, result.compute()) @@ -561,17 +579,21 @@ def test_dask_1d_target(self): z_target = np.array([0.5]) z_source = f_source = np.arange(3) * np.ones([4, 2, 3]) reference = vinterp.interpolate( - z_target, z_source, f_source, axis=1, extrapolation="linear" + z_target, + z_source, + f_source, + axis=1, + extrapolation="linear", ) # Test with various combinations of lazy input f_src = da.asarray(f_source, chunks=(2, 1, 2)) for z_tgt in (z_target, z_target.tolist(), da.asarray(z_target)): for z_src in (z_source, da.asarray(z_source)): result = vinterp.interpolate( - z_tgt, z_src, f_src, axis=1, extrapolation="linear" + z_tgt, + z_src, + f_src, + axis=1, + extrapolation="linear", ) assert_array_equal(reference, result.compute()) - - -if __name__ == "__main__": - unittest.main()