diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index ab23946f3..a4388d3d3 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -11,7 +11,7 @@ import sys from collections.abc import Callable, Iterable from numbers import Integral, Number -from typing import Any, Iterable, Mapping, Optional, Sequence, Union +from typing import Any, Iterable, Mapping, Optional, Sequence, TypeAlias, Union import matplotlib as mpl import matplotlib.artist as martist @@ -29,6 +29,7 @@ import matplotlib.ticker as mticker import numpy as np import numpy.ma as ma +from numpy.typing import ArrayLike from packaging import version from .. import colors as pcolors @@ -64,6 +65,12 @@ # This is half of rc['patch.linewidth'] of 0.6. Half seems like a nice default. EDGEWIDTH = 0.3 +DataInput: TypeAlias = ArrayLike +ColorTupleRGB: TypeAlias = tuple[float, float, float] +ColorTupleRGBA: TypeAlias = tuple[float, float, float, float] +ColorInput: TypeAlias = DataInput | str | ColorTupleRGB | ColorTupleRGBA | None +ParsedColor: TypeAlias = DataInput | list[str] | str | None + # Data argument docstrings _args_1d_docstring = """ *args : {y} or {x}, {y} @@ -993,7 +1000,10 @@ : array-like or color-spec, optional The marker color(s). If this is an array matching the shape of `x` and `y`, the colors are generated using `cmap`, `norm`, `vmin`, and `vmax`. Otherwise, - this should be a valid matplotlib color. + this should be a valid matplotlib color. To pass explicit RGB(A) colors, + use an ``N x 3`` or ``N x 4`` array, or pass a single color with `color=`. + One-dimensional numeric arrays matching the point count are interpreted as + scalar values for colormapping. smin, smax : float, optional The minimum and maximum marker size area in units ``points ** 2``. Ignored if `absolute_size` is ``True``. Default value for `smin` is ``1`` and for @@ -3963,7 +3973,17 @@ def _parse_2d_format( zs = tuple(map(inputs._to_numpy_array, zs)) return (x, y, *zs, kwargs) - def _parse_color(self, x, y, c, *, apply_cycle=True, infer_rgb=False, **kwargs): + def _parse_color( + self, + x: DataInput, + y: DataInput, + c: ColorInput, + *, + apply_cycle: bool = True, + infer_rgb: bool = False, + force_cmap: bool = False, + **kwargs: Any, + ) -> tuple[ParsedColor, dict[str, Any]]: """ Parse either a colormap or color cycler. Colormap will be discrete and fade to subwhite luminance by default. Returns a HEX string if needed so we don't @@ -3972,7 +3992,7 @@ def _parse_color(self, x, y, c, *, apply_cycle=True, infer_rgb=False, **kwargs): # NOTE: This function is positioned above the _parse_cmap and _parse_cycle # functions and helper functions. parsers = (self._parse_cmap, *self._level_parsers) - if c is None or mcolors.is_color_like(c): + if c is None or (mcolors.is_color_like(c) and not force_cmap): if infer_rgb and c is not None and (isinstance(c, str) and c != "none"): c = pcolors.to_hex(c) # avoid scatter() ambiguous color warning if apply_cycle: # False for scatter() so we can wait to get correct 'N' @@ -4000,6 +4020,32 @@ def _parse_color(self, x, y, c, *, apply_cycle=True, infer_rgb=False, **kwargs): warnings._warn_ultraplot(f"Ignoring unused keyword arg(s): {pop}") return (c, kwargs) + def _scatter_c_is_scalar_data( + self, x: DataInput, y: DataInput, c: ColorInput + ) -> bool: + """ + Return whether scatter ``c=`` should be treated as scalar data. + + Matplotlib treats 1D numeric arrays matching the point count as values to + be colormapped, even though short float sequences can also look like an + RGBA tuple to ``is_color_like``. Preserve explicit RGB/RGBA arrays via the + existing ``N x 3``/``N x 4`` path and reserve this override for the 1D + numeric case only. + """ + if c is None or isinstance(c, str): + return False + values = np.asarray(c) + if values.ndim != 1 or values.size <= 1: + return False + if not np.issubdtype(values.dtype, np.number): + return False + x = np.atleast_1d(inputs._to_numpy_array(x)) + y = np.atleast_1d(inputs._to_numpy_array(y)) + point_count = x.shape[0] + if y.shape[0] != point_count: + return False + return values.shape[0] == point_count + @warnings._rename_kwargs("0.6.0", centers="values") def _parse_cmap( self, @@ -5527,6 +5573,7 @@ def _apply_scatter(self, xs, ys, ss, cc, *, vert=True, **kwargs): # Only parse color if explicitly provided infer_rgb = True if cc is not None: + force_cmap = self._scatter_c_is_scalar_data(xs, ys, cc) if not isinstance(cc, str): test = np.atleast_1d(cc) if ( @@ -5542,6 +5589,7 @@ def _apply_scatter(self, xs, ys, ss, cc, *, vert=True, **kwargs): inbounds=inbounds, apply_cycle=False, infer_rgb=infer_rgb, + force_cmap=force_cmap, **kw, ) # Create the cycler object by manually cycling and sanitzing the inputs diff --git a/ultraplot/tests/test_1dplots.py b/ultraplot/tests/test_1dplots.py index 257da91a0..d63256a52 100644 --- a/ultraplot/tests/test_1dplots.py +++ b/ultraplot/tests/test_1dplots.py @@ -3,6 +3,8 @@ Test 1D plotting overrides. """ +import warnings + import numpy as np import numpy.ma as ma import pandas as pd @@ -378,6 +380,26 @@ def test_scatter_edgecolor_single_row(): return fig +def test_scatter_numeric_c_honors_cmap(): + """ + Numeric 1D ``c`` arrays should be treated as scalar data for colormapping. + """ + fig, ax = uplt.subplots() + values = np.array([0.1, 0.2, 0.3, 0.4]) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + obj = ax.scatter( + [1.0, 2.0, 3.0, 4.0], + [1.0, 2.0, 3.0, 4.0], + c=values, + cmap="turbo", + ) + messages = [str(item.message) for item in caught] + assert not any("Ignoring unused keyword arg(s)" in message for message in messages) + assert "turbo" in obj.get_cmap().name + np.testing.assert_allclose(obj.get_array(), values) + + @pytest.mark.mpl_image_compare def test_scatter_inbounds(): """