Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions ultraplot/axes/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sys
from collections.abc import Callable, Iterable
from numbers import Integral, Number
from typing import Any, Iterable, Mapping, Optional, Sequence, Union
from typing import Any, Iterable, Mapping, Optional, Sequence, TypeAlias, Union

import matplotlib as mpl
import matplotlib.artist as martist
Expand All @@ -29,6 +29,7 @@
import matplotlib.ticker as mticker
import numpy as np
import numpy.ma as ma
from numpy.typing import ArrayLike
from packaging import version

from .. import colors as pcolors
Expand Down Expand Up @@ -64,6 +65,12 @@
# This is half of rc['patch.linewidth'] of 0.6. Half seems like a nice default.
EDGEWIDTH = 0.3

DataInput: TypeAlias = ArrayLike
ColorTupleRGB: TypeAlias = tuple[float, float, float]
ColorTupleRGBA: TypeAlias = tuple[float, float, float, float]
ColorInput: TypeAlias = DataInput | str | ColorTupleRGB | ColorTupleRGBA | None
ParsedColor: TypeAlias = DataInput | list[str] | str | None

# Data argument docstrings
_args_1d_docstring = """
*args : {y} or {x}, {y}
Expand Down Expand Up @@ -993,7 +1000,10 @@
: array-like or color-spec, optional
The marker color(s). If this is an array matching the shape of `x` and `y`,
the colors are generated using `cmap`, `norm`, `vmin`, and `vmax`. Otherwise,
this should be a valid matplotlib color.
this should be a valid matplotlib color. To pass explicit RGB(A) colors,
use an ``N x 3`` or ``N x 4`` array, or pass a single color with `color=`.
One-dimensional numeric arrays matching the point count are interpreted as
scalar values for colormapping.
smin, smax : float, optional
The minimum and maximum marker size area in units ``points ** 2``. Ignored
if `absolute_size` is ``True``. Default value for `smin` is ``1`` and for
Expand Down Expand Up @@ -3963,7 +3973,17 @@ def _parse_2d_format(
zs = tuple(map(inputs._to_numpy_array, zs))
return (x, y, *zs, kwargs)

def _parse_color(self, x, y, c, *, apply_cycle=True, infer_rgb=False, **kwargs):
def _parse_color(
self,
x: DataInput,
y: DataInput,
c: ColorInput,
*,
apply_cycle: bool = True,
infer_rgb: bool = False,
force_cmap: bool = False,
**kwargs: Any,
) -> tuple[ParsedColor, dict[str, Any]]:
"""
Parse either a colormap or color cycler. Colormap will be discrete and fade
to subwhite luminance by default. Returns a HEX string if needed so we don't
Expand All @@ -3972,7 +3992,7 @@ def _parse_color(self, x, y, c, *, apply_cycle=True, infer_rgb=False, **kwargs):
# NOTE: This function is positioned above the _parse_cmap and _parse_cycle
# functions and helper functions.
parsers = (self._parse_cmap, *self._level_parsers)
if c is None or mcolors.is_color_like(c):
if c is None or (mcolors.is_color_like(c) and not force_cmap):
if infer_rgb and c is not None and (isinstance(c, str) and c != "none"):
c = pcolors.to_hex(c) # avoid scatter() ambiguous color warning
if apply_cycle: # False for scatter() so we can wait to get correct 'N'
Expand Down Expand Up @@ -4000,6 +4020,32 @@ def _parse_color(self, x, y, c, *, apply_cycle=True, infer_rgb=False, **kwargs):
warnings._warn_ultraplot(f"Ignoring unused keyword arg(s): {pop}")
return (c, kwargs)

def _scatter_c_is_scalar_data(
self, x: DataInput, y: DataInput, c: ColorInput
) -> bool:
"""
Return whether scatter ``c=`` should be treated as scalar data.

Matplotlib treats 1D numeric arrays matching the point count as values to
be colormapped, even though short float sequences can also look like an
RGBA tuple to ``is_color_like``. Preserve explicit RGB/RGBA arrays via the
existing ``N x 3``/``N x 4`` path and reserve this override for the 1D
numeric case only.
"""
if c is None or isinstance(c, str):
return False
values = np.asarray(c)
if values.ndim != 1 or values.size <= 1:
return False
if not np.issubdtype(values.dtype, np.number):
return False
x = np.atleast_1d(inputs._to_numpy_array(x))
y = np.atleast_1d(inputs._to_numpy_array(y))
point_count = x.shape[0]
if y.shape[0] != point_count:
return False
return values.shape[0] == point_count

@warnings._rename_kwargs("0.6.0", centers="values")
def _parse_cmap(
self,
Expand Down Expand Up @@ -5527,6 +5573,7 @@ def _apply_scatter(self, xs, ys, ss, cc, *, vert=True, **kwargs):
# Only parse color if explicitly provided
infer_rgb = True
if cc is not None:
force_cmap = self._scatter_c_is_scalar_data(xs, ys, cc)
if not isinstance(cc, str):
test = np.atleast_1d(cc)
if (
Expand All @@ -5542,6 +5589,7 @@ def _apply_scatter(self, xs, ys, ss, cc, *, vert=True, **kwargs):
inbounds=inbounds,
apply_cycle=False,
infer_rgb=infer_rgb,
force_cmap=force_cmap,
**kw,
)
# Create the cycler object by manually cycling and sanitzing the inputs
Expand Down
22 changes: 22 additions & 0 deletions ultraplot/tests/test_1dplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Test 1D plotting overrides.
"""

import warnings

import numpy as np
import numpy.ma as ma
import pandas as pd
Expand Down Expand Up @@ -378,6 +380,26 @@ def test_scatter_edgecolor_single_row():
return fig


def test_scatter_numeric_c_honors_cmap():
"""
Numeric 1D ``c`` arrays should be treated as scalar data for colormapping.
"""
fig, ax = uplt.subplots()
values = np.array([0.1, 0.2, 0.3, 0.4])
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
obj = ax.scatter(
[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 3.0, 4.0],
c=values,
cmap="turbo",
)
messages = [str(item.message) for item in caught]
assert not any("Ignoring unused keyword arg(s)" in message for message in messages)
assert "turbo" in obj.get_cmap().name
np.testing.assert_allclose(obj.get_array(), values)


@pytest.mark.mpl_image_compare
def test_scatter_inbounds():
"""
Expand Down
Loading