Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 26 additions & 28 deletions index.ipynb
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's cool that Ruff works on notebooks. It's also cool the improved pattern it found for randomness. 👍

Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
"\n",
"nx, ny = 6, 3\n",
"\n",
"np.random.seed(0)\n",
"orography = np.random.normal(1000, 600, size=(ny, nx)) - 400\n",
"sea_level_temp = np.random.normal(290, 5, size=(ny, nx))"
"rng = np.random.default_rng(0)\n",
"orography = rng.normal(1000, 600, size=(ny, nx)) - 400\n",
"sea_level_temp = rng.normal(290, 5, size=(ny, nx))"
]
},
{
Expand Down Expand Up @@ -85,20 +85,20 @@
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.set_cmap('viridis')\n",
"plt.set_cmap(\"viridis\")\n",
"fig = plt.figure(figsize=(8, 6))\n",
"\n",
"plt.subplot(1, 2, 1)\n",
"plt.pcolormesh(orography)\n",
"cbar = plt.colorbar(orientation='horizontal',\n",
" label='Orography (m)')\n",
"cbar = plt.colorbar(orientation=\"horizontal\",\n",
" label=\"Orography (m)\")\n",
"# Reduce the maximum number of ticks to 5.\n",
"cbar.ax.xaxis.get_major_locator().nbins = 5\n",
"\n",
"plt.subplot(1, 2, 2)\n",
"plt.pcolormesh(sea_level_temp)\n",
"cbar = plt.colorbar(orientation='horizontal',\n",
" label='Sea level temperature (K)')\n",
"cbar = plt.colorbar(orientation=\"horizontal\",\n",
" label=\"Sea level temperature (K)\")\n",
"# Reduce the maximum number of ticks to 5.\n",
"cbar.ax.xaxis.get_major_locator().nbins = 5\n",
"\n",
Expand Down Expand Up @@ -179,17 +179,17 @@
"source": [
"plt.figure(figsize=(8, 6))\n",
"plt.fill_between(np.arange(6), np.zeros(6), orography[1, :],\n",
" color='green', linewidth=2, label='Orography')\n",
" color=\"green\", linewidth=2, label=\"Orography\")\n",
"\n",
"plt.plot(np.zeros(nx),\n",
" color='blue', linewidth=1.2,\n",
" label='Sea level')\n",
" color=\"blue\", linewidth=1.2,\n",
" label=\"Sea level\")\n",
"\n",
"for i in range(9):\n",
" plt.plot(altitude[i, 1, :], color='gray', linestyle='--',\n",
" label='Model levels' if i == 0 else None)\n",
" plt.plot(altitude[i, 1, :], color=\"gray\", linestyle=\"--\",\n",
" label=\"Model levels\" if i == 0 else None)\n",
"\n",
"plt.ylabel('altitude / m')\n",
"plt.ylabel(\"altitude / m\")\n",
"plt.margins(0.1)\n",
"plt.legend()\n",
"plt.show()"
Expand Down Expand Up @@ -250,14 +250,12 @@
}
],
"source": [
"from matplotlib.colors import LogNorm\n",
"\n",
"fig = plt.figure(figsize=(8, 6))\n",
"norm = plt.Normalize(vmin=temperature.min(), vmax=temperature.max())\n",
"\n",
"for i in range(nz):\n",
" plt.subplot(3, 3, i + 1)\n",
" qm = plt.pcolormesh(temperature[i], cmap='viridis', norm=norm)\n",
" qm = plt.pcolormesh(temperature[i], cmap=\"viridis\", norm=norm)\n",
"\n",
"plt.subplots_adjust(right=0.84, wspace=0.3, hspace=0.3)\n",
"cax = plt.axes([0.85, 0.1, 0.03, 0.8])\n",
Expand Down Expand Up @@ -332,21 +330,21 @@
"source": [
"plt.figure(figsize=(8, 6))\n",
"plt.fill_between(np.arange(6), np.zeros(6), orography[1, :],\n",
" color='green', linewidth=2, label='Orography')\n",
" color=\"green\", linewidth=2, label=\"Orography\")\n",
"\n",
"for i in range(9):\n",
" plt.plot(altitude[i, 1, :],\n",
" color='gray', lw=1.2,\n",
" label=None if i > 0 else 'Source levels \\n(model levels)')\n",
" color=\"gray\", lw=1.2,\n",
" label=None if i > 0 else \"Source levels \\n(model levels)\")\n",
"for i, target in enumerate(target_altitudes):\n",
" plt.plot(np.repeat(target, 6),\n",
" color='gray', linestyle='--', lw=1.4, alpha=0.6,\n",
" label=None if i > 0 else 'Target levels \\n(altitude)')\n",
" color=\"gray\", linestyle=\"--\", lw=1.4, alpha=0.6,\n",
" label=None if i > 0 else \"Target levels \\n(altitude)\")\n",
"\n",
"plt.ylabel('height / m')\n",
"plt.ylabel(\"height / m\")\n",
"plt.margins(top=0.1)\n",
"plt.legend()\n",
"plt.savefig('summary.png')\n",
"plt.savefig(\"summary.png\")\n",
"plt.show()"
]
},
Expand Down Expand Up @@ -413,7 +411,7 @@
"plt.figure(figsize=(8, 6))\n",
"ax1 = plt.subplot(1, 2, 1)\n",
"plt.fill_between(np.arange(6), np.zeros(6), orography[1, :],\n",
" color='green', linewidth=2, label='Orography')\n",
" color=\"green\", linewidth=2, label=\"Orography\")\n",
"cs = plt.contourf(np.tile(np.arange(6), nz).reshape(nz, 6),\n",
" altitude[:, 1],\n",
" temperature[:, 1])\n",
Expand All @@ -423,7 +421,7 @@
"\n",
"plt.subplot(1, 2, 2, sharey=ax1)\n",
"plt.fill_between(np.arange(6), np.zeros(6), orography[1, :],\n",
" color='green', linewidth=2, label='Orography')\n",
" color=\"green\", linewidth=2, label=\"Orography\")\n",
"plt.contourf(np.arange(6), target_altitudes,\n",
" np.ma.masked_invalid(new_temperature[:, 1]),\n",
" cmap=cs.cmap, norm=cs.norm)\n",
Expand All @@ -432,9 +430,9 @@
" c=new_temperature[:, 1])\n",
"plt.scatter(np.tile(np.arange(nx), target_nz).reshape(target_nz, nx),\n",
" np.repeat(target_altitudes, nx).reshape(target_nz, nx),\n",
" s=np.isnan(new_temperature[:, 1]) * 15, marker='x')\n",
" s=np.isnan(new_temperature[:, 1]) * 15, marker=\"x\")\n",
"\n",
"plt.suptitle('Temperature cross-section before and after restratification')\n",
"plt.suptitle(\"Temperature cross-section before and after restratification\")\n",
"plt.show()"
]
}
Expand Down
69 changes: 13 additions & 56 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ preview = false

