I am using matplotlib to make scatter plots. Each point on the scatter plot is associated with a named object. I would like to be able to see the name of an object when I hover my cursor over the point on the scatter plot associated with that object. In particular, it would be nice to be able to quickly see the names of the points that are outliers. The closest thing I have been able to find while searching here is the annotate command, but that appears to create a fixed label on the plot. Unfortunately, with the number of points that I have, the scatter plot would be unreadable if I labeled each point. Does anyone know of a way to create labels that only appear when the cursor hovers in the vicinity of that point?
It seems none of the other answers here actually answer the question. So here is a code that uses a scatter and shows an annotation upon hovering over the scatter points.
import matplotlib.pyplot as plt
import numpy as np; np.random.seed(1)
x = np.random.rand(15)
y = np.random.rand(15)
names = np.array(list("ABCDEFGHIJKLMNO"))
c = np.random.randint(1,5,size=15)
norm = plt.Normalize(1,4)
cmap = plt.cm.RdYlGn
fig,ax = plt.subplots()
sc = plt.scatter(x,y,c=c, s=100, cmap=cmap, norm=norm)
annot = ax.annotate("", xy=(0,0), xytext=(20,20),textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
def update_annot(ind):
pos = sc.get_offsets()[ind["ind"][0]]
annot.xy = pos
text = "{}, {}".format(" ".join(list(map(str,ind["ind"]))),
" ".join([names[n] for n in ind["ind"]]))
def hover(event):
vis = annot.get_visible()
if event.inaxes == ax:
cont, ind = sc.contains(event)
if cont:
if vis:
fig.canvas.mpl_connect("motion_notify_event", hover)
Because people also want to use this solution for a line plot instead of a scatter, the following would be the same solution for plot (which works slightly differently).
import matplotlib.pyplot as plt
import numpy as np; np.random.seed(1)
x = np.sort(np.random.rand(15))
y = np.sort(np.random.rand(15))
names = np.array(list("ABCDEFGHIJKLMNO"))
norm = plt.Normalize(1,4)
cmap = plt.cm.RdYlGn
fig,ax = plt.subplots()
line, = plt.plot(x,y, marker="o")
annot = ax.annotate("", xy=(0,0), xytext=(-20,20),textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
def update_annot(ind):
x,y = line.get_data()
annot.xy = (x[ind["ind"][0]], y[ind["ind"][0]])
text = "{}, {}".format(" ".join(list(map(str,ind["ind"]))),
" ".join([names[n] for n in ind["ind"]]))
def hover(event):
vis = annot.get_visible()
if event.inaxes == ax:
cont, ind = line.contains(event)
if cont:
if vis:
fig.canvas.mpl_connect("motion_notify_event", hover)
In case someone is looking for a solution for lines in twin axes, refer to How to make labels appear when hovering over a point in multiple axis?
In case someone is looking for a solution for bar plots, please refer to e.g. this answer.
This solution works when hovering a line without the need to click it:
import matplotlib.pyplot as plt
# Need to create as global variable so our callback(on_plot_hover) can access
fig = plt.figure()
plot = fig.add_subplot(111)
# create some curves
for i in range(4):
# Giving unique ids to each data member
def on_plot_hover(event):
# Iterating over each data member plotted
for curve in plot.get_lines():
# Searching which data member corresponds to current mouse position
if curve.contains(event)[0]:
print("over %s" % curve.get_gid())
fig.canvas.mpl_connect('motion_notify_event', on_plot_hover)
From http://matplotlib.sourceforge.net/examples/event_handling/pick_event_demo.html :
from matplotlib.pyplot import figure, show
import numpy as npy
from numpy.random import rand
if 1: # picking on a scatter plot (matplotlib.collections.RegularPolyCollection)
x, y, c, s = rand(4, 100)
def onpick3(event):
ind = event.ind
print('onpick3 scatter:', ind, npy.take(x, ind), npy.take(y, ind))
fig = figure()
ax1 = fig.add_subplot(111)
col = ax1.scatter(x, y, 100*s, c, picker=True)
fig.canvas.mpl_connect('pick_event', onpick3)
This recipe draws an annotation on picking a data point: http://scipy-cookbook.readthedocs.io/items/Matplotlib_Interactive_Plotting.html .
This recipe draws a tooltip, but it requires wxPython:
Point and line tooltips in matplotlib?
The easiest option is to use the mplcursors package.
mplcursors: read the docs
mplcursors: github
If using Anaconda, install with these instructions, otherwise use these instructions for pip.
This must be plotted in an interactive window, not inline.
For jupyter, executing something like %matplotlib qt in a cell will turn on interactive plotting. See How can I open the interactive matplotlib window in IPython notebook?
Tested in python 3.10, pandas 1.4.2, matplotlib 3.5.1, seaborn 0.11.2
import matplotlib.pyplot as plt
import pandas_datareader as web # only for test data; must be installed with conda or pip
from mplcursors import cursor # separate package must be installed
# reproducible sample data as a pandas dataframe
df = web.DataReader('aapl', data_source='yahoo', start='2021-03-09', end='2022-06-13')
plt.figure(figsize=(12, 7))
plt.plot(df.index, df.Close)
ax = df.plot(y='Close', figsize=(10, 7))
Works with axes-level plots like sns.lineplot, and figure-level plots like sns.relplot.
import seaborn as sns
# load sample data
tips = sns.load_dataset('tips')
sns.relplot(data=tips, x="total_bill", y="tip", hue="day", col="time")
The other answers did not address my need for properly showing tooltips in a recent version of Jupyter inline matplotlib figure. This one works though:
import matplotlib.pyplot as plt
import numpy as np
import mplcursors
fig, ax = plt.subplots()
ax.scatter(*np.random.random((2, 26)))
ax.set_title("Mouse over a point")
crs = mplcursors.cursor(ax,hover=True)
crs.connect("add", lambda sel: sel.annotation.set_text(
'Point {},{}'.format(sel.target[0], sel.target[1])))
Leading to something like the following picture when going over a point with mouse:
A slight edit on an example provided in http://matplotlib.org/users/shell.html:
import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_title('click on points')
line, = ax.plot(np.random.rand(100), '-', picker=5) # 5 points tolerance
def onpick(event):
thisline = event.artist
xdata = thisline.get_xdata()
ydata = thisline.get_ydata()
ind = event.ind
print('onpick points:', *zip(xdata[ind], ydata[ind]))
fig.canvas.mpl_connect('pick_event', onpick)
This plots a straight line plot, as Sohaib was asking
mpld3 solve it for me.
import matplotlib.pyplot as plt
import numpy as np
import mpld3
fig, ax = plt.subplots(subplot_kw=dict(axisbg='#EEEEEE'))
N = 100
scatter = ax.scatter(np.random.normal(size=N),
s=1000 * np.random.random(size=N),
ax.grid(color='white', linestyle='solid')
ax.set_title("Scatter Plot (with tooltips!)", size=20)
labels = ['point {0}'.format(i + 1) for i in range(N)]
tooltip = mpld3.plugins.PointLabelTooltip(scatter, labels=labels)
mpld3.plugins.connect(fig, tooltip)
You can check this example
mplcursors worked for me. mplcursors provides clickable annotation for matplotlib. It is heavily inspired from mpldatacursor (https://github.com/joferkington/mpldatacursor), with a much simplified API
import matplotlib.pyplot as plt
import numpy as np
import mplcursors
data = np.outer(range(10), range(1, 5))
fig, ax = plt.subplots()
lines = ax.plot(data)
ax.set_title("Click somewhere on a line.\nRight-click to deselect.\n"
"Annotations can be dragged.")
mplcursors.cursor(lines) # or just mplcursors.cursor()
showing object information in matplotlib statusbar
no extra libraries needed
clean plot
no overlap of labels and artists
supports multi artist labeling
can handle artists from different plotting calls (like scatter, plot, add_patch)
code in library style
### imports
import matplotlib as mpl
import matplotlib.pylab as plt
import numpy as np
# https://stackoverflow.com/a/47166787/7128154
# https://matplotlib.org/3.3.3/api/collections_api.html#matplotlib.collections.PathCollection
# https://matplotlib.org/3.3.3/api/path_api.html#matplotlib.path.Path
# https://stackoverflow.com/questions/15876011/add-information-to-matplotlib-navigation-toolbar-status-bar
# https://stackoverflow.com/questions/36730261/matplotlib-path-contains-point
# https://stackoverflow.com/a/36335048/7128154
class StatusbarHoverManager:
Manage hover information for mpl.axes.Axes object based on appearing
ax : mpl.axes.Axes
subplot to show status information
artists : list of mpl.artist.Artist
elements on the subplot, which react to mouse over
labels : list (list of strings) or strings
each element on the top level corresponds to an artist.
if the artist has items
(i.e. second return value of contains() has key 'ind'),
the element has to be of type list.
otherwise the element if of type string
cid : to reconnect motion_notify_event
def __init__(self, ax):
assert isinstance(ax, mpl.axes.Axes)
def hover(event):
if event.inaxes != ax:
info = 'x={:.2f}, y={:.2f}'.format(event.xdata, event.ydata)
ax.format_coord = lambda x, y: info
cid = ax.figure.canvas.mpl_connect("motion_notify_event", hover)
self.ax = ax
self.cid = cid
self.artists = []
self.labels = []
def add_artist_labels(self, artist, label):
if isinstance(artist, list):
assert len(artist) == 1
artist = artist[0]
self.artists += [artist]
self.labels += [label]
def hover(event):
if event.inaxes != self.ax:
info = 'x={:.2f}, y={:.2f}'.format(event.xdata, event.ydata)
for aa, artist in enumerate(self.artists):
cont, dct = artist.contains(event)
if not cont:
inds = dct.get('ind')
if inds is not None: # artist contains items
for ii in inds:
lbl = self.labels[aa][ii]
info += '; artist [{:d}, {:d}]: {:}'.format(
aa, ii, lbl)
lbl = self.labels[aa]
info += '; artist [{:d}]: {:}'.format(aa, lbl)
self.ax.format_coord = lambda x, y: info
self.cid = self.ax.figure.canvas.mpl_connect(
"motion_notify_event", hover)
def demo_StatusbarHoverManager():
fig, ax = plt.subplots()
shm = StatusbarHoverManager(ax)
poly = mpl.patches.Polygon(
[[0,0], [3, 5], [5, 4], [6,1]], closed=True, color='green', zorder=0)
artist = ax.add_patch(poly)
shm.add_artist_labels(artist, 'polygon')
artist = ax.scatter([2.5, 1, 2, 3], [6, 1, 1, 7], c='blue', s=10**2)
lbls = ['point ' + str(ii) for ii in range(4)]
shm.add_artist_labels(artist, lbls)
artist = ax.plot(
[0, 0, 1, 5, 3], [0, 1, 1, 0, 2], marker='o', color='red')
lbls = ['segment ' + str(ii) for ii in range(5)]
shm.add_artist_labels(artist, lbls)
# --- main
if __name__== "__main__":
I have made a multi-line annotation system to add to: https://stackoverflow.com/a/47166787/10302020.
for the most up to date version:
Simply change the data in the bottom section.
import matplotlib.pyplot as plt
def update_annot(ind, line, annot, ydata):
x, y = line.get_data()
annot.xy = (x[ind["ind"][0]], y[ind["ind"][0]])
# Get x and y values, then format them to be displayed
x_values = " ".join(list(map(str, ind["ind"])))
y_values = " ".join(str(ydata[n]) for n in ind["ind"])
text = "{}, {}".format(x_values, y_values)
def hover(event, line_info):
line, annot, ydata = line_info
vis = annot.get_visible()
if event.inaxes == ax:
# Draw annotations if cursor in right position
cont, ind = line.contains(event)
if cont:
update_annot(ind, line, annot, ydata)
# Don't draw annotations
if vis:
def plot_line(x, y):
line, = plt.plot(x, y, marker="o")
# Annotation style may be changed here
annot = ax.annotate("", xy=(0, 0), xytext=(-20, 20), textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
line_info = [line, annot, y]
lambda event: hover(event, line_info))
# Your data values to plot
x1 = range(21)
y1 = range(0, 21)
x2 = range(21)
y2 = range(0, 42, 2)
# Plot line graphs
fig, ax = plt.subplots()
plot_line(x1, y1)
plot_line(x2, y2)
Based off Markus Dutschke" and "ImportanceOfBeingErnest", I (imo) simplified the code and made it more modular.
Also this doesn't require additional packages to be installed.
import matplotlib.pylab as plt
import numpy as np
fh, ax = plt.subplots()
#Generate some data
y,x = np.histogram(np.random.randn(10000), bins=500)
x = x[:-1]
colors = ['#0000ff', '#00ff00','#ff0000']
x2, y2 = x,y/10
x3, y3 = x, np.random.randn(500)*10+40
h1 = ax.plot(x, y, color=colors[0])
h2 = ax.plot(x2, y2, color=colors[1])
h3 = ax.scatter(x3, y3, color=colors[2], s=1)
artists = h1 + h2 + [h3] #concatenating lists
labels = [list('ABCDE'*100),list('FGHIJ'*100),list('klmno'*100)] #define labels shown
#___ Initialize annotation arrow
annot = ax.annotate("", xy=(0,0), xytext=(20,20),textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
def on_plot_hover(event):
if event.inaxes != ax: #exit if mouse is not on figure
is_vis = annot.get_visible() #check if an annotation is visible
# x,y = event.xdata,event.ydata #coordinates of mouse in graph
for ii, artist in enumerate(artists):
is_contained, dct = artist.contains(event)
if('get_data' in dir(artist)): #for plot
data = list(zip(*artist.get_data()))
elif('get_offsets' in dir(artist)): #for scatter
data = artist.get_offsets().data
inds = dct['ind'] #get which data-index is under the mouse
#___ Set Annotation settings
xy = data[inds[0]] #get 1st position only
annot.xy = xy
if is_vis:
annot.set_visible(False) #disable when not hovering
fh.canvas.mpl_connect('motion_notify_event', on_plot_hover)
Giving the following result:
Maybe this helps anybody, but I have adapted the #ImportanceOfBeingErnest's answer to work with patches and classes. Features:
The entire framework is contained inside of a single class, so all of the used variables are only available within their relevant scopes.
Can create multiple distinct sets of patches
Hovering over a patch prints patch collection name and patch subname
Hovering over a patch highlights all patches of that collection by changing their edge color to black
Note: For my applications, the overlap is not relevant, thus only one object's name is displayed at a time. Feel free to extend to multiple objects if you wish, it is not too hard.
fig, ax = plt.subplots(tight_layout=True)
ap = annotated_patches(fig, ax)
ap.add_patches('Azure', 'circle', 'blue', np.random.uniform(0, 1, (4,2)), 'ABCD', 0.1)
ap.add_patches('Lava', 'rect', 'red', np.random.uniform(0, 1, (3,2)), 'EFG', 0.1, 0.05)
ap.add_patches('Emerald', 'rect', 'green', np.random.uniform(0, 1, (3,2)), 'HIJ', 0.05, 0.1)
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection
class annotated_patches:
def __init__(self, fig, ax):
self.fig = fig
self.ax = ax
self.annot = self.ax.annotate("", xy=(0,0),
textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
self.collectionsDict = {}
self.coordsDict = {}
self.namesDict = {}
self.isActiveDict = {}
self.motionCallbackID = self.fig.canvas.mpl_connect("motion_notify_event", self.hover)
def add_patches(self, groupName, kind, color, xyCoords, names, *params):
if kind=='circle':
circles = [mpatches.Circle(xy, *params, ec="none") for xy in xyCoords]
thisCollection = PatchCollection(circles, facecolor=color, alpha=0.5, edgecolor=None)
elif kind == 'rect':
rectangles = [mpatches.Rectangle(xy, *params, ec="none") for xy in xyCoords]
thisCollection = PatchCollection(rectangles, facecolor=color, alpha=0.5, edgecolor=None)
raise ValueError('Unexpected kind', kind)
self.collectionsDict[groupName] = thisCollection
self.coordsDict[groupName] = xyCoords
self.namesDict[groupName] = names
self.isActiveDict[groupName] = False
def update_annot(self, groupName, patchIdxs):
self.annot.xy = self.coordsDict[groupName][patchIdxs[0]]
self.annot.set_text(groupName + ': ' + self.namesDict[groupName][patchIdxs[0]])
# Set edge color
self.isActiveDict[groupName] = True
def hover(self, event):
vis = self.annot.get_visible()
updatedAny = False
if event.inaxes == self.ax:
for groupName, collection in self.collectionsDict.items():
cont, ind = collection.contains(event)
if cont:
self.update_annot(groupName, ind["ind"])
updatedAny = True
if self.isActiveDict[groupName]:
self.isActiveDict[groupName] = True
if (not updatedAny) and vis:
Having the geographic points with values, I would like to encode the values with colormap and customize the legend position and colormap range.
Using geopandas, I have written the following function:
def plot_continuous(df, column_values, title):
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
df.plot(ax=ax, column=column_values, cmap='OrRd', legend=True);
The colorbar by default is vertical, but I would like to make it horizontal.
In order to have a horizontal colorbar, I have written the following function:
def plot_continuous(df, column_values, title, legend_title=None):
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
x = np.array(df.geometry.apply(lambda x: x.x))
y = np.array(df.geometry.apply(lambda x: x.y))
vals = np.array(df[column_values])
sc = ax.scatter(x, y, c=vals, cmap='OrRd')
cbar = plt.colorbar(sc, orientation="horizontal")
if legend_title is not None:
The image width and height in the latter case, however, is not proportional so the output looks distorted.
Does anyone know how to customize the geographic plot and keep the width-height ratio undistorted?
This gets far simpler if you use geopandas customisation of plot()
This is documented: https://geopandas.org/en/stable/docs/user_guide/mapping.html
Below I show MWE using your function and then using geopandas. Later has scaled data correctly.
MWE of your code
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
def plot_continuous(df, column_values, title, legend_title=None):
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
x = np.array(df.geometry.apply(lambda x: x.x))
y = np.array(df.geometry.apply(lambda x: x.y))
vals = np.array(df[column_values])
sc = ax.scatter(x, y, c=vals, cmap='OrRd')
cbar = plt.colorbar(sc, orientation="horizontal")
if legend_title is not None:
cities = gpd.read_file(gpd.datasets.get_path("naturalearth_cities"))
cities["color"] = np.random.randint(1,10, len(cities))
plot_continuous(cities, "color", "my title", "Color")
use geopandas
ax = cities.plot(
legend_kwds={"label": "Color", "orientation": "horizontal"},
ax.set_title("my title")
I have an issue with customizing the legend of my plot. I did lot's of customizing but couldnt get my head around this one. I want the symbols (not the labels) to be equally spaced in the legend. As you can see in the example, the space between the circles in the legend, gets smaller as the circles get bigger.
any ideas?
Also, how can I also add a color bar (in addition to the size), with smaller circles being light red (for example) and bigger circle being blue (for example)
here is my code so far:
import pandas as pd
import matplotlib.pyplot as plt
from vega_datasets import data as vega_data
gap = pd.read_json(vega_data.gapminder.url)
df = gap.loc[gap['year'] == 2000]
fig, ax = plt.subplots(1, 1,figsize=[14,12])
ax=ax.scatter(df['life_expect'], df['fertility'],
s = df['pop']/100000,alpha=0.7, edgecolor="black",cmap="viridis")
kw = dict(prop="sizes", num=6, color="lightgrey", markeredgecolor='black',markeredgewidth=2)
plt.legend(*ax.legend_elements(**kw),bbox_to_anchor=(1, 0),frameon=False,
loc="lower left",markerscale=1,ncol=1,borderpad=2,labelspacing=4,handletextpad=2)
It's a bit tricky, but you could measure the legend elements and reposition them to have a constant inbetween distance. Due to the pixel positioning, the plot can't be resized afterwards.
I tested the code inside PyCharm with the 'Qt5Agg' backend. And in a Jupyter notebook, both with %matplotlib inline and with %matplotlib notebook. I'm not sure whether it would work well in all environments.
Note that ax.scatter doesn't return an ax (countrary to e.g. sns.scatterplot) but a list of the created scatter dots.
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.transforms import IdentityTransform
from vega_datasets import data as vega_data
gap = pd.read_json(vega_data.gapminder.url)
df = gap.loc[gap['year'] == 2000]
fig, ax = plt.subplots(1, 1, figsize=[14, 12])
scat = ax.scatter(df['life_expect'], df['fertility'],
s=df['pop'] / 100000, alpha=0.7, edgecolor="black", cmap="viridis")
x = 1.1
y = 0.1
is_first = True
kw = dict(prop="sizes", num=6, color="lightgrey", markeredgecolor='black', markeredgewidth=2)
handles, labels = scat.legend_elements(**kw)
inverted_transData = ax.transData.inverted()
for handle, label in zip(handles[::-1], labels[::-1]):
plt.setp(handle, clip_on=False)
for _ in range(1 if is_first else 2):
plt.setp(handle, transform=ax.transAxes)
if is_first:
xd, yd = x, y
xd, yd = inverted_transData.transform((x, y))
bbox = handle.get_window_extent(fig.canvas.get_renderer())
y += y - bbox.y0 + 15 # 15 pixels inbetween
x = (bbox.x0 + bbox.x1) / 2
if is_first:
xd_text, _ = inverted_transData.transform((bbox.x1+10, y))
ax.text(xd_text, yd, label, transform=ax.transAxes, ha='left', va='center')
y = bbox.y1
is_first = False
I'm trying to plot a big amount of curves in a stackplot with matplotlib, using python.
To read the graph, I need to show legends, but if I show it with the legend method, my graph is unreadable (because of the number of legends, and their size).
I have found that mplcursors could help me to do that with a popup in the graph itself. It works with "simple" plots, but not with a stackplot.
Here is the warning message with stackplots:
/usr/lib/python3.7/site-packages/mplcursors/_pick_info.py:141: UserWarning: Pick support for PolyCollection is missing.
warnings.warn(f"Pick support for {type(artist).__name__} is missing.")
And here is the code related to this error (it's only a proof of concept):
import matplotlib.pyplot as plt
import mplcursors
import numpy as np
data = np.outer(range(10), range(1, 5))
timestamp = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
tmp = list()
tmp.append(data[:, 0])
tmp.append(data[:, 1])
tmp.append(data[:, 2])
tmp.append(data[:, 3])
fig, ax = plt.subplots()
ax.stackplot(timestamp, tmp, labels=('curve1', 'line2', 'curvefever', 'whatever'))
cursor = mplcursors.cursor(hover=True)
def on_add(sel):
label = sel.artist.get_label()
Do you have an idea of how to fix that, or do you know another way to do something like that ?
It is not clear why mplcursors doesn't accept a stackplot. But you can replicate the behavior with more primitive matplotlib functionality:
import matplotlib.pyplot as plt
import numpy as np
def update_annot(label, x, y):
annot.xy = (x, y)
def on_hover(event):
visible = annot.get_visible()
is_outside_of_stackplot = True
if event.inaxes == ax:
for coll, label in zip(stckplt, labels):
contained, _ = coll.contains(event)
if contained:
update_annot(label, event.x, event.y)
is_outside_of_stackplot = False
if is_outside_of_stackplot and visible:
data = np.random.randint(1, 5, size=(4, 40))
fig, ax = plt.subplots()
labels = ('curve1', 'line2', 'curvefever', 'whatever')
stckplt = ax.stackplot(range(data.shape[1]), data, labels=labels)
ax.autoscale(enable=True, axis='x', tight=True)
# ax.legend()
annot = ax.annotate("", xy=(0, 0), xycoords="figure pixels",
xytext=(20, 20), textcoords="offset points",
bbox=dict(boxstyle="round", fc="yellow", alpha=0.6),
plt.connect('motion_notify_event', on_hover)
When placing the legend outside of the axes using bbox_to_anchor as in this answer, the space between the axes and the legend changes when the figure is resized. For static exported plots this is fine; you can simply tweak the numbers until you get it right. But for interactive plots that you might want to resize, this is a problem. As can be seen in this example:
import numpy as np
from matplotlib import pyplot as plt
x = np.arange(5)
y = np.random.randn(5)
fig, ax = plt.subplots(tight_layout=True)
ax.plot(x, y, label='data1')
ax.plot(x, y-1, label='data2')
legend = ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=2)
How can I make sure that the legend keeps the same distance from the axes even when the figure is resized?
The distance of a legend from the bounding box edge is set by the borderaxespad argument. The borderaxespad is in units of multiples of the fontsize - making it automatically independent of the axes size.
So in this case,
import matplotlib.pyplot as plt
import numpy as np
x = np.arange(5)
y = np.random.randn(5)
fig, ax = plt.subplots(constrained_layout=True)
ax.plot(x, y, label='data1')
ax.plot(x, y-1, label='data2')
legend = ax.legend(loc="upper center", bbox_to_anchor=(0.5,0), borderaxespad=2)
A similar question about showing a title below the axes at a constant distance is being asked in Place title at the bottom of the figure of an axes?
You can use the resize events of the canvas to update the values in bbox_to_anchor with each update. To calculate the new values you can use the inverse of the axes transformation (Bbox.inverse_transformed(ax.transAxes)), which translates from screen coordinates in pixels to the axes coordinates which are normally used in bbox_to_anchor.
Here is an example with support for putting the legend on all four sides of the axes:
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.transforms import Bbox
class FixedOutsideLegend:
"""A legend placed at a fixed offset (in pixels) from the axes."""
def __init__(self, ax, location, pixel_offset, **kwargs):
self._pixel_offset = pixel_offset
self.location = location
if location == 'right':
self._loc = 'center left'
elif location == 'left':
self._loc = 'center right'
elif location == 'upper':
self._loc = 'lower center'
elif location == 'lower':
self._loc = 'upper center'
raise ValueError('Unknown location: {}'.format(location))
self.legend = ax.legend(
loc=self._loc, bbox_to_anchor=self._get_bbox_to_anchor(), **kwargs)
ax.figure.canvas.mpl_connect('resize_event', self.on_resize)
def on_resize(self, event):
def _get_bbox_to_anchor(self):
Find the lengths in axes units that correspond to the specified
screen_bbox = Bbox.from_bounds(
0, 0, self._pixel_offset, self._pixel_offset)
ax_bbox = screen_bbox.inverse_transformed(ax.transAxes)
except np.linagl.LinAlgError:
ax_width = 0
ax_height = 0
ax_width = ax_bbox.width
ax_height = ax_bbox.height
if self.location == 'right':
return (1 + ax_width, 0.5)
elif self.location == 'left':
return (-ax_width, 0.5)
elif self.location == 'upper':
return (0.5, 1 + ax_height)
elif self.location == 'lower':
return (0.5, -ax_height)
x = np.arange(5)
y = np.random.randn(5)
fig, ax = plt.subplots(tight_layout=True)
ax.plot(x, y, label='data1')
ax.plot(x, y-1, label='data2')
legend = FixedOutsideLegend(ax, 'lower', 20, ncol=2)
I am using an artist animation method with 5 subplots. There is one static plot on the left, with 3 smaller animated imshow plots to the right (the colorbar is the 5th). I have successfully used ConnectionPatch to connect subplots to show where the data is coming from, but only on static plots. No matter what I try, I can't seem to get the patches to show up on the animation. I've tried to include the patch in the image artist list, tried to update the figure with the artist instead of the axis (which I guess doesn't make much sense), among other things. It will be very difficult to extract a working example due to the complexity of the plot, but maybe someone has a tip.
Could setting the facecolor to 'white' with the animation savefig_kwargs be covering up the connector lines? If so, how do I change the z order of the patch/facecolor?
Without a minimal working example, I can only tell you that it is possible to use a ConnectionPatch in an animation. However, as seen below, one has to recreate it for every frame.
import matplotlib.pyplot as plt
import numpy as np; np.random.seed(0)
import matplotlib.gridspec as gridspec
from matplotlib.patches import ConnectionPatch
import matplotlib.animation
plt.rcParams["figure.figsize"] = np.array([6,3.6])*0.7
x = np.linspace(-3,3)
X,Y = np.meshgrid(x,x)
f = lambda x,y: (1 - x / 2. + x ** 5 + y ** 3) * np.exp(-x ** 2 - y ** 2)+1.5
Z = f(X,Y)
bins=np.linspace(Z.min(), Z.max(), 16)
cols = plt.cm.PuOr((bins[:-1]-Z.min())/(Z.max()-Z.min()))
gs = gridspec.GridSpec(2, 2, height_ratios=[34,53], width_ratios=[102,53])
fig = plt.figure()
ax.imshow(Z, cmap="PuOr")
rec = plt.Rectangle([-.5,-.5], width=9, height=9, edgecolor="crimson", fill=False, lw=2)
conp = ConnectionPatch(xyA=[-0.5,0.5], xyB=[9.5,4], coordsA="data", coordsB="data",
axesA=ax3, axesB=ax, arrowstyle="-|>", zorder=25, shrinkA=0, shrinkB=1,
mutation_scale=20, fc="w", ec="crimson", lw=2)
im = ax3.imshow(Z[:9,:9], cmap="PuOr", vmin=Z.min(), vmax=Z.max())
ticks = np.array([0,4,8])
ax3.set_yticks(ticks); ax3.set_xticks(ticks)
ax2.hist(Z[:9,:9].flatten(), bins=bins)
def ins(px,py):
global rec, conp, histpatches
ll = [px-.5,py-.5]
conp = ConnectionPatch(xyA=[-0.5,0.5], xyB=[px+9.5,py+4], coordsA="data", coordsB="data",
axesA=ax3, axesB=ax, arrowstyle="-|>", zorder=25, shrinkA=0, shrinkB=1,
mutation_scale=20, fc="w", ec="crimson", lw=2)
data = Z[px:px+9,py:py+9]
h, b_, patches = ax2.hist(data.flatten(), bins=bins, ec="k", fc="#f1a142")
[pat.set_color(cols[i]) for i, pat in enumerate(patches)]
def func(p):
px,py = p
ins(px, py)
phi = np.linspace(0.,2*np.pi)
r = np.sin(2*phi)*20+np.pi/2
xr = (r*np.cos(phi)).astype(np.int8)
yr = (r*np.sin(phi)).astype(np.int8)
frames = np.c_[xr+20, yr+20]
ani = matplotlib.animation.FuncAnimation(fig, func, frames=frames, interval=300, repeat=True)