Related
I am using Python/matplotlib to create a figure whereby it has three subplots, each returned from a different 'source' or class method.
For example, I have a script called 'plot_spectra.py' that contains the Spectra() class with method Plot().
So, calling Spectra('filename.ext').Plot() will return a tuple, as per the code below:
# create the plot
fig, ax = plt.subplots()
ax.contour(xx, yy, plane, levels=cl, cmap=cmap)
ax.set_xlim(ppm_1h_0, ppm_1h_1)
ax.set_ylim(ppm_13c_0, ppm_13c_1)
# return the contour plot
return fig, ax
It is my understanding that the 'figure' is the 'window' in matplotlib, and the 'ax' is an individual plot. I would then want to say, plot three of these 'ax' objects in the same figure, but I am struggling to do so because I keep getting an empty window and I think I have misunderstood what each object actually is.
Calling:
hnca, hnca_ax = Spectra('data/HNCA.ucsf', type='sparky').Plot(plane_ppm=resi.N(), vline=resi.H())
plt.subplot(2,2,1)
plt.subplot(hnca_ax)
eucplot, barplot = PlotEucXYIntensity(scores, x='H', y='N')
plt.subplot(2,2,3)
plt.subplot(eucplot)
plt.subplot(2,2,4)
plt.subplot(barplot)
plt.show()
Ultimately, what I am trying to obtain is a single window that looks like this:
Where each plot has been returned from a different function or class method.
What 'object' do I need to return from my functions? And how do I incorporate these three objects into a single figure?
I would suggest this kind of approach, where you specify the ax on which you want to plot in the function:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
def Spectra(data, ax):
ax.plot(data)
def PlotIntensity(data, ax):
ax.hist(data)
def SeabornScatter(data, ax):
sns.scatterplot(data, data, ax = ax)
spectrum = np.random.random((1000,))
plt.figure()
ax1 = plt.subplot(1,3,1)
Spectra(spectrum, ax1)
ax2 = plt.subplot(1,3,2)
SeabornScatter(spectrum, ax2)
ax3 = plt.subplot(1,3,3)
PlotIntensity(spectrum, ax3)
plt.tight_layout()
plt.show()
You can specify the grid for the subplots in very different ways, and you probably also want to have a look on the gridspec module.
One way to do this is:
f = plt.figure()
gs = f.add_gridspec(2,2)
ax = f.add_subplot(gs[0,:])
Think of the '2,2' as adding 2 row x 2 columns.
On the third line 'gs[0,:]' is telling to add a chart on row 0, all columns. This will create the chart on the top of your top. Note that indices begin with 0 and not with 1.
To add the 'eucplot' you will have to call a different ax on row 1 and column 0:
ax2 = f.add_subplot(gs[1,0])
Lastly, the 'barplot' will go in yet a different ax on row 1, column 1:
ax3 = f.add_subplot(gs[1,1])
See this site here for further reference: Customizing Figure Layouts Using GridSpec and Other Functions
I'm writing a program which gets data and then uses time series forecasting to predict data values for the next, say, 300 data points.
However, only data which fulfills a certain condition will be plotted, so there is no defined number of subplots for the add_subplot() method. I'm aware of the plot.subplots() function, but something such as
fig, (ax1, ax2) = plt.subplots(1, 2)
implies that two graphs will definitely be plotted and I need to change the specific amount, like with a parameter.
Here is a simplified version of the current code which results in each plot being in separate windows:
fig = plt.figure() # creates a figure instance for the final graph output
plots = 1 # indicates the total number of plots to plot, starting from 1
# passed as a parameter to the add_subplot() function
for data in dataSet:
forecast(data, fig, plots)
plt.figure(fig.number)
plt.show()
And the function:
import matplotlib.pyplot as plt
import matplotlib.ticker as tick
from pandas import Series
from statsmodels.tsa.holtwinters import ExponentialSmoothing
def forecast(data, superFigure, plotNumber):
index = range(0, len(data))
plotData = Series(data, index)
# fit the data values into a specific model:
modelFit = ExponentialSmoothing(plotData, trend="add").fit()
# forecast for the next 300 points:
modelForecast = modelFit.forecast(300)
if [condition]:
# plot the original data points:
points = plotData.plot(marker='x', color='black', label='Base Data')
points.set_xlim(0, len(data) + 300)
# plot the forecast in a different colour:
modelForecast.plot(marker='x', ax=points, color='blue', label='Forecasted Data')
plt.title("Plot Title")
plt.xlabel("X Axis")
plt.ylabel("Y Axis")
# format the axes, adding thousand separator
points.get_xaxis().set_major_formatter(
tick.FuncFormatter(lambda x, p: format(int(x), ',')))
points.get_yaxis().set_major_formatter(
tick.FuncFormatter(lambda x, p: format(int(x), ',')))
plt.legend()
plt.show()
This produces multiple graphs such as this (actual labels have been cut out).
Unfortunately you have to close each graph before viewing the next one, and I want every graph to be visible on one page.
I tried changing the code within the "if [condition]" to:
if [condition]:
points = plotData.plot(marker='x', color='black', label='Base Data')
modelForecast.plot(marker='x', ax=points, color='blue', label='Forecasted Data')
dataLine = plt.gca().get_lines()[0]
forecastLine = plt.gca().get_lines()[1]
# put all x and y values into single lists by concatenating them
totalXData = [*dataLine.get_xdata(), *forecastLine.get_xdata()]
totalYData = [*dataLine.get_ydata(), *forecastLine.get_ydata()]
subset = superFigure.add_subplot(10, 10, plotNumber)
for i in range(0, len(totalXData)):
subset.plot(totalXData[i], totalYData[i])
plotNumber += 1
These changes produce this exact graph which seems to have the other graphs squished in the top-left corner, and I get "MatplotlibDeprecationWarning: Adding an axes using the same arguments as a previous axes currently reuses the earlier instance" warnings.
If I change "superFigure.add_subplot(10, 10, plotNumber)" to "superFigure.add_subplot(20, 20, plotNumber)" I also get "UserWarning: Tight layout not applied. tight_layout cannot make axes width small enough to accommodate all axes decorations".
I then tried to change it to:
if [condition]:
fig, ax = plt.subplots()
plotData.plot(marker='x', ax=ax, color='black', label='Base Data')
modelForecast.plot(marker='x', ax=ax, color='blue', label='Forecasted Data')
ax.set([...])
ax.legend()
plt.show()
which doesn't produce the desired output assumedly because it recreates the figure on each call of forecast(), unless a figure window can contain multiple figures.
I also sometimes get the following warning:
RuntimeWarning: More than 20 figures have been opened. Figures created
through the pyplot interface (matplotlib.pyplot.figure) are retained
until explicitly closed and may consume too much memory.
fig, ax = plt.subplots()
How can I create subplots which include all the formatting and are displayed in one window all together?
I want to create a figure with two y-axes and add multiple curves to each of those axes at different points in my code (from different functions).
In a first function, I create a figure:
import matplotlib.pyplot as plt
from numpy import *
# Opens new figure with two axes
def function1():
f = plt.figure(1)
ax1 = plt.subplot(211)
ax2 = ax1.twinx()
x = linspace(0,2*pi,100)
ax1.plot(x,sin(x),'b')
ax2.plot(x,1000*cos(x),'g')
# other stuff will be shown in subplot 212...
Now, in a second function I want to add a curve to each of the already created axes:
def function2():
# Get handles of figure, which is already open
f = plt.figure(1)
ax3 = plt.subplot(211).axes # get handle to 1st axis
ax4 = ax3.twinx() # get handle to 2nd axis (wrong?)
# Add new curves
x = linspace(0,2*pi,100)
ax3.plot(x,sin(2*x),'m')
ax4.plot(x,1000*cos(2*x),'r')
Now my problem is that the green curve added in the first code block is not anymore visible after the second block is executed (all others are).
I think the reason for this is the line
ax4 = ax3.twinx()
in my second code block. It probably creates a new twinx() instead of returning a handle to the existing one.
How would I correctly get the handles to already existing twinx-axes in a plot?
you can use get_shared_x_axes (get_shared_y_axes) to get a handle to the axes created by twinx (twiny):
# creat some axes
f,a = plt.subplots()
# create axes that share their x-axes
atwin = a.twinx()
# in case you don't have access to atwin anymore you can get a handle to it with
atwin_alt = a.get_shared_x_axes().get_siblings(a)[0]
atwin == atwin_alt # should be True
I would guess that the cleanest way would be to create the axes outside the functions. Then you can supply whatever axes you like to the function.
import matplotlib.pyplot as plt
import numpy as np
def function1(ax1, ax2):
x = np.linspace(0,2*np.pi,100)
ax1.plot(x,np.sin(x),'b')
ax2.plot(x,1000*np.cos(x),'g')
def function2(ax1, ax2):
x = np.linspace(0,2*np.pi,100)
ax1.plot(x,np.sin(2*x),'m')
ax2.plot(x,1000*np.cos(2*x),'r')
fig, (ax, bx) = plt.subplots(nrows=2)
axtwin = ax.twinx()
function1(ax, axtwin)
function2(ax, axtwin)
plt.show()
I'd like to do something like this:
import matplotlib.pyplot as plt
%matplotlib inline
fig1 = plt.figure(1)
plt.plot([1,2,3],[5,2,4])
plt.show()
In one cell, and then redraw the exact same plot in another cell, like so:
plt.figure(1) # attempting to reference the figure I created earlier...
# four things I've tried:
plt.show() # does nothing... :(
fig1.show() # throws warning about backend and does nothing
fig1.draw() # throws error about renderer
fig1.plot([1,2,3],[5,2,4]) # This also doesn't work (jupyter outputs some
# text saying matplotlib.figure.Figure at 0x..., changing the backend and
# using plot don't help with that either), but regardless in reality
# these plots have a lot going on and I'd like to recreate them
# without running all of the same commands over again.
I've messed around with some combinations of this stuff as well but nothing works.
This question is similar to IPython: How to show the same plot in different cells? but I'm not particularly looking to update my plot, I just want to redraw it.
I have found a solution to do this. The trick is to create a figure with an axis fig, ax = plt.subplots() and use the axis to plot. Then we can just call fig at the end of any other cell we want to replot the figure.
import matplotlib.pyplot as plt
import numpy as np
x_1 = np.linspace(-.5,3.3,50)
y_1 = x_1**2 - 2*x_1 + 1
fig, ax = plt.subplots()
plt.title('Reusing this figure', fontsize=20)
ax.plot(x_1, y_1)
ax.set_xlabel('x',fontsize=18)
ax.set_ylabel('y',fontsize=18, rotation=0, labelpad=10)
ax.legend(['Eq 1'])
ax.axis('equal');
This produces
Now we can add more things by using the ax object:
t = np.linspace(0,2*np.pi,100)
h, a = 2, 2
k, b = 2, 3
x_2 = h + a*np.cos(t)
y_2 = k + b*np.sin(t)
ax.plot(x_2,y_2)
ax.legend(['Eq 1', 'Eq 2'])
fig
Note how I just wrote fig in the last line, making the notebook output the figure once again.
I hope this helps!
Example of scatterplot matrix
Is there such a function in matplotlib.pyplot?
For those who do not want to define their own functions, there is a great data analysis libarary in Python, called Pandas, where one can find the scatter_matrix() method:
from pandas.plotting import scatter_matrix
df = pd.DataFrame(np.random.randn(1000, 4), columns = ['a', 'b', 'c', 'd'])
scatter_matrix(df, alpha = 0.2, figsize = (6, 6), diagonal = 'kde')
Generally speaking, matplotlib doesn't usually contain plotting functions that operate on more than one axes object (subplot, in this case). The expectation is that you'd write a simple function to string things together however you'd like.
I'm not quite sure what your data looks like, but it's quite simple to just build a function to do this from scratch. If you're always going to be working with structured or rec arrays, then you can simplify this a touch. (i.e. There's always a name associated with each data series, so you can omit having to specify names.)
As an example:
import itertools
import numpy as np
import matplotlib.pyplot as plt
def main():
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
plt.show()
def scatterplot_matrix(data, names, **kwargs):
"""Plots a scatterplot matrix of subplots. Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names". Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid."""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
fig.subplots_adjust(hspace=0.05, wspace=0.05)
for ax in axes.flat:
# Hide all ticks and labels
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
# Set up ticks only on one side for the "edge" subplots...
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')
# Plot the data.
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
axes[x,y].plot(data[x], data[y], **kwargs)
# Label the diagonal subplots...
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
# Turn on the proper x or y axes ticks.
for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)
return fig
main()
You can also use Seaborn's pairplot function:
import seaborn as sns
sns.set()
df = sns.load_dataset("iris")
sns.pairplot(df, hue="species")
Thanks for sharing your code! You figured out all the hard stuff for us. As I was working with it, I noticed a few little things that didn't look quite right.
[FIX #1] The axis tics weren't lining up like I would expect (i.e., in your example above, you should be able to draw a vertical and horizontal line through any point across all plots and the lines should cross through the corresponding point in the other plots, but as it sits now this doesn't occur.
[FIX #2] If you have an odd number of variables you are plotting with, the bottom right corner axes doesn't pull the correct xtics or ytics. It just leaves it as the default 0..1 ticks.
Not a fix, but I made it optional to explicitly input names, so that it puts a default xi for variable i in the diagonal positions.
Below you'll find an updated version of your code that addresses these two points, otherwise preserving the beauty of your code.
import itertools
import numpy as np
import matplotlib.pyplot as plt
def scatterplot_matrix(data, names=[], **kwargs):
"""
Plots a scatterplot matrix of subplots. Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names". Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid.
"""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
fig.subplots_adjust(hspace=0.0, wspace=0.0)
for ax in axes.flat:
# Hide all ticks and labels
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
# Set up ticks only on one side for the "edge" subplots...
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')
# Plot the data.
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
# FIX #1: this needed to be changed from ...(data[x], data[y],...)
axes[x,y].plot(data[y], data[x], **kwargs)
# Label the diagonal subplots...
if not names:
names = ['x'+str(i) for i in range(numvars)]
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
# Turn on the proper x or y axes ticks.
for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)
# FIX #2: if numvars is odd, the bottom right corner plot doesn't have the
# correct axes limits, so we pull them from other axes
if numvars%2:
xlimits = axes[0,-1].get_xlim()
ylimits = axes[-1,0].get_ylim()
axes[-1,-1].set_xlim(xlimits)
axes[-1,-1].set_ylim(ylimits)
return fig
if __name__=='__main__':
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
plt.show()
Thanks again for sharing this with us. I have used it many times! Oh, and I re-arranged the main() part of the code so that it can be a formal example code or not get called if it is being imported into another piece of code.
While reading the question I expected to see an answer including rpy. I think this is a nice option taking advantage of two beautiful languages. So here it is:
import rpy
import numpy as np
def main():
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
mpg = data[0,:]
disp = data[1,:]
drat = data[2,:]
wt = data[3,:]
rpy.set_default_mode(rpy.NO_CONVERSION)
R_data = rpy.r.data_frame(mpg=mpg,disp=disp,drat=drat,wt=wt)
# Figure saved as eps
rpy.r.postscript('pairsPlot.eps')
rpy.r.pairs(R_data,
main="Simple Scatterplot Matrix Via RPy")
rpy.r.dev_off()
# Figure saved as png
rpy.r.png('pairsPlot.png')
rpy.r.pairs(R_data,
main="Simple Scatterplot Matrix Via RPy")
rpy.r.dev_off()
rpy.set_default_mode(rpy.BASIC_CONVERSION)
if __name__ == '__main__': main()
I can't post an image to show the result :( sorry!