[tool.ruff.lint]
ignore = [

# flake8-annotations (ANN)
# https://docs.astral.sh/ruff/rules/#flake8-annotations-ann
"ANN001", # Missing type annotation for function argument {name}
Expand All @@ -156,21 +155,6 @@ ignore = [
"ANN204", # Missing return type annotation for special method {name}

"ARG002", # Unused method argument: {name}
"ARG003", # Unused class method argument: {name}

# flake8-bugbear (B)
# https://docs.astral.sh/ruff/rules/#flake8-bugbear-b
"B028", # No explicit stacklevel keyword argument found

# flake8-comprehensions (C4)
# https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4
"C405", # Unnecessary {obj_type} literal (rewrite as a set literal)
"C419", # Unnecessary list comprehension

# flake8-commas (COM)
# https://docs.astral.sh/ruff/rules/#flake8-commas-com
"COM812", # Trailing comma missing.
"COM819", # Trailing comma prohibited.

# pydocstyle (D)
# https://docs.astral.sh/ruff/rules/#pydocstyle-d
Expand All @@ -185,46 +169,12 @@ ignore = [
# https://docs.astral.sh/ruff/rules/#eradicate-era
"ERA001", # Found commented-out code

# flake8-boolean-trap (FBT)
# https://docs.astral.sh/ruff/rules/#flake8-boolean-trap-fbt
"FBT002", # Boolean default positional argument in function definition

# flake8-implicit-str-concat (ISC)
# https://docs.astral.sh/ruff/rules/single-line-implicit-string-concatenation/
# NOTE: This rule may cause conflicts when used with "ruff format".
"ISC001", # Implicitly concatenate string literals on one line.

# pep8-naming (N)
# https://docs.astral.sh/ruff/rules/#pep8-naming-n
"N801", # Class name {name} should use CapWords convention
"PT011",

# Refactor (R)
# https://docs.astral.sh/ruff/rules/#refactor-r
"PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable

# flake8-pytest-style (PT)
# https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt
"PT009", # Use a regular assert instead of unittest-style {assertion}
"PT027", # Use pytest.raises instead of unittest-style {assertion}

# flake8-return (RET)
# https://docs.astral.sh/ruff/rules/#flake8-return-ret
"RET504", # Unnecessary assignment to {name} before return statement

# Ruff-specific rules (RUF)
# https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
"RUF005", # Consider {expression} instead of concatenation
"RUF012", # Mutable class attributes should be annotated with typing.ClassVar

# flake8-self (SLF)
# https://docs.astral.sh/ruff/rules/#flake8-self-slf
"SLF001", # Private member accessed: {access}
"PLR2004",

# flake8-print (T20)
# https://docs.astral.sh/ruff/rules/#flake8-print-t20
"T201", # print found

]
"S101", # Use of assert detected.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You needed to add this one, presumably because you've been converting the tests to use PyTest? If you, could you move this exception down to the file-specific exceptions for src/stratify/tests/*.py?

]
preview = false
select = [
"ALL",
Expand All @@ -239,13 +189,20 @@ known-first-party = ["iris"]

[tool.ruff.lint.per-file-ignores]
# All test scripts

"src/stratify/tests/performance.py" = [
"T201", # print found
]
"setup.py" = [
"T201", # print found
"RUF012", # Mutable class attributes should be annotated with typing.ClassVar
]
# Change to match specific package path:
"lib/iris/tests/*.py" = [
"src/stratify/tests/*.py" = [
# https://docs.astral.sh/ruff/rules/undocumented-public-module/
"D100", # Missing docstring in public module
"D205", # 1 blank line required between summary line and description
"D401", # 1 First line of docstring should be in imperative mood
"SLF001", # Private member accessed: {access}
]

[tool.ruff.lint.pydocstyle]
Expand Down
10 changes: 6 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from Cython.Build import cythonize # isort:skip
except ImportError:
wmsg = "Cython unavailable, unable to build stratify extensions!"
warnings.warn(wmsg)
warnings.warn(wmsg, stacklevel=2)
cythonize = None


Expand Down Expand Up @@ -53,7 +53,7 @@ def run(self):
[
("CYTHON_TRACE", "1"),
("CYTHON_TRACE_NOGIL", "1"),
]
],
)
cython_directives.update({"linetrace": True})
if FLAG_COVERAGE in sys.argv:
Expand All @@ -70,9 +70,11 @@ def run(self):
)
extensions.append(extension)

