Matplotlib: Plotting multiple histograms in plt.subplots - python

Background of the problem:
I'm working on a class that takes an Axes object as constructor parameter and produces a (m,n) dimension figure with a histogram in each cell, kind of like the figure below:
There are two important things to note here, that I'm not allowed to modified in any way:
The Figure object is not passed as a constructor parameter; only the Axes object is. So the subplots object cannot be modified in any way.
The Axes parameter is set to that of a (1,1) figure, by default (as below). All the modification required to make it an (m,n) figure are performed within the class (inside its methods)
_, ax = plt.subplots() # By default takes (1,1) dimension
cm = ClassName(model, ax=ax, histogram=True) # calling my class
What I'm stuck on:
Since I want to plot histograms within each cell, I decided to approach it by looping over each cell and creating a histogram for each.
results[col].hist(ax=self.ax[y,x], bins=bins)
However, I'm not able to specify the axes of the histogram in any way. This is because the Axes parameter passed is of default dimension (1,1) and hence not index-able. When I try this I get a TypeError saying.
TypeError: 'AxesSubplot' object is not subscriptable
With all this considered, I would like to know of any possible ways I can add my histogram to the parent Axes object. Thanks for taking a look.

The requirement is pretty strict and maybe not the best design choice. Because you later want to plot several subplots at the position of a single subplot, this single subplot is only created for the sole purpose of dying and being replaced a few moments later.
So what you can do is obtain the position of the axes you pass in and create a new gridspec at that position. Then remove the original axes and create a new set of axes at within that newly created gridspec.
The following would be an example. Note that it currently requires that the axes to be passed in is a Subplot (as opposed to any axes).
It also hardcodes the number of plots to be 2*2. In the real use case you would probably derive that number from the model you pass in.
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import gridspec
class ClassName():
def __init__(self, model, ax=None, **kwargs):
ax = ax or plt.gca()
if not hasattr(ax, "get_gridspec"):
raise ValueError("Axes needs to be a subplot")
parentgs = ax.get_gridspec()
q = ax.get_geometry()[-1]
# Geometry of subplots
m, n = 2, 2
gs = gridspec.GridSpecFromSubplotSpec(m,n, subplot_spec=parentgs[q-1])
fig = ax.figure
ax.remove()
self.axes = np.empty((m,n), dtype=object)
for i in range(m):
for j in range(n):
self.axes[i,j] = fig.add_subplot(gs[i,j], label=f"{i}{j}")
def plot(self, data):
for ax,d in zip(self.axes.flat, data):
ax.plot(d)
_, (ax,ax2) = plt.subplots(ncols=2)
cm = ClassName("mymodel", ax=ax2) # calling my class
cm.plot(np.random.rand(4,10))
plt.show()

Related

Plot only specific subplots of a grid of subplots

I have created a grid of subplots to my liking.
I initiated the plotting by defining fig,ax = plt.subplots(2,6,figsize=(24,8))
So far so good. I filled those subplots with their respective content. Now I want to plot a single or two particular subplot in isolation. I tried:
ax[idx][idx].plot()
This does not work and returns an empty list
I have tried:
fig_single,ax_single = plt.subplots(2,1)
ax_single[0]=ax[idx][0]
ax_single[1]=ax[idx][1]
This returns:
TypeError: 'AxesSubplot' object does not support item assignment
How do I proceed without plotting those subplots again by calling the respective plot functions?
You're close.
fig,ax = plt.subplots(nrows=2,ncols=6,sharex=False,sharey=False,figsize=(24,8))
#set either sharex=True or sharey=True if you wish axis limits to be shared
#=> very handy for interactive exploration of timeseries data, ...
r=0 #first row
c=0 #first column
ax[r,c].plot() #plot your data, instead of ax[r][c].plot()
ax[r,c].set_title() #name title for a subplot
ax[r,c].set_ylabel('Ylabel ') #ylabel for a subplot
ax[r,c].set_xlabel('X axis label') #xlabel for a subplot
A more complete/flexible method is to assign r,c:
for i in range(nrows*ncols):
r,c = np.divmod(i,ncols)
ax[r,c].plot() #....
You can afterwards still make modifications, e.g. set_ylim, set_title, ...
So if you want to name the label of the 11th subplot:
ax[2,4].set_ylabel('11th subplot ylabel')
You will often want to make use of fig.tight_layout() at the end, so that the figure uses the available area correctly.
Complete example:
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0,180,180)
nrows = 2
ncols = 6
fig,ax = plt.subplots(nrows=nrows,ncols=ncols,sharex=False,sharey=False,figsize=(24,8))
for i in range(nrows*ncols):
r,c = np.divmod(i,ncols)
y = np.sin(x*180/np.pi*(i+1))
ax[r,c].plot(x,y)
ax[r,c].set_title('%s'%i)
fig.suptitle('Overall figure title')
fig.tight_layout()

