# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2026 Antoine COLLET
"""
Provide plotter classes.
These classes allows to wrap the creation of figures with matplotlib and to use
a unified framework.
"""
from __future__ import annotations
import abc
import copy
import platform
import warnings
from collections import ChainMap
from itertools import product
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
import matplotlib as mpl
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.figure import Figure, SubFigure
from matplotlib.legend import Legend
from matplotlib.transforms import Bbox
from numpy.typing import ArrayLike
from packaging.version import Version
from typing_extensions import Literal
from nested_grid_plotter._utils import (
add_grid_and_tick_prams_to_axis,
object_or_object_sequence_to_list,
)
# pylint: disable=C0103 # does not confrom to snake case naming style
class NestedBuilder(abc.ABC):
"""
Abstract class for nested builders.
.. versionadded:: 2.0
"""
@abc.abstractmethod
def __call__(
self,
fig: Union[Figure, SubFigure],
figname: str,
grouped_sf_dict: Dict[str, Dict[str, SubFigure]],
grouped_ax_dict: Dict[str, Dict[str, Axes]],
) -> None: ...
[docs]
class SubplotsMosaicBuilder(NestedBuilder):
"""
Args and kwargs for Figure.subfigures routine.
.. versionadded:: 2.0
"""
[docs]
def __init__(
self,
mosaic,
*,
sharex: bool = False,
sharey: bool = False,
width_ratios: Optional[ArrayLike] = None,
height_ratios: Optional[ArrayLike] = None,
empty_sentinel: Any = ".",
subplot_kw: Optional[Dict[str, Any]] = None,
per_subplot_kw: Optional[Dict[str, Any]] = None,
gridspec_kw: Optional[Dict[str, Any]] = None,
) -> None:
"""
Build a layout of Axes based on ASCII art or nested lists.
This is a helper function to build complex GridSpec layouts visually.
See :ref:`mosaic`
for an example and full API documentation
Parameters
----------
mosaic : list of list of {hashable or nested} or str
A visual layout of how you want your Axes to be arranged
labeled as strings. For example ::
x = [["A panel", "A panel", "edge"], ["C panel", ".", "edge"]]
produces 4 Axes:
- 'A panel' which is 1 row high and spans the first two columns
- 'edge' which is 2 rows high and is on the right edge
- 'C panel' which in 1 row and 1 column wide in the bottom left
- a blank space 1 row and 1 column wide in the bottom center
Any of the entries in the layout can be a list of lists
of the same form to create nested layouts.
If input is a str, then it can either be a multi-line string of
the form ::
'''
AAE
C.E
'''
where each character is a column and each line is a row. Or it
can be a single-line string where rows are separated by ``;``::
"AB;CC"
The string notation allows only single character Axes labels and
does not support nesting but is very terse.
The Axes identifiers may be `str` or a non-iterable hashable
object (e.g. `tuple` s may not be used).
sharex, sharey : bool, default: False
If True, the x-axis (*sharex*) or y-axis (*sharey*) will be shared
among all subplots. In that case, tick label visibility and axis
units behave as for `subplots`. If False, each subplot's x- or
y-axis will be independent.
width_ratios : array-like of length *ncols*, optional
Defines the relative widths of the columns. Each column gets a
relative width of ``width_ratios[i] / sum(width_ratios)``.
If not given, all columns will have the same width. Equivalent
to ``gridspec_kw={'width_ratios': [...]}``. In the case of nested
layouts, this argument applies only to the outer layout.
height_ratios : array-like of length *nrows*, optional
Defines the relative heights of the rows. Each row gets a
relative height of ``height_ratios[i] / sum(height_ratios)``.
If not given, all rows will have the same height. Equivalent
to ``gridspec_kw={'height_ratios': [...]}``. In the case of nested
layouts, this argument applies only to the outer layout.
subplot_kw : dict, optional
Dictionary with keywords passed to the `.Figure.add_subplot` call
used to create each subplot. These values may be overridden by
values in *per_subplot_kw*.
per_subplot_kw : dict, optional
A dictionary mapping the Axes identifiers or tuples of identifiers
to a dictionary of keyword arguments to be passed to the
`.Figure.add_subplot` call used to create each subplot. The values
in these dictionaries have precedence over the values in
*subplot_kw*.
If *mosaic* is a string, and thus all keys are single characters,
it is possible to use a single string instead of a tuple as keys;
i.e. ``"AB"`` is equivalent to ``("A", "B")``.
gridspec_kw : dict, optional
Dictionary with keywords passed to the `.GridSpec` constructor used
to create the grid the subplots are placed on. In the case of
nested layouts, this argument applies only to the outer layout.
For more complex layouts, users should use `.Figure.subfigures`
to create the nesting.
empty_sentinel : object, optional
Entry in the layout to mean "leave this space empty". Defaults
to ``'.'``. Note, if *layout* is a string, it is processed via
`inspect.cleandoc` to remove leading white space, which may
interfere with using white-space as the empty sentinel.
"""
self.mosaic = mosaic
self.sharex: bool = sharex
self.sharey: bool = sharey
self.empty_sentinel = empty_sentinel
self.subplot_kw = subplot_kw
self.per_subplot_kw = per_subplot_kw
self.empty_sentinel = empty_sentinel
self.subplot_kw = subplot_kw
self.per_subplot_kw = per_subplot_kw
# height_ratios and width ratios were introduced in matplotlib 3.6
# for 3.5, it must be added to gridspec_kw
gridspec_kw = dict(gridspec_kw or {})
if height_ratios is not None:
if "height_ratios" in gridspec_kw:
raise ValueError(
"'height_ratios' must not be defined both as "
"parameter and as key in 'gridspec_kw'"
)
gridspec_kw["height_ratios"] = height_ratios
if width_ratios is not None:
if "width_ratios" in gridspec_kw:
raise ValueError(
"'width_ratios' must not be defined both as "
"parameter and as key in 'gridspec_kw'"
)
gridspec_kw["width_ratios"] = width_ratios
self.gridspec_kw = gridspec_kw
[docs]
def __call__(
self,
fig: Union[Figure, SubFigure],
figname: str,
grouped_sf_dict: Dict[str, Dict[str, SubFigure]],
grouped_ax_dict: Dict[str, Dict[str, Axes]],
) -> None:
grouped_ax_dict[figname] = fig.subplot_mosaic(
**_make_kwargs_retrocompatible(mpl.__version__, self)
)
def _make_kwargs_retrocompatible(
mpl_version: str, builder: SubplotsMosaicBuilder
) -> Dict[str, Any]:
kwargs = dict(builder.__dict__)
# per_subplot_kw was introduced in matplotlib 3.7
if Version(mpl_version) < Version("3.7"):
if kwargs.pop("per_subplot_kw") is not None:
warnings.warn(
'Parameter "per_subplot_kw" is supported from matplotlib 3.7 '
f"while you use version {mpl_version} and it is consequently "
"ignored."
)
return kwargs
[docs]
class SubfigsBuilder(NestedBuilder):
"""
Args and kwargs for Figure.subfigures routine.
.. versionadded:: 2.0
"""
[docs]
def __init__(
self,
*,
nrows: int = 1,
ncols: int = 1,
squeeze: bool = True,
wspace: Optional[float] = None,
hspace: Optional[float] = None,
width_ratios: Optional[ArrayLike] = None,
height_ratios: Optional[ArrayLike] = None,
sub_builders: Optional[Dict[str, NestedBuilder]] = None,
**kwargs,
) -> None:
"""
Initiate the instance.
Parameters
----------
nrows, ncols : int, default: 1
Number of rows/columns of the subfigure grid.
squeeze : bool, default: True
If True, extra dimensions are squeezed out from the returned
array of subfigures.
wspace, hspace : float, default: None
The amount of width/height reserved for space between subfigures,
expressed as a fraction of the average subfigure width/height.
If not given, the values will be inferred from rcParams if using
constrained layout (see `~.ConstrainedLayoutEngine`), or zero if
not using a layout engine.
width_ratios : array-like of length *ncols*, optional
Defines the relative widths of the columns. Each column gets a
relative width of ``width_ratios[i] / sum(width_ratios)``.
If not given, all columns will have the same width.
height_ratios : array-like of length *nrows*, optional
Defines the relative heights of the rows. Each row gets a
relative height of ``height_ratios[i] / sum(height_ratios)``.
If not given, all rows will have the same height.
"""
self.nrows: int = nrows
self.ncols: int = ncols
self.sub_builders = sub_builders
self.squeeze: bool = squeeze
self.wspace = wspace
self.hspace = hspace
self.width_ratios = width_ratios
self.height_ratios = height_ratios
self.kwargs: Dict[str, Any] = kwargs
[docs]
def __call__(
self,
fig: Union[Figure, SubFigure],
figname: str,
grouped_sf_dict: Dict[str, Dict[str, SubFigure]],
grouped_ax_dict: Dict[str, Dict[str, Axes]],
) -> None:
# treat exception first
if self.sub_builders is not None:
if len(self.sub_builders) != self.nrows * self.ncols:
raise Exception(
f"Error while creating subfigures for {figname}, "
f"{len(self.sub_builders)} builders have been "
f"provided, but there are {self.nrows} rows and {self.ncols} cols, "
f"i.e., {self.nrows * self.ncols} builders expected!"
)
# Note subfigure(...) returns a SubFigure instance or a numpy array of
# subfigure with shape (nrows, ncols)
# Here with ensure a flat sequence
new_subfigs: Sequence[SubFigure] = np.asarray(
object_or_object_sequence_to_list(
fig.subfigures(
nrows=self.nrows,
ncols=self.ncols,
squeeze=self.squeeze,
wspace=self.wspace,
hspace=self.hspace,
width_ratios=self.width_ratios,
height_ratios=self.height_ratios,
**self.kwargs,
)
)
).flatten()
grouped_sf_dict[figname] = {}
if self.sub_builders is not None:
for new_subfig, (sf_name, builder) in zip(
new_subfigs, self.sub_builders.items()
):
grouped_sf_dict[figname][sf_name] = new_subfig
builder(new_subfig, sf_name, grouped_sf_dict, grouped_ax_dict)
else:
# Here we need to treat the case with and without
for n, (i, j) in enumerate(product(range(self.nrows), range(self.ncols))):
sf_name = f"{figname if figname != 'fig' else 'subfig'}_{n + 1}"
grouped_sf_dict[figname][sf_name] = new_subfigs[n]
grouped_ax_dict[sf_name] = {
f"{sf_name}_ax1-1": new_subfigs[n].add_subplot()
}
[docs]
class Plotter:
"""
General class to wrap matplotlib plots.
The paticularity is that each axis is given a unique name and each subfigure is
given a unique name too.
"""
[docs]
def __init__(
self,
fig: Optional[Figure] = None,
builder: Optional[NestedBuilder] = None,
) -> None:
"""
Initiate the instance.
Parameters
----------
fig: Optional[Figure]
Figure for the plot. One can either use `matplotlib.pyplot.figure(...)` or
`matplotlib.figure.Figure`. If None, `matplotlib.pyplot.figure(...)`
is used. The default is None.
builder: Optional[NestedBuilder]
Builder to populate the figure with subfgures and axes. See
:class:`SubfigsBuilder`_ and :class:`SubplotsMosaicBuilder`_.
Attributes
----------
fig: :class:`matplotlib.figure.Figure`
Description.
subfigs: Dict[str, :class:`matplotlib.figure.SubFigure`]
Description.
ax_dict: Dict[str, Dict[str, Axes]]
Nested dict, first level with subfigures names and second with
axes names.
"""
if fig is None:
self.fig: Figure = plt.figure(constrained_layout=True)
else:
self.fig = fig
# initiate subfigures and axes references
self.grouped_ax_dict: Dict[str, Dict[str, Axes]] = {}
self.grouped_sf_dict: Dict[str, Dict[str, SubFigure]] = {}
# build subfigures and mosaic
if builder is None:
builder = SubplotsMosaicBuilder([["ax1-1"]])
builder(self.fig, "fig", self.grouped_sf_dict, self.grouped_ax_dict)
self._check_if_subplot_names_are_unique()
# Two dict to store the handles and labels to add to the legend
# The key is the number of the id of the axis matching the handles
self._additional_handles: Dict[str, Any] = {
ax_name: [] for ax_name in self.ax_dict.keys()
}
self._additional_labels: Dict[str, Any] = {
ax_name: [] for ax_name in self.ax_dict.keys()
}
@property
def ax_dict(self) -> Dict[str, Axes]:
"""Return a flatten version of `grouped_ax_dict`."""
# we cannot use reversed because of dicts are not reversible in py3.7
# so we convert to list and reverse instead
return dict(ChainMap(*list(self.grouped_ax_dict.values())[::-1]))
@property
def sf_dict(self) -> Dict[str, SubFigure]:
"""Return a flatten version of `grouped_sf_dict`."""
# we cannot use reversed because of dicts are not reversible in py3.7
# so we convert to list and reverse instead
return dict(ChainMap(*list(self.grouped_sf_dict.values())[::-1]))
@property
def axes(self) -> List[Axes]:
"""Return all axes as a list."""
return list(self.ax_dict.values())
@property
def axes_names(self) -> List[str]:
"""Return all axes names as a list."""
return list(self.ax_dict.keys())
[docs]
def close(self) -> None:
"""Close the current figure."""
plt.close(self.fig)
[docs]
def _check_if_subplot_names_are_unique(self) -> None:
"""
Check if a subplot name is not used in two different subfigures.
This is necessary otherwise, one or more subplots would be missing
from the `ax_dict`.
Raises
------
Exception
Indicate which axis names are duplicated and on which subfigures .
Returns
-------
None
"""
temp: Dict[str, List[str]] = {}
for k, v in self.grouped_ax_dict.items():
for k2 in v.keys():
temp.setdefault(k2, []).append(k)
non_unique_keys = [k for k, v in temp.items() if len(v) > 1]
if len(non_unique_keys) == 1:
raise Exception(
f"The name {non_unique_keys} has been used in more than one subfigures!"
)
if len(non_unique_keys) > 1:
raise Exception(
f"The names {non_unique_keys} have been used in "
"more than one subfigures!"
)
def _get_bbox_extra_artists(
self, kwargs: Optional[Dict[str, Any]] = None
) -> List[Artist]:
# make sure that if a fig legend as been added, it won't be cutoff by
# the figure box
if kwargs is None:
kwargs = {}
bbox_extra_artists = [
*kwargs.get("bbox_extra_artists", ()),
*self.fig.legends,
*[lgd for fig in self.sf_dict.values() for lgd in fig.legends],
*[ax.get_legend() for ax in self.axes if ax.get_legend() is not None],
*[
artist
for ax in self.axes
for artist in ax.get_default_bbox_extra_artists()
],
# *[artist for ax in self.axes for artist in ax.artists],
]
for fig in [self.fig, *self.sf_dict.values()]:
if fig._supxlabel is not None: # type: ignore
bbox_extra_artists.append(fig._supxlabel) # type: ignore
if fig._supylabel is not None: # type: ignore
bbox_extra_artists.append(fig._supylabel) # type: ignore
if fig._suptitle is not None: # type: ignore
bbox_extra_artists.append(fig._suptitle) # type: ignore
return bbox_extra_artists # ty: ignore[invalid-return-type]
[docs]
def savefig(self, *args: Any, **kwargs: Any) -> None:
"""
Save the current figure.
Parameters
----------
*args : Any
Positional arguments for :meth:`matplotlib.figure.Figure.savefig`.
**kwargs : Any
Keywords arguments for :meth:`matplotlib.figure.Figure.savefig`.
Returns
-------
None
"""
# Ensure that all artists are saved (nothing should be cropped)
# https://stackoverflow.com/questions/9651092/my-matplotlib-pyplot-legend-is-being-cut-off/42303455
bbox_inches = kwargs.pop("bbox_inches", "tight")
kwargs.update(
{"bbox_extra_artists": tuple(self._get_bbox_extra_artists(kwargs))}
)
if len(kwargs["bbox_extra_artists"]) != 0:
if Version(platform.python_version()) < Version("3.8"): # pragma: no cover
warnings.warn(
"There are bbox extra artists to save but this is not"
" supported for python 3.7. Please use python 3.8 or above. "
)
kwargs.pop("bbox_extra_artists")
self.fig.savefig(*args, **kwargs, bbox_inches=bbox_inches)
# need this if 'transparent=True' to reset colors
self.fig.canvas.draw_idle()
[docs]
def identify_axes(self, fontsize: int = 48) -> None:
"""
Draw the label in a large font in the center of the Axes.
Parameters
----------
ax_dict : Dict[str, Axes]
Mapping between the title / label and the Axes.
fontsize : int, optional
How big the label should be.
Returns
-------
None
"""
for ax_name, ax in self.ax_dict.items():
ax.text(
0.5,
0.5,
ax_name,
transform=ax.transAxes,
ha="center",
va="center",
fontsize=fontsize,
color="darkgrey",
)
[docs]
def get_axis(self, ax_name: str) -> Axes:
"""
Get an axis from the plotter.
Parameters
----------
ax_name : str
Name of the axis to get.
Returns
-------
Axes
The desired axis.
"""
if ax_name not in self.ax_dict.keys():
raise ValueError(f'The axis "{ax_name}" does not exists!')
return self.ax_dict[ax_name]
[docs]
def get_axes(self, ax_names: Sequence[str]) -> List[Axes]:
"""
Get a sequence of axes from the plotter.
Parameters
----------
ax_names : Sequence[str]
Name of the axes to get. Must be iterable.
Returns
-------
Axes
The desired axes.
"""
return [self.get_axis(axn) for axn in ax_names]
def _iterate_subfig_grouped_axes(
self, subfig_name: str
) -> Iterator[Dict[str, Axes]]:
if subfig_name in self.grouped_ax_dict.keys():
yield self.grouped_ax_dict[subfig_name]
return
for sf_name in self.grouped_sf_dict[subfig_name].keys():
for tmp in self._iterate_subfig_grouped_axes(sf_name):
yield tmp
def get_subfigure_ax_dict(self, subfig_name: str) -> Dict[str, Axes]:
return {
k: v
for _dict in self._iterate_subfig_grouped_axes(subfig_name)
for k, v in _dict.items()
}
[docs]
def add_grid_and_tick_prams_to_all_axes(
self, subfigure_name: Optional[str] = None, **kwargs: Any
) -> None:
"""
Add a grid to all the axes of the figure.
Parameters
----------
subfigure_name : Optional[str], optional
Name of the target subfigure. If None, apply to the all figure.
The default is None.
**kwargs : Any
Keywords for `add_grid_and_thick_prams_to_axis`.
Returns
-------
None
"""
if subfigure_name is not None:
for ax in self.grouped_ax_dict[subfigure_name].values():
add_grid_and_tick_prams_to_axis(ax, **kwargs)
else:
for ax in self.ax_dict.values():
add_grid_and_tick_prams_to_axis(ax, **kwargs)
[docs]
def _get_axis_legend_items(self, ax_name: str) -> Tuple[List[Artist], List[str]]:
"""
Get the given axis legend handles and labels.
Parameters
----------
ax_name : str
Name of the axis.
Returns
-------
Tuple[List[Artist], List[str]]
Handles and labels lists.
"""
ax: Axes = self.ax_dict[ax_name]
handles, labels = ax.get_legend_handles_labels()
# Handle twin axes
if ax.figure is not None:
for other_ax in ax.figure.axes:
if other_ax is ax:
continue
if other_ax.bbox.bounds == ax.bbox.bounds:
_handles, _labels = other_ax.get_legend_handles_labels()
handles += _handles
labels += _labels
# Add the additional handles and labels of the axis
handles += self._additional_handles.get(ax_name, [])
labels += self._additional_labels.get(ax_name, [])
return handles, labels
[docs]
@staticmethod
def _remove_dulicated_legend_items(
handles: Sequence[Artist], labels: Sequence[str]
) -> Tuple[List[Artist], List[str]]:
"""Remove the duplicated legend handles and labels."""
# remove the duplicates
by_label = dict(zip(labels, handles))
return list(by_label.values()), list(by_label.keys())
[docs]
def add_fig_legend(
self,
name: Optional[str] = None,
bbox_x_shift: float = 0.0,
bbox_y_shift: float = 0.0,
loc: Literal["left", "right", "top", "bottom"] = "bottom",
is_reverse_items: bool = False,
**kwargs: Any,
) -> Optional[Legend]:
"""
Add a legend to the plot.
To the main figure or to one subfigure.
Parameters
----------
name : Optional[str], optional
The subfigure to which add the legend. If no name is given, it applies to
the all figure. Otherwise to a specific subfigure. The default is None.
bbox_x_shift : float, optional
Legend vertical shift (up oriented). The default is 0.0.
bbox_y_shift : float, optional
Legend horizontal shift (right oriented). The default is 0.0.
loc : Literal["left", "right", "top", "bottom"], optional
Side on which to place the legend box. The default is "bottom".
is_reverse_items: bool
Whether to reverse the order of items in the legend. The default is False.
.. versionadded:: 2.0
**kwargs : Any
Additional arguments for `plt.legend`.
Returns
-------
Optional[Legend]
The created legend instance. None if no handles nor labels have been found.
"""
handles, labels = self._gather_figure_legend_items(name)
if len(handles) == 0:
return None
# Get the correct object
if name is None:
obj: Union[Figure, SubFigure] = self.fig
bbox_transform = obj.transFigure
else:
obj = self.sf_dict[name]
bbox_transform = obj.transSubfigure
# Make sure that the figure of the handles is the figure of the legend
# RunTimeError Can not put single artist in more than one figure
for i, _ in enumerate(handles):
if handles[i].figure is not obj:
handles[i] = copy.copy(handles[i])
# TODO: fix this
handles[i]._parent_figure = None # ty: ignore[unresolved-attribute]
handles[i].figure = obj # ty: ignore[invalid-assignment]
# Remove a potentially existing legend
obj.legends.clear()
bbox_to_anchor: List[float] = get_bbox_to_anchor(loc)
bbox_to_anchor[0] += bbox_x_shift
bbox_to_anchor[1] += bbox_y_shift
return obj.legend(
handles[::-1] if is_reverse_items else handles,
labels[::-1] if is_reverse_items else labels,
loc="center",
bbox_to_anchor=bbox_to_anchor,
bbox_transform=bbox_transform,
**kwargs,
)
[docs]
def add_axis_legend_outside_frame(
self,
ax_name: str,
bbox_x_shift: Optional[float] = None,
bbox_y_shift: Optional[float] = None,
loc: Literal["left", "right", "top", "bottom"] = "bottom",
borderaxespad: float = 1.0,
is_reverse_items: bool = False,
**kwargs: Any,
) -> Optional[Legend]:
"""
Add a legend to the ax outside the ax frame.
Parameters
----------
ax_name : str
The name of the ax for which to add a legend outside the frame.
bbox_x_shift : float, optional
Legend vertical shift (up oriented). The default is 0.0.
bbox_y_shift : float, optional
Legend horizontal shift (right oriented). The default is 0.0.
loc : Literal["left", "right", "top", "bottom"], optional
Side on which to place the legend box. The default is "bottom".
is_reverse_items: bool
Whether to reverse the order of items in the legend. The default is False.
.. versionadded:: 2.0
**kwargs : Any
Additional arguments for `plt.legend`.
Returns
-------
Optional[Legend]
The created legend instance. None if no handles nor labels have been found.
"""
handles, labels = self._remove_dulicated_legend_items(
*self._get_axis_legend_items(ax_name)
)
# No handles = no need for a legend
if len(handles) == 0:
return None
ax = self.ax_dict[ax_name]
# get default values
bbox_to_anchor: List[float] = get_bbox_to_anchor(loc)
# user adjustment
if bbox_x_shift is not None:
bbox_to_anchor[0] += bbox_x_shift
if bbox_y_shift is not None:
bbox_to_anchor[1] += bbox_y_shift
# Generate the figure a first time
lgd = ax.legend(
handles[::-1] if is_reverse_items else handles,
labels[::-1] if is_reverse_items else labels,
loc="center",
bbox_to_anchor=bbox_to_anchor,
bbox_transform=ax.transAxes,
borderaxespad=borderaxespad,
**kwargs,
)
# Handle cases with non automatic adjustment of the legend vertical/horizontal
# position.
if loc in ["bottom", "top"]:
if bbox_x_shift is not None:
return lgd
else:
if bbox_x_shift is not None:
return lgd
# The following deals with automatic position adjustment
# First we update the figure with the layout engine, ex: ConstrainedLayout
engine = self.fig.get_layout_engine()
if engine is not None:
engine.execute(self.fig)
else:
return lgd
# TODO: need to take twin axes into account
# ax.get_shared_x_axes().get_siblings(ax)[0]
# We estimate the bbox_to_anchor adjustment from tight_boxes
# This is not exact and depends on tight or constrained layout
tot_bbox: Bbox = ax.get_tightbbox() # type: ignore
ax_bbox: Bbox = ax.get_tightbbox(
bbox_extra_artists=[
elt
for elt in ax.get_default_bbox_extra_artists()
if not isinstance(elt, Legend)
]
) # type: ignore
frame_bbox = get_frame_bbox(ax)
lgd_bbox: Bbox = lgd.get_tightbbox() # type: ignore
# take the legend border axespad into account.
pad: float = lgd.borderaxespad * lgd.prop.get_size()
if loc in ["left", "right"]:
# total x extent for the axis
dx_tot = tot_bbox.xmax - tot_bbox.xmin
# x extent of the frame + ylabels + yticks etc. Not legend.
dx_ax = ax_bbox.xmax - ax_bbox.xmin
# x extent of ylabels + yticks etc.
dx_ax_no_frame = dx_ax - (frame_bbox.xmax - frame_bbox.xmin)
# x extent of the created legend
dx_lgd = lgd_bbox.xmax - lgd_bbox.xmin
# we must consider the transform (tight or constrained)
den: float = dx_tot - dx_lgd - dx_ax_no_frame
if loc in ["left"]:
bbox_to_anchor[0] -= (
dx_lgd / 2 + frame_bbox.xmin - ax_bbox.xmin + pad
) / den
else:
bbox_to_anchor[0] += (
dx_lgd / 2 + ax_bbox.xmax - frame_bbox.xmax + pad
) / den
else:
# total y extent for the axis
dy_tot = tot_bbox.ymax - tot_bbox.ymin
# y extent of the frame + xlabels + xticks etc. Not legend.
dy_ax = ax_bbox.ymax - ax_bbox.ymin
# y extent of xlabels + xticks etc.
dy_ax_no_frame = dy_ax - (frame_bbox.ymax - frame_bbox.ymin)
# y extent of the created legend
dy_lgd = lgd_bbox.ymax - lgd_bbox.ymin
# we must consider the transform (tight or constrained)
den = dy_tot - dy_lgd - dy_ax_no_frame
if loc in ["bottom"]:
bbox_to_anchor[1] -= (
dy_lgd / 2 + frame_bbox.ymin - ax_bbox.ymin + pad
) / den
else:
bbox_to_anchor[1] += (
dy_lgd / 2 + ax_bbox.ymax - frame_bbox.ymax + pad
) / den
# while is_lgd_overlapping_axis(ax, lgd):
# if loc in ["bottom"]:
# bbox_to_anchor[1] -= 0.025
lgd = ax.legend(
handles[::-1] if is_reverse_items else handles,
labels[::-1] if is_reverse_items else labels,
loc="center",
bbox_to_anchor=bbox_to_anchor,
bbox_transform=ax.transAxes,
borderaxespad=borderaxespad,
**kwargs,
)
# if lyte is not None:
# lyte.execute(self.fig)
return lgd
[docs]
def add_axis_legend(
self, ax_name: str, is_reverse_items: bool = False, **kwargs: Any
) -> Tuple[List[Any], List[str]]:
"""
Add a legend to the graphic.
Parameters
----------
ax_name : str
Ax for which to add the legend.
is_reverse_items: bool
Whether to reverse the order of items in the legend. The default is False.
.. versionadded:: 2.0
**kwargs : Any
Additional arguments for `plt.legend`.
Returns
-------
(Tuple[List[Any], List[str]])
Tuple of list of handles and list of associated labels.
"""
handles, labels = self._remove_dulicated_legend_items(
*self._get_axis_legend_items(ax_name)
)
self.ax_dict[ax_name].legend(
handles[::-1] if is_reverse_items else handles,
labels[::-1] if is_reverse_items else labels,
**kwargs,
)
return handles, labels
[docs]
def clear_all_axes(self) -> None:
"""
Clear all the axes from their data and layouts.
It also resets the additional handles and labels for the legend.
"""
for ax in self.ax_dict.values():
ax.clear()
# Also clear the elements of the legend
self._additional_handles = {}
self._additional_labels = {}
# Remove a potentially existing legends on fig and subfigs
self.clear_fig_legends()
[docs]
def clear_fig_legends(self) -> None:
"""Remove all added figure legends"""
# Remove a potentially existing legends on fig and subfigs
self.fig.legends.clear()
for subfig in self.sf_dict.values():
subfig.legends.clear()
[docs]
class NestedGridPlotter(Plotter):
"""
Alias for `Plotter` for backward compatibility.
"""
def get_bbox_to_anchor(loc: str) -> List[float]:
try:
return {
"left": [0.0, 0.5],
"right": [1.0, 0.5],
"top": [0.5, 1.0],
"bottom": [0.5, 0.0],
}[loc]
except KeyError as e:
raise ValueError(
'authorized values for loc are "left", "right", "top", "bottom"!'
) from e
def get_frame_bbox(ax: Axes) -> Bbox:
return Bbox(
[
[
ax.spines.right.get_window_extent().xmax,
ax.spines.top.get_window_extent().ymax,
],
[
ax.spines.left.get_window_extent().xmin,
ax.spines.bottom.get_window_extent().ymin,
],
]
)
def is_lgd_overlapping_axis(ax: Axes, lgd: Legend) -> bool:
ax_bbox = ax.get_tightbbox(
bbox_extra_artists=[
elt
for elt in ax.get_default_bbox_extra_artists()
if not isinstance(elt, Legend)
]
)
lgd_bbox = lgd.get_tightbbox()
if ax_bbox is None or lgd_bbox is None: # pragma: no cover
return False
# take the legend border axespad into account.
pad = lgd.borderaxespad * lgd.prop.get_size()
return (
ax_bbox.xmax >= lgd_bbox.xmin - pad and lgd_bbox.xmax + pad >= ax_bbox.xmin
) and (ax_bbox.ymax >= lgd_bbox.ymin - pad and lgd_bbox.ymax + pad >= ax_bbox.ymin)