if cythonize and not any([arg in CMDS_NOCYTHONIZE for arg in sys.argv]):
if cythonize and not any(arg in CMDS_NOCYTHONIZE for arg in sys.argv):
extensions = cythonize(
extensions, compiler_directives=cython_directives, language_level=3
extensions,
compiler_directives=cython_directives,
language_level=3,
)

cmdclass = {"clean_cython": CleanCython}
Expand Down
26 changes: 15 additions & 11 deletions src/stratify/_bounded_vinterp.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def interpolate_conservative(z_target, z_src, fz_src, axis=-1):
)
raise ValueError(msg.format(tuple(dat_shape), tuple(src_shape)))

if z_src.shape[-1] != 2:
shape_2 = 2
if z_src.shape[-1] != shape_2:
msg = "Unexpected source and target bounds shape. shape[-1] != 2"
raise ValueError(msg)

Expand All @@ -100,8 +101,8 @@ def interpolate_conservative(z_target, z_src, fz_src, axis=-1):

# src_data
bdims = list(range(fz_src.ndim - (z_src.ndim - 1)))
data_vdims = [ind for ind in range(fz_src.ndim) if ind not in (bdims + [axis])]
data_transpose = bdims + [axis] + data_vdims
data_vdims = [ind for ind in range(fz_src.ndim) if ind not in ([*bdims, axis])]
data_transpose = [*bdims, axis, *data_vdims]
fz_src_reshaped = np.transpose(fz_src, data_transpose)
fz_src_orig = list(fz_src_reshaped.shape)
shape = (
Expand All @@ -113,21 +114,24 @@ def interpolate_conservative(z_target, z_src, fz_src, axis=-1):

# Define our src and target bounds in a consistent way.
# [axis_interpolation, z_varying, 2]
vdims = list(set(range(z_src.ndim)) - set([axis_relative]))
z_src_reshaped = np.transpose(z_src, [axis_relative] + vdims)
z_target_reshaped = np.transpose(z_target, [axis_relative] + vdims)
# vdims = list(set(range(z_src.ndim)) - set([axis_relative]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented code

vdims = list(set(range(z_src.ndim)) - {axis_relative})
z_src_reshaped = np.transpose(z_src, [axis_relative, *vdims])
z_target_reshaped = np.transpose(z_target, [axis_relative, *vdims])

shape = int(np.prod(z_src_reshaped.shape[1:-1]))
z_src_reshaped = z_src_reshaped.reshape(
[z_src_reshaped.shape[0], shape, z_src_reshaped.shape[-1]]
[z_src_reshaped.shape[0], shape, z_src_reshaped.shape[-1]],
)
shape = int(np.prod(z_target_reshaped.shape[1:-1]))
z_target_reshaped = z_target_reshaped.reshape(
[z_target_reshaped.shape[0], shape, z_target_reshaped.shape[-1]]
[z_target_reshaped.shape[0], shape, z_target_reshaped.shape[-1]],
)

result = conservative_interpolation(
z_src_reshaped, z_target_reshaped, fz_src_reshaped
z_src_reshaped,
z_target_reshaped,
fz_src_reshaped,
)

# Turn the result into a shape consistent with the source.
Expand All @@ -136,5 +140,5 @@ def interpolate_conservative(z_target, z_src, fz_src, axis=-1):
shape[len(bdims)] = z_target.shape[axis_relative]
result = result.reshape(shape)
invert_transpose = [data_transpose.index(ind) for ind in list(range(result.ndim))]
result = result.transpose(invert_transpose)
return result
# result = result.transpose(invert_transpose)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented code

return result.transpose(invert_transpose)
2 changes: 1 addition & 1 deletion src/stratify/tests/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import stratify


def src_data(shape=(400, 500, 100), lazy=False):
def src_data(shape=(400, 500, 100)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure of the design of this module, but lazy is very much still used. Seems like Ruff made a mistake?

lazy = "lazy" in sys.argv[1:]
interp_and_extrap(shape=(500, 600, 100), lazy=lazy)

def interp_and_extrap(
shape,
lazy,
interp=stratify.INTERPOLATE_LINEAR,
extrap=stratify.EXTRAPOLATE_NEAREST,
):
z, fz = src_data(shape, lazy)

z = np.tile(np.linspace(0, 100, shape[-1]), np.prod(shape[:2])).reshape(shape)
if lazy:
fz = da.arange(np.prod(shape), dtype=np.float64).reshape(shape)
Expand Down
Loading
Loading