matplotlib get colorbar mappable from an axis - python

I want to add a colorbar WITHOUT what is returned by the axis on plotting things.
Sometimes I draw things to an axis inside a function, which returns nothing.
Is there a way to get the mappable for a colorbar from an axis where a plotting has been done beforehand?
I believe there is enough information about colormap and color range bound to the axis itself.
I'd like tp do something like this:
def plot_something(ax):
ax.plot( np.random.random(10), np.random.random(10), c= np.random.random(10))
fig, axs = plt.subplots(2)
plot_something(axs[0])
plot_something(axs[1])
mappable = axs[0].get_mappable() # a hypothetical method I want to have.
fig.colorbar(mappable)
plt.show()
EDIT
The answer to the possible duplicate can partly solve my problem as is given in the code snippet. However, this question is more about retrieving a general mappable object from an axis, which seems to be impossible according to Diziet Asahi.

The way you could get your mappable would depend on what plotting function your are using in your plot_something() function.
for example:
plot() returns a Line2D object. A reference to that object is
stored in the list ax.lines of the Axes object. That being said, I don't think a Line2D can be used as a mappable for colorbar()
scatter() returns a PathCollection collection object. This object is stored in the ax.collections list of the Axes object.
On the other hand, imshow() returns an AxesImage object, which is stored in ax.images
You might have to try and look in those different list until you find an appropriate object to use.
def plot_something(ax):
x = np.random.random(size=(10,))
y = np.random.random(size=(10,))
c = np.random.random(size=(10,))
ax.scatter(x,y,c=c)
fig, ax = plt.subplots()
plot_something(ax)
mappable = ax.collections[0]
fig.colorbar(mappable=mappable)

Related

How can I return a matplotlib figure from a function?

I need to plot changing molecule numbers against time. But I'm also trying to investigate the effects of parallel processing so I'm trying to avoid writing to global variables. At the moment I have the following two numpy arrays tao_all, contains all the time points to be plotted on the x-axis and popul_num_all which contains the changing molecule numbers to be plotted on the y-axis.
The current code I've got for plotting is as follows:
for i, label in enumerate(['Enzyme', 'Substrate', 'Enzyme-Substrate complex', 'Product']):
figure1 = plt.plot(tao_all, popul_num_all[:, i], label=label)
plt.legend()
plt.tight_layout()
plt.show()
I need to encapsulate this in a function that takes the above arrays as the input and returns the graph. I've read a couple of other posts on here that say I should write my results to an axis and return the axis? But I can't quite get my head around applying that to my problem?
Cheers
def plot_func(x, y):
fig,ax = plt.subplots()
ax.plot(x, y)
return fig
Usage:
fig = plot_func([1,2], [3,4])
Alternatively you may want to return ax. For details about Figure and Axes see the docs. You can get the axes array from the figure by fig.axes and the figure from the axes by ax.get_figure().
In addition to above answer, I can suggest you to use matplotlib animation.FuncAnimation method if you are working with the time series and want to make your visualization better.
You can find the details here https://matplotlib.org/api/_as_gen/matplotlib.animation.FuncAnimation.html

How do I change the Matplotlib axis limits for a plot given by a specific library(trackpy)?

This function gives me a plot. However I want to change the default axis. It says in documentation that ax refers to:
ax : matplotlib axes object, optional.
I tried to input the axis limit as ax=([0 100 0 500]) for example but it recognizes it as a tuple or a list. How is the correct way to input it?
Thanks!
The trackpy.plot_traj functions returns a matplotlib axes object.
So, you'll want something like:
ax = trackpy.plot_traj(traj)
ax.set_xlim([0, 100])
ax.set_ylim([0, 500])

Matplotlib: Plotting multiple histograms in plt.subplots

Background of the problem:
I'm working on a class that takes an Axes object as constructor parameter and produces a (m,n) dimension figure with a histogram in each cell, kind of like the figure below:
There are two important things to note here, that I'm not allowed to modified in any way:
The Figure object is not passed as a constructor parameter; only the Axes object is. So the subplots object cannot be modified in any way.
The Axes parameter is set to that of a (1,1) figure, by default (as below). All the modification required to make it an (m,n) figure are performed within the class (inside its methods)
_, ax = plt.subplots() # By default takes (1,1) dimension
cm = ClassName(model, ax=ax, histogram=True) # calling my class
What I'm stuck on:
Since I want to plot histograms within each cell, I decided to approach it by looping over each cell and creating a histogram for each.
results[col].hist(ax=self.ax[y,x], bins=bins)
However, I'm not able to specify the axes of the histogram in any way. This is because the Axes parameter passed is of default dimension (1,1) and hence not index-able. When I try this I get a TypeError saying.
TypeError: 'AxesSubplot' object is not subscriptable
With all this considered, I would like to know of any possible ways I can add my histogram to the parent Axes object. Thanks for taking a look.
The requirement is pretty strict and maybe not the best design choice. Because you later want to plot several subplots at the position of a single subplot, this single subplot is only created for the sole purpose of dying and being replaced a few moments later.
So what you can do is obtain the position of the axes you pass in and create a new gridspec at that position. Then remove the original axes and create a new set of axes at within that newly created gridspec.
The following would be an example. Note that it currently requires that the axes to be passed in is a Subplot (as opposed to any axes).
It also hardcodes the number of plots to be 2*2. In the real use case you would probably derive that number from the model you pass in.
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import gridspec
class ClassName():
def __init__(self, model, ax=None, **kwargs):
ax = ax or plt.gca()
if not hasattr(ax, "get_gridspec"):
raise ValueError("Axes needs to be a subplot")
parentgs = ax.get_gridspec()
q = ax.get_geometry()[-1]
# Geometry of subplots
m, n = 2, 2
gs = gridspec.GridSpecFromSubplotSpec(m,n, subplot_spec=parentgs[q-1])
fig = ax.figure
ax.remove()
self.axes = np.empty((m,n), dtype=object)
for i in range(m):
for j in range(n):
self.axes[i,j] = fig.add_subplot(gs[i,j], label=f"{i}{j}")
def plot(self, data):
for ax,d in zip(self.axes.flat, data):
ax.plot(d)
_, (ax,ax2) = plt.subplots(ncols=2)
cm = ClassName("mymodel", ax=ax2) # calling my class
cm.plot(np.random.rand(4,10))
plt.show()

