Retrieve all colorbars in figure - python

I'm trying to manipulate all colorbar instances contained in a figure. There is fig.get_axes() to obtain a list of axes, but I cannot find anything similar for colorbars.
This answer, https://stackoverflow.com/a/19817573/7042795, only applies to special situations, but not the general case.
Consider this MWE:
import matplotlib.pyplot as plt
import numpy as np
data = np.random.random((10,10)) # Generate some random data to plot
fig, axs = plt.subplots(1,2)
im1 = axs[0].imshow(data)
cbar1 = fig.colorbar(im1)
im2 = axs[1].imshow(2*data)
cbar2 = fig.colorbar(im2)
fig.show()
How can I get cbar1 and cbar2 from fig?
What I need is a function like:
def get_colorbars(fig):
cbars = fig.get_colorbars()
return cbars
cbars = get_colorbars(fig)

You would have no choice but to check each object present in the figure whether it has a colorbar or not. This could look as follows:
def get_colorbars(fig):
cbs = []
for ax in fig.axes:
cbs.extend(ax.findobj(lambda obj: hasattr(obj, "colorbar") and obj.colorbar))
return [a.colorbar for a in cbs]
This will give you all the colorbars that are tied to an artist. There may be more colorbars in the figure though, e.g. created directly from a ScalarMappble or multiple colorbars for the same object; those cannot be found.

Since the only place I'm reasonably sure that colorbar references are retained is as an attribute of the artist they are tied to, the best solution I could think of is to search all artists in a figure. This is best done recursively:
def get_colorbars(fig):
def check_kids(obj, bars):
for child in obj.get_children():
if isinstance(getattr(child, 'colorbar', None), Colorbar):
bars.append(child.colorbar)
check_kids(child, bars)
return bars
return check_kids(fig, [])
I have not had a chance to test this code, but it should at least point you in the right direction.

Related

Matplotlib: Plotting multiple histograms in plt.subplots

Background of the problem:
I'm working on a class that takes an Axes object as constructor parameter and produces a (m,n) dimension figure with a histogram in each cell, kind of like the figure below:
There are two important things to note here, that I'm not allowed to modified in any way:
The Figure object is not passed as a constructor parameter; only the Axes object is. So the subplots object cannot be modified in any way.
The Axes parameter is set to that of a (1,1) figure, by default (as below). All the modification required to make it an (m,n) figure are performed within the class (inside its methods)
_, ax = plt.subplots() # By default takes (1,1) dimension
cm = ClassName(model, ax=ax, histogram=True) # calling my class
What I'm stuck on:
Since I want to plot histograms within each cell, I decided to approach it by looping over each cell and creating a histogram for each.
results[col].hist(ax=self.ax[y,x], bins=bins)
However, I'm not able to specify the axes of the histogram in any way. This is because the Axes parameter passed is of default dimension (1,1) and hence not index-able. When I try this I get a TypeError saying.
TypeError: 'AxesSubplot' object is not subscriptable
With all this considered, I would like to know of any possible ways I can add my histogram to the parent Axes object. Thanks for taking a look.
The requirement is pretty strict and maybe not the best design choice. Because you later want to plot several subplots at the position of a single subplot, this single subplot is only created for the sole purpose of dying and being replaced a few moments later.
So what you can do is obtain the position of the axes you pass in and create a new gridspec at that position. Then remove the original axes and create a new set of axes at within that newly created gridspec.
The following would be an example. Note that it currently requires that the axes to be passed in is a Subplot (as opposed to any axes).
It also hardcodes the number of plots to be 2*2. In the real use case you would probably derive that number from the model you pass in.
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import gridspec
class ClassName():
def __init__(self, model, ax=None, **kwargs):
ax = ax or plt.gca()
if not hasattr(ax, "get_gridspec"):
raise ValueError("Axes needs to be a subplot")
parentgs = ax.get_gridspec()
q = ax.get_geometry()[-1]
# Geometry of subplots
m, n = 2, 2
gs = gridspec.GridSpecFromSubplotSpec(m,n, subplot_spec=parentgs[q-1])
fig = ax.figure
ax.remove()
self.axes = np.empty((m,n), dtype=object)
for i in range(m):
for j in range(n):
self.axes[i,j] = fig.add_subplot(gs[i,j], label=f"{i}{j}")
def plot(self, data):
for ax,d in zip(self.axes.flat, data):
ax.plot(d)
_, (ax,ax2) = plt.subplots(ncols=2)
cm = ClassName("mymodel", ax=ax2) # calling my class
cm.plot(np.random.rand(4,10))
plt.show()

