# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2026 Antoine COLLET
"""Provide some tools for 2D plots."""
import copy
import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import numpy.typing as npt
from matplotlib import colors
from matplotlib.axes import Axes
from matplotlib.colorbar import Colorbar
from matplotlib.figure import Figure, SubFigure
from matplotlib.image import AxesImage
# pylint: disable=C0103 # does not confrom to snake case naming style
# pylint: disable=R0913 # too many arguments
# pylint: disable=R0914 # too many local variables
def _apply_default_imshow_kwargs(
imshow_kwargs: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
"""Apply default values to the given imshow kwargs dictionary."""
_imshow_kwargs: dict[str, Any] = {
"interpolation": "nearest",
"cmap": "bwr",
"aspect": "auto",
"origin": "lower",
}
if imshow_kwargs is not None:
_imshow_kwargs.update(imshow_kwargs)
if not any(v in _imshow_kwargs for v in ["vmin", "vmax", "norm"]):
_imshow_kwargs.update({"norm": colors.Normalize()})
return _imshow_kwargs
def _apply_default_colorbar_kwargs(
colorbar_kwargs: Optional[Dict[str, Any]], axes: Sequence[Axes]
) -> Dict[str, Any]:
"""Apply default values to the given colorbar kwargs dictionary."""
_colorbar_kwargs: dict[str, Any] = {
"orientation": "vertical",
"aspect": 20,
"ax": np.array(axes), # Make sure to have a numpy array
}
if colorbar_kwargs is not None:
_colorbar_kwargs.update(colorbar_kwargs)
return _colorbar_kwargs
def add_2d_grid(
ax: Axes, nx: int, ny: int, kwargs: Optional[Dict[str, Any]] = None
) -> None:
"""
Add a grid to the.
Parameters
----------
ax : Axes
The axis to which add a grid.
nx : int
Number of vertical bars.
ny : int
Number of horizontal bars.
kwargs : Optional[Dict[str, Any]], optional
Optional arguments for vlines and hlines. The default is None.
Returns
-------
None.
"""
_kwargs = {"color": "grey", "linewidths": 0.5}
if kwargs is not None:
_kwargs.update(kwargs)
ax.vlines(
x=np.arange(0, nx) + 0.5,
ymin=np.full(nx, 0) - 0.5,
ymax=np.full(nx, ny) - 0.5,
**_kwargs, # ty: ignore[invalid-argument-type]
)
ax.hlines(
y=np.arange(0, ny) + 0.5,
xmin=np.full(ny, 0) - 0.5,
xmax=np.full(ny, nx) - 0.5,
**_kwargs, # ty: ignore[invalid-argument-type]
)
def _get_vmin_vmax(
data_list: List[npt.NDArray[np.float64]],
is_symmetric_cbar: bool,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
) -> Tuple[float, float]:
"""
Get vmin and vmax for the color bar scaling.
Parameters
----------
data_list : List[np.ndarray]
List of arrays containing the data.
is_symmetric_cbar : bool
Does the scale need to be symmetric and centered to zero. The default is False.
vmin: Optional[float]
Minimum value for the scale. If not provided, it is automatically derived
from the data. The default is None.
vmax: Optional[float]
Maximum value for the scale. If not provided, it is automatically derived
from the data. The default is None.
"""
if vmin is None:
vmin = np.nanmin([np.nanmin(data) for data in data_list])
if vmax is None:
vmax = np.nanmax([np.nanmax(data) for data in data_list])
if is_symmetric_cbar:
abs_norm = max(abs(vmin), abs(vmax))
vmin = -abs_norm
vmax = abs_norm
return vmin, vmax
def _check_axes_and_data_consistency(
axes: Sequence[Axes], data: Dict[str, npt.NDArray[np.float64]]
) -> None:
"""
Check that the number of axes and keys in data are the same.
Parameters
----------
axes : Sequence[Axes]
List of axes.
data : Dict[str, npt.NDArray[np.float64]]
Dictionary of data arrays.
Raises
------
ValueError
If the number of axes and keys in data are not the same.
Returns
-------
None
"""
_n_data: int = len(data.values())
_n_axes: int = len(axes)
if _n_data != _n_axes:
raise ValueError(
f"The number of axes ({_n_axes}), does not match the number "
f"of data ({_n_data})!"
)
def _norm_data_and_cbar(
images: List[AxesImage],
data: List[npt.NDArray[np.float64]],
_imshow_kwargs: Dict[str, Any],
is_symmetric_cbar: bool,
) -> None:
"""
Apply a proper scaling to the colorbar based on data and user provided norm.
Parameters
----------
images_dict : Dict[str, AxesImage]
Dict of images for which to scale the colorbar.
data_list : Dict[str, npt.NDArray[np.float64]]
Dict of arrays containing the data.
_imshow_kwargs: Dict[str, Any]
Keywords arguments for imshow.
is_symmetric_cbar : bool
Does the scale need to be symmetric and centered to zero. The default is False.
"""
norm: colors.Normalize = _imshow_kwargs.get("norm", colors.Normalize(clip=True))
if isinstance(norm, colors.LogNorm) and is_symmetric_cbar:
warnings.warn(
"You used a LogNorm norm instance which is incompatible with a"
" symmetric colorbar. Symmetry is ignored. Use SymLogNorm for"
" symmetrical logscale color bar.",
UserWarning,
)
is_symmetric_cbar = False
vmin, vmax = _get_vmin_vmax(
data,
is_symmetric_cbar,
norm.vmin if norm.vmin is not None else _imshow_kwargs.get("vmin"),
norm.vmax if norm.vmax is not None else _imshow_kwargs.get("vmax"),
)
norm.vmin = vmin
norm.vmax = vmax
for im in images:
im.set_norm(norm)
def _get_argsort_im_data(
data: Dict[str, npt.NDArray[np.float64]],
) -> npt.NDArray[np.int64]:
"""
Get the argsort for images data by increasing dimensions.
.. versionadded:: 2.0
If both x and y dimensions vary between data values, then data is return as is and
no sorting is performed.
"""
shapes = {}
for key, val in data.items():
if not len(val.shape) == 2:
raise ValueError(
f'The given data for "{key}" has dimension {len(val.shape)} '
"whereas it should be two dimensional!"
)
shapes[key] = val.shape
_shapesv = np.array(list(shapes.values()))
x_shapes = np.unique(_shapesv[:, 0])
y_shapes = np.unique(_shapesv[:, 1])
if x_shapes.size == 1 or y_shapes.size == 1:
if y_shapes.size != 1:
return np.argsort(_shapesv[:, 1])
else:
return np.argsort(_shapesv[:, 0])
if x_shapes.size != 1 and y_shapes.size != 1:
warnings.warn(
f"Data have different shapes: {shapes}. This might cause display"
" issues if some axes share xaxis or yaxis!"
)
return np.arange(len(data))
[docs]
def multi_imshow(
axes: Sequence[Axes],
fig: Union[Figure, SubFigure],
data: Dict[str, npt.NDArray[np.float64]],
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
imshow_kwargs: Optional[Dict[str, Any]] = None,
cbar_kwargs: Optional[Dict[str, Any]] = None,
is_symmetric_cbar: bool = False,
cbar_title: Optional[str] = None,
) -> Colorbar:
"""
Plot multiple 2D field with imshow using a shared and scaled colorbar.
Parameters
----------
axes : Sequence[Axes]
Sequence of axes on which to plot the given data.
fig : Figure
The figure on which to plot the data. This is useful to position correctly
the colorbar.
data : Dict[str, npt.NDArray[np.float64]]
Dictionary of data arrays. Key are used as axis title.
xlabel : Optional[str], optional
Label to apply to all xaxes. The default is None.
ylabel : Optional[str], optional
Label to apply to all yaxes. The default is None.
imshow_kwargs : Optional[Dict[str, Any]], optional
Optional arguments for `plt.imshow`. The default is None.
cbar_kwargs : Optional[Dict[str, Any]], optional
DESCRIPTION. The default is None.
is_symmetric_cbar : bool, optional
Does the scale need to be symmetric and centered to zero. The default is False.
cbar_title : Optional[str], optional
Label to give to the colorbar. The default is None.
Raises
------
ValueError
If the given data arrays do not have the required dimensionality (3).
Returns
-------
Colorbar
The color bar is returned so it can be further customized.
"""
# The number of ax_name and data provided should be the same:
_check_axes_and_data_consistency(axes, data)
# Add some default values for imshow and colorbar
_imshow_kwargs: Dict[str, Any] = _apply_default_imshow_kwargs(imshow_kwargs)
_cbar_kwargs: Dict[str, Any] = _apply_default_colorbar_kwargs(cbar_kwargs, axes)
images_dict: Dict[str, AxesImage] = {}
# order to make sure that the largest image is displayed last (just in case
# sharex or sharey is active)
_order = list(_get_argsort_im_data(data))
for (
j,
(label, values),
) in sorted(zip(_get_argsort_im_data(data), data.items())):
ax: Axes = axes[_order[j]]
# Need to transpose because the dimensions (M, N) define the rows and
# columns
# Also, need to copy the _imshow_kwargs to avoid its update. Otherwise the
# colorbar scaling does not work properly
images_dict[label] = ax.imshow(values.T, **copy.deepcopy(_imshow_kwargs))
ax.label_outer()
ax.set_title(label, fontweight="bold")
if xlabel is not None:
ax.set_xlabel(xlabel, fontweight="bold")
if ylabel is not None:
ax.set_ylabel(ylabel, fontweight="bold")
# norm both data and colobar
_norm_data_and_cbar(
list(images_dict.values()),
list(data.values()),
_imshow_kwargs,
is_symmetric_cbar,
)
cbar: Colorbar = fig.colorbar(list(images_dict.values())[0], **_cbar_kwargs)
if cbar_title is not None:
cbar.ax.get_yaxis().labelpad = 20
cbar.ax.set_ylabel(cbar_title, rotation=270)
return cbar