Matplotlib Python -> Plot_dendrogram Axis object

I'm building a plot_dendrogram graph and I'm struggling with the axis object. I need to create a Matplotlib axis object from 0,0.5 for the x and y axis. How do I do this?
current attempt:
plt.axis([0,.5,0,.5])
and I get the following error:
'list' object has no attribute 'set_ylim'
As the documentation states, plt.axis is a convenience function, which sets the axis limits of the current axes object.
To create the axes object, I would suggest something like this:
ax = plt.subplot(111)
ax.set_xlim([0, 0.5])
ax.set_ylim([0, 0.5])
However, I'm not sure how you got the error you posted. Even if I have not created any axes and just do:
import matplotlib.pyplot as plt
plt.axis([0,.5,0,.5])
I don't get an error.

Update the x-axis of a matplotlib subplot according to the y-axis of a different subplot

I would like to plot an orthogonal projection like this one:
using matplotlib, possibly including the 3D subplot. All the subplots should share common axes.
fig = plt.figure()
ax = fig.add_subplot(221, title="XZ")
bx = fig.add_subplot(222, title="YZ", sharey=ax)
cx = fig.add_subplot(223, title="XY", sharex=ax, sharey=[something like bx.Xaxis])
dx = fig.add_subplot(224, title="XYZ", projection="3d", sharex=ax, sharey=bx, sharez=[something like bx.Yaxis]
I can't figure out how to "link" the x-axis of one plot with the y-axis of another. Is there a way to accomplish this?
Late to the party but...
You should be able to accomplish what you want by manually updating one subplot's axis data with the other subplots axis data.
Using the notation from your post, for example, you can match the ylim values of cx with the xlim values of bx using the get and set methods.
cx.set_ylim(bx.get_ylim())
Similarly, you can match tick labels and positions across subplots.
bx_xticks = bx.get_xticks()
bx_xticklabels = [label.get_text() for label in bx.get_xticklabels()]
cx.set_yticks(bx_xticks)
cx.set_yticklabels(bx_xticklabels)
You should be able to define any and all axis attributes and objects dynamically from an already instantiated subplot in this way.
Here is my approach to the problem, which is basically a condensed version of #elebards answer. I just add update limit methods to the axes class, so they get access to the set_xlim / set_ylim methods. Then I connect these functions to the callbacks of the axis I want to synchronize it. When these are called the event argument will be filled with
import types
import matplotlib.pyplot as plt
def sync_y_with_x(self, event):
self.set_xlim(event.get_ylim(), emit=False)
def sync_x_with_y(self, event):
self.set_ylim(event.get_xlim(), emit=False)
fig = plt.figure()
ax1 = fig.add_subplot(211)
ax2 = fig.add_subplot(212)
ax1.update_xlim = types.MethodType(sync_y_with_x, ax1)
ax2.update_ylim = types.MethodType(sync_x_with_y, ax2)
ax1.callbacks.connect("ylim_changed", ax2.update_ylim)
ax2.callbacks.connect("xlim_changed", ax1.update_xlim)
I solved1 the problem by exploiting event handlers.
Listening for "*lim_changed" events and then properly get_*lim and set*_lim to synchronise the limits does the trick.
Note you also have to reverse the x-axis in the upper right plot YZ.
Here is a sample function to sync the x-axis with the y-axis:
def sync_x_with_y(self, axis):
# check whether the axes orientation is not coherent
if (axis.get_ylim()[0] > axis.get_ylim()[1]) != (self.get_xlim()[0] > self.get_xlim()[1]):
self.set_xlim(axis.get_ylim()[::-1], emit=False)
else:
self.set_xlim(axis.get_ylim(), emit=False)
I implemented a simple class Orthogonal Projection that make quite easy to make such kind of plots.
1 Starting from a hint that Benjamin Root gave me on matplotlib mailing list almost a year ago...sorry for not having posted the solution before

Categories