How to check if colorbar exists on figure

Question: Is there a way to check if a color bar already exists?
I am making many plots with a loop. The issue is that the color bar is drawn every iteration!
If I could determine if the color bar exists then I can put the color bar function in an if statement.
if cb_exists:
# do nothing
else:
plt.colorbar() #draw the colorbar
If I use multiprocessing to make the figures, is it possible to prevent multiple color bars from being added?
import numpy as np
import matplotlib.pyplot as plt
import multiprocessing
def plot(number):
a = np.random.random([5,5])*number
plt.pcolormesh(a)
plt.colorbar()
plt.savefig('this_'+str(number))
# I want to make a 50 plots
some_list = range(0,50)
num_proc = 5
p = multiprocessing.Pool(num_proc)
temps = p.map(plot, some_list)
I realize I can clear the figure with plt.clf() and plt.cla() before plotting the next iteration. But, I have data on my basemap layer I don't want to re-plot (that adds to the time it takes to create the plot). So, if I could remove the colorbar and add a new one I'd save some time.
Is is actually not easy to remove a colorbar from a plot and later draw a new one to it.
The best solution I can come up with at the moment is the following, which assumes that there is only one axes present in the plot. Now, if there was a second axis, it must be the colorbar beeing present. So by checking how many axes we find on the plot, we can judge upon whether or not there is a colorbar.
Here we also mind the user's wish not to reference any named objects from outside. (Which does not makes much sense, as we need to use plt anyways, but hey.. so was the question)
import matplotlib.pyplot as plt
import numpy as np
fig, ax = plt.subplots()
im = ax.pcolormesh(np.array(np.random.rand(2,2) ))
ax.plot(np.cos(np.linspace(0.2,1.8))+0.9, np.sin(np.linspace(0.2,1.8))+0.9, c="k", lw=6)
ax.set_title("Title")
cbar = plt.colorbar(im)
cbar.ax.set_ylabel("Label")
for i in range(10):
# inside this loop we should not access any variables defined outside
# why? no real reason, but questioner asked for it.
#draw new colormesh
im = plt.gcf().gca().pcolormesh(np.random.rand(2,2))
#check if there is more than one axes
if len(plt.gcf().axes) > 1:
# if so, then the last axes must be the colorbar.
# we get its extent
pts = plt.gcf().axes[-1].get_position().get_points()
# and its label
label = plt.gcf().axes[-1].get_ylabel()
# and then remove the axes
plt.gcf().axes[-1].remove()
# then we draw a new axes a the extents of the old one
cax= plt.gcf().add_axes([pts[0][0],pts[0][1],pts[1][0]-pts[0][0],pts[1][1]-pts[0][1] ])
# and add a colorbar to it
cbar = plt.colorbar(im, cax=cax)
cbar.ax.set_ylabel(label)
# unfortunately the aspect is different between the initial call to colorbar
# without cax argument. Try to reset it (but still it's somehow different)
cbar.ax.set_aspect(20)
else:
plt.colorbar(im)
plt.show()
In general a much better solution would be to operate on the objects already present in the plot and only update them with the new data. Thereby, we suppress the need to remove and add axes and find a much cleaner and faster solution.
import matplotlib.pyplot as plt
import numpy as np
fig, ax = plt.subplots()
im = ax.pcolormesh(np.array(np.random.rand(2,2) ))
ax.plot(np.cos(np.linspace(0.2,1.8))+0.9, np.sin(np.linspace(0.2,1.8))+0.9, c="k", lw=6)
ax.set_title("Title")
cbar = plt.colorbar(im)
cbar.ax.set_ylabel("Label")
for i in range(10):
data = np.array(np.random.rand(2,2) )
im.set_array(data.flatten())
cbar.set_clim(vmin=data.min(),vmax=data.max())
cbar.draw_all()
plt.draw()
plt.show()
Update:
Actually, the latter approach of referencing objects from outside even works together with the multiprocess approach desired by the questioner.
So, here is a code that updates the figure, without the need to delete the colorbar.
import matplotlib.pyplot as plt
import numpy as np
import multiprocessing
import time
fig, ax = plt.subplots()
im = ax.pcolormesh(np.array(np.random.rand(2,2) ))
ax.plot(np.cos(np.linspace(0.2,1.8))+0.9, np.sin(np.linspace(0.2,1.8))+0.9, c="w", lw=6)
ax.set_title("Title")
cbar = plt.colorbar(im)
cbar.ax.set_ylabel("Label")
tx = ax.text(0.2,0.8, "", fontsize=30, color="w")
tx2 = ax.text(0.2,0.2, "", fontsize=30, color="w")
def do(number):
start = time.time()
tx.set_text(str(number))
data = np.array(np.random.rand(2,2)*(number+1) )
im.set_array(data.flatten())
cbar.set_clim(vmin=data.min(),vmax=data.max())
tx2.set_text("{m:.2f} < {ma:.2f}".format(m=data.min(), ma= data.max() ))
cbar.draw_all()
plt.draw()
plt.savefig("multiproc/{n}.png".format(n=number))
stop = time.time()
return np.array([number, start, stop])
if __name__ == "__main__":
multiprocessing.freeze_support()
some_list = range(0,50)
num_proc = 5
p = multiprocessing.Pool(num_proc)
nu = p.map(do, some_list)
nu = np.array(nu)
plt.close("all")
fig, ax = plt.subplots(figsize=(16,9))
ax.barh(nu[:,0], nu[:,2]-nu[:,1], height=np.ones(len(some_list)), left=nu[:,1], align="center")
plt.show()
(The code at the end shows a timetable which allows to see that multiprocessing has indeed taken place)
If you can access to axis and image information, colorbar can be retrieved
as a property of the image (or the mappable to which associate colorbar).
Following a previous answer (How to retrieve colorbar instance from figure in matplotlib), an example could be:
ax=plt.gca() #plt.gca() for current axis, otherwise set appropriately.
im=ax.images #this is a list of all images that have been plotted
if im[-1].colorbar is None: #in this case I assume to be interested to the last one plotted, otherwise use the appropriate index or loop over
plt.colorbar() #plot a new colorbar
Note that an image without colorbar returns None to im[-1].colorbar
One approach is:
initially (prior to having any color bar drawn), set a variable
colorBarPresent = False
in the method for drawing the color bar, check to see if it's already drawn. If not, draw it and set the colorBarPresent variable True:
def drawColorBar():
if colorBarPresent:
# leave the function and don't draw the bar again
else:
# draw the color bar
colorBarPresent = True
There is an indirect way of guessing (with reasonable accuracy for most applications, I think) whether an Axes instance is home to a color bar. Depending on whether it is a horizontal or vertical color bar, either the X axis or Y axis (but not both) will satisfy all of these conditions:
No ticks
No tick labels
No axis label
Axis range is (0, 1)
So here's a function for you:
def is_colorbar(ax):
"""
Guesses whether a set of Axes is home to a colorbar
:param ax: Axes instance
:return: bool
True if the x xor y axis satisfies all of the following and thus looks like it's probably a colorbar:
No ticks, no tick labels, no axis label, and range is (0, 1)
"""
xcb = (len(ax.get_xticks()) == 0) and (len(ax.get_xticklabels()) == 0) and (len(ax.get_xlabel()) == 0) and \
(ax.get_xlim() == (0, 1))
ycb = (len(ax.get_yticks()) == 0) and (len(ax.get_yticklabels()) == 0) and (len(ax.get_ylabel()) == 0) and \
(ax.get_ylim() == (0, 1))
return xcb != ycb # != is effectively xor in this case, since xcb and ycb are both bool
Thanks to this answer for the cool != xor trick: https://stackoverflow.com/a/433161/6605826
With this function, you can see if a colorbar exists by:
colorbar_exists = any([is_colorbar(ax) for ax in np.atleast_1d(gcf().axes).flatten()])
or if you're sure the colorbar will always be last, you can get off easy with:
colorbar_exists = is_colorbar(gcf().axes[-1])