Retrieve all colorbars in figure

I'm trying to manipulate all colorbar instances contained in a figure. There is fig.get_axes() to obtain a list of axes, but I cannot find anything similar for colorbars.
This answer, https://stackoverflow.com/a/19817573/7042795, only applies to special situations, but not the general case.
Consider this MWE:
import matplotlib.pyplot as plt
import numpy as np
data = np.random.random((10,10)) # Generate some random data to plot
fig, axs = plt.subplots(1,2)
im1 = axs[0].imshow(data)
cbar1 = fig.colorbar(im1)
im2 = axs[1].imshow(2*data)
cbar2 = fig.colorbar(im2)
fig.show()
How can I get cbar1 and cbar2 from fig?
What I need is a function like:
def get_colorbars(fig):
cbars = fig.get_colorbars()
return cbars
cbars = get_colorbars(fig)
You would have no choice but to check each object present in the figure whether it has a colorbar or not. This could look as follows:
def get_colorbars(fig):
cbs = []
for ax in fig.axes:
cbs.extend(ax.findobj(lambda obj: hasattr(obj, "colorbar") and obj.colorbar))
return [a.colorbar for a in cbs]
This will give you all the colorbars that are tied to an artist. There may be more colorbars in the figure though, e.g. created directly from a ScalarMappble or multiple colorbars for the same object; those cannot be found.
Since the only place I'm reasonably sure that colorbar references are retained is as an attribute of the artist they are tied to, the best solution I could think of is to search all artists in a figure. This is best done recursively:
def get_colorbars(fig):
def check_kids(obj, bars):
for child in obj.get_children():
if isinstance(getattr(child, 'colorbar', None), Colorbar):
bars.append(child.colorbar)
check_kids(child, bars)
return bars
return check_kids(fig, [])
I have not had a chance to test this code, but it should at least point you in the right direction.

matplotlib get colorbar mappable from an axis

I want to add a colorbar WITHOUT what is returned by the axis on plotting things.
Sometimes I draw things to an axis inside a function, which returns nothing.
Is there a way to get the mappable for a colorbar from an axis where a plotting has been done beforehand?
I believe there is enough information about colormap and color range bound to the axis itself.
I'd like tp do something like this:
def plot_something(ax):
ax.plot( np.random.random(10), np.random.random(10), c= np.random.random(10))
fig, axs = plt.subplots(2)
plot_something(axs[0])
plot_something(axs[1])
mappable = axs[0].get_mappable() # a hypothetical method I want to have.
fig.colorbar(mappable)
plt.show()
EDIT
The answer to the possible duplicate can partly solve my problem as is given in the code snippet. However, this question is more about retrieving a general mappable object from an axis, which seems to be impossible according to Diziet Asahi.
The way you could get your mappable would depend on what plotting function your are using in your plot_something() function.
for example:
plot() returns a Line2D object. A reference to that object is
stored in the list ax.lines of the Axes object. That being said, I don't think a Line2D can be used as a mappable for colorbar()
scatter() returns a PathCollection collection object. This object is stored in the ax.collections list of the Axes object.
On the other hand, imshow() returns an AxesImage object, which is stored in ax.images
You might have to try and look in those different list until you find an appropriate object to use.
def plot_something(ax):
x = np.random.random(size=(10,))
y = np.random.random(size=(10,))
c = np.random.random(size=(10,))
ax.scatter(x,y,c=c)
fig, ax = plt.subplots()
plot_something(ax)
mappable = ax.collections[0]
fig.colorbar(mappable=mappable)

How to subplot multiple graphs when calling a function that plots the graph?