How to do a range bar graph in matplotlib?

I'm trying to make a plot using matplotlib that resembles the following:
However, I'm not quite sure which type of graph to use. My data has the following form, where start x position is a positive value greater or equal to 0:
<item 1><start x position><end x position>
<item 2><start x position><end x position>
Looking at the docs, I see that there is barh and errorbar, but I'm not sure if its possible to use barh with a start offset. What would be the best method to use, given my type of data? I'm not that familiar with the library, so I was hoping to get some insight.
Appetizer
Commented Code
As far as I know, the most direct way to do what you want requires that you directly draw your rectangles on the matplotlib canvas using the patches module of matplotlib
A simple implementation follows
import matplotlib.pyplot as plt
import matplotlib.patches as patches
def plot_rect(data, delta=0.4):
"""data is a dictionary, {"Label":(low,hi), ... }
return a drawing that you can manipulate, show, save etc"""
yspan = len(data)
yplaces = [.5+i for i in range(yspan)]
ylabels = sorted(data.keys())
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_yticks(yplaces)
ax.set_yticklabels(ylabels)
ax.set_ylim((0,yspan))
# later we'll need the min and max in the union of intervals
low, hi = data[ylabels[0]]
for pos, label in zip(yplaces,ylabels):
start, end = data[label]
ax.add_patch(patches.Rectangle((start,pos-delta/2.0),end-start,delta))
if start<low : low=start
if end>hi : hi=end
# little small trick, draw an invisible line so that the x axis
# limits are automatically adjusted...
ax.plot((low,hi),(0,0))
# now get the limits as automatically computed
xmin, xmax = ax.get_xlim()
# and use them to draw the hlines in your example
ax.hlines(range(1,yspan),xmin,xmax)
# the vlines are simply the x grid lines
ax.grid(axis='x')
# eventually return what we have done
return ax
# this is the main script, note that we have imported pyplot as plt
# the data, inspired by your example,
data = {'A':(1901,1921),
'B':(1917,1935),
'C':(1929,1948),
'D':(1943,1963),
'E':(1957,1983),
'F':(1975,1991),
'G':(1989,2007)}
# call the function and give its result a name
ax = plot_rect(data)
# so that we can further manipulate it using the `axes` methods, e.g.
ax.set_xlabel('Whatever')
# finally save or show what we have
plt.show()
The result of our sufferings has been shown in the first paragraph of this post...
Addendum
Let's say that you feel that blue is a very dull color...
The patches you've placed in your drawing are accessible as a property (aptly named patches...) of the drawing and modifiable too, e.g.,
ax = plot_rect(data)
ax.set_xlabel('Whatever')
for rect in ax.patches:
rect.set(facecolor=(0.9,0.9,0.2,1.0), # a tuple, RGBA
edgecolor=(0.6,0.2,0.3,1.0),
linewidth=3.0)
plt.show()
In my VH opinion, a custom plotting function should do the least indispensable to characterize the plot, as this kind of post-production is usually very easy in matplotlib.

In Python with Matplotlib how to check if a subplot is empty in the figure

I have some graphs created with NetworkX and show them on screen using Matplotlib. Specifically, since I don't know in advance how many graphs I need to show, I create a subplot on the figure on fly. That works fine. However, at some point in the script, some subplots are removed from the figure and the figure is shown with some empty subplots. I would like to avoid it, but I was not able to retrieve the subplots that are empty in the figure. Here is my code:
#instantiate a figure with size 12x12
fig = plt.figure(figsize=(12,12))
#when a graph is created, also a subplot is created:
ax = plt.subplot(3,4,count+1)
#and the graph is drawn inside it: N.B.: pe is the graph to be shown
nx.draw(pe, positions, labels=positions, font_size=8, font_weight='bold', node_color='yellow', alpha=0.5)
#many of them are created..
#under some conditions a subplot needs to be deleted, and so..
#condition here....and then retrieve the subplot to deleted. The graph contains the id of the ax in which it is shown.
for ax in fig.axes:
if id(ax) == G.node[shape]['idax']:
fig.delaxes(ax)
until here works fine, but when I show the figure, the result looks like this:
you can notice that there are two empty subplots there.. at the second position and at the fifth. How can I avoid it? Or.. how can I re-organize the subplots in such a way that there are no more blanks in the figure?
Any help is apreciated! Thanks in advance.
So to do this I would keep a list of axes and when I delete the contents of one I would swap it out with a full one. I think the example below solved the problem (or at least gives an idea of how to solve it):
import matplotlib.pyplot as plt
# this is just a helper class to keep things clean
class MyAxis(object):
def __init__(self,ax,fig):
# this flag tells me if there is a plot in these axes
self.empty = False
self.ax = ax
self.fig = fig
self.pos = self.ax.get_position()
def del_ax(self):
# delete the axes
self.empty = True
self.fig.delaxes(self.ax)
def swap(self,other):
# swap the positions of two axes
#
# THIS IS THE IMPORTANT BIT!
#
new_pos = other.ax.get_position()
self.ax.set_position(new_pos)
other.ax.set_position(self.pos)
self.pos = new_pos
def main():
# generate a figure and 10 subplots in a grid
fig, axes = plt.subplots(ncols=5,nrows=2)
# get these as a list of MyAxis objects
my_axes = [MyAxis(ax,fig) for ax in axes.ravel()]
for ax in my_axes:
# plot some random stuff
ax.ax.plot(range(10))
# delete a couple of axes
my_axes[0].del_ax()
my_axes[6].del_ax()
# count how many axes are dead
dead = sum([ax.empty for ax in my_axes])
# swap the dead plots for full plots in a row wise fashion
for kk in range(dead):
for ii,ax1 in enumerate(my_axes[kk:]):
if ax1.empty:
print ii,"dead"
for jj,ax2 in enumerate(my_axes[::-1][kk:]):
if not ax2.empty:
print "replace with",jj
ax1.swap(ax2)
break
break
plt.draw()
plt.show()
if __name__ == "__main__":
main()
The extremely ugly for loop construct is really just a placeholder to give an example of how the axes can be swapped.