I have a function that plots a graph. I can call this graph with different variables to alter the graph. I'd like to call this function multiple times and plot the graphs along side each other but not sure how to do so
def plt_graph(x, graph_title, horiz_label):
df[x].plot(kind='barh')
plt.title(graph_title)
plt.ylabel("")
plt.xlabel(horiz_label)
plt_graph('gross','Total value','Gross (in millions)')
In case you know the number of plots you want to produce beforehands, you can first create as many subplots as you need
fig, axes = plt.subplots(nrows=1, ncols=5)
(in this case 5) and then provide the axes to the function
def plt_graph(x, graph_title, horiz_label, ax):
df[x].plot(kind='barh', ax=ax)
Finally, call every plot like this
plt_graph("framekey", "Some title", "some label", axes[4])
(where 4 stands for the fifth and last plot)
I have created a tool to do this really easily. I use it all the time in jupyter notebooks and find it so much neater than a big column of charts. Copy the Gridplot class from this file:
https://github.com/simonm3/analysis/blob/master/analysis/plot.py
Usage:
gridplot = Gridplot()
plt.plot(x)
plt.plot(y)
It shows each new plot in a grid with 4 plots to a row. You can change the size of the charts or the number per row. It works for plt.plot, plt.bar, plt.hist and plt.scatter. However it does require you use matplot directly rather than pandas.
If you want to turn it off:
gridplot.off()
If you want to reset the grid to position 1:
gridplot.on()
Here is a way that you can do it. First you create the figure which will contain the axes object. With those axes you have something like a canvas, which is where every graph will be drawn.
fig, ax = plt.subplots(1,2)
Here I have created one figure with two axes. This is a one row and two columns figure. If you inspect the ax variable you will see two objects. This is what we'll use for all the plotting. Now, going back to the function, let's start with a simple dataset and define the function.
df = pd.DataFrame({"a": np.random.random(10), "b": np.random.random(10)})
def plt_graph(x, graph_title, horiz_label, ax):
df[x].plot(kind = 'barh', ax = ax)
ax.set_xlabel(horiz_label)
ax.set_title(graph_title)
Then, to call the function you will simply do this:
plt_graph("a", "a", "a", ax=ax[0])
plt_graph("b", "b", "b", ax=ax[1])
Note that you pass each graph that you want to create to any of the axes you have. In this case, as you have two, you pass the first to the first axes and so on. Note that if you include seaborn in your import (import seaborn as sns), automatically the seaborn style will be applied to your graphs.
This is how your graphs will look like.
When you are creating plotting functions, you want to look at matplotlib's object oriented interface.

Update the x-axis of a matplotlib subplot according to the y-axis of a different subplot

I would like to plot an orthogonal projection like this one:
using matplotlib, possibly including the 3D subplot. All the subplots should share common axes.
fig = plt.figure()
ax = fig.add_subplot(221, title="XZ")
bx = fig.add_subplot(222, title="YZ", sharey=ax)
cx = fig.add_subplot(223, title="XY", sharex=ax, sharey=[something like bx.Xaxis])
dx = fig.add_subplot(224, title="XYZ", projection="3d", sharex=ax, sharey=bx, sharez=[something like bx.Yaxis]
I can't figure out how to "link" the x-axis of one plot with the y-axis of another. Is there a way to accomplish this?
Late to the party but...
You should be able to accomplish what you want by manually updating one subplot's axis data with the other subplots axis data.
Using the notation from your post, for example, you can match the ylim values of cx with the xlim values of bx using the get and set methods.
cx.set_ylim(bx.get_ylim())
Similarly, you can match tick labels and positions across subplots.
bx_xticks = bx.get_xticks()
bx_xticklabels = [label.get_text() for label in bx.get_xticklabels()]
cx.set_yticks(bx_xticks)
cx.set_yticklabels(bx_xticklabels)
You should be able to define any and all axis attributes and objects dynamically from an already instantiated subplot in this way.
Here is my approach to the problem, which is basically a condensed version of #elebards answer. I just add update limit methods to the axes class, so they get access to the set_xlim / set_ylim methods. Then I connect these functions to the callbacks of the axis I want to synchronize it. When these are called the event argument will be filled with
import types
import matplotlib.pyplot as plt
def sync_y_with_x(self, event):
self.set_xlim(event.get_ylim(), emit=False)
def sync_x_with_y(self, event):
self.set_ylim(event.get_xlim(), emit=False)
fig = plt.figure()
ax1 = fig.add_subplot(211)
ax2 = fig.add_subplot(212)
ax1.update_xlim = types.MethodType(sync_y_with_x, ax1)
ax2.update_ylim = types.MethodType(sync_x_with_y, ax2)
ax1.callbacks.connect("ylim_changed", ax2.update_ylim)
ax2.callbacks.connect("xlim_changed", ax1.update_xlim)
I solved1 the problem by exploiting event handlers.
Listening for "*lim_changed" events and then properly get_*lim and set*_lim to synchronise the limits does the trick.
Note you also have to reverse the x-axis in the upper right plot YZ.
Here is a sample function to sync the x-axis with the y-axis:
def sync_x_with_y(self, axis):
# check whether the axes orientation is not coherent
if (axis.get_ylim()[0] > axis.get_ylim()[1]) != (self.get_xlim()[0] > self.get_xlim()[1]):
self.set_xlim(axis.get_ylim()[::-1], emit=False)
else:
self.set_xlim(axis.get_ylim(), emit=False)
I implemented a simple class Orthogonal Projection that make quite easy to make such kind of plots.
1 Starting from a hint that Benjamin Root gave me on matplotlib mailing list almost a year ago...sorry for not having posted the solution before

Categories