Reusing patch objects in matplotlib without them moving position

I want to automatically generate a series of plots which are clipped to patches. If I try and reuse a patch object, it moves position across the canvas.
This script (based on an answer to a previous question by Yann) demonstrates what is happening.
import pylab as plt
import scipy as sp
import matplotlib.patches as patches
sp.random.seed(100)
x = sp.random.random(100)
y = sp.random.random(100)
patch = patches.Circle((.75,.75),radius=.25,fc='none')
def doplot(x,y,patch,count):
fig = plt.figure()
ax = fig.add_subplot(111)
im = ax.scatter(x,y)
ax.add_patch(patch)
im.set_clip_path(patch)
plt.savefig(str(count) + '.png')
for count in xrange(4):
doplot(x,y,patch,count)
The first plot looks like this:
But in the second '1.png', the patch has moved..
However replotting again doesn't move the patch. '2.png' and '3.png' look exactly the same as '1.png'.
Could anyone point me in the right direction of what I'm doing wrong??
In reality, the patches I'm using are relatively complex and take some time to generate - I'd prefer to not have to remake them every frame if possible.
The problem can be avoided by using the same axes for each plot, with ax.cla() called to clear the plot after each iteration.
import pylab as plt
import scipy as sp
import matplotlib.patches as patches
sp.random.seed(100)
patch = patches.Circle((.75,.75),radius=.25,fc='none')
fig = plt.figure()
ax = fig.add_subplot(111)
def doplot(x,y,patch,count):
ax.set_xlim(-0.2,1.2)
ax.set_ylim(-0.2,1.2)
x = sp.random.random(100)
y = sp.random.random(100)
im = ax.scatter(x,y)
ax.add_patch(patch)
im.set_clip_path(patch)
plt.savefig(str(count) + '.png')
ax.cla()
for count in xrange(4):
doplot(x,y,patch,count)
An alternative to unutbu's answer, is to use the copy package, which can copy objects. It is very hard to see how things are changing after one calls add_patch, but they are. The axes, figure, extents,clip_box,transform and window_extent properties of the patch are changed. Unfortantely the superficial printing of each of these properties results in the same string, so it looks like they are not changing. But the underlying attributes of some or all of these properties, eg extents is a Bbox, are probably changed.
The copy call will allow you to get a unique patch for each figure you make, without know what kind of patch it is. This still does not answer why this happens, but as I wrote above it's an alternative solution to the problem:
import copy
def doplot(x,y,patch,count):
newPatch = copy.copy(patch)
fig = plt.figure(dpi=50)
ax = fig.add_subplot(111)
im = ax.scatter(x,y)
ax.add_patch(newPatch)
im.set_clip_path(newPatch)
plt.savefig(str(count) + '.png')
Also you can use fig.savefig(str(count) + '.png'). This explicitly saves the figure fig where as the plt.savefig call saves the current figure, which happens to be the one you want.

Categories