Preserve content of fig after show() in matplotlib? - python

I'm creating a violinplot of some data and afterwards I render a scatterplot with individual data points (red points in example) to three subplots.
Since the generation of the violinplot is relatively time consuming, I'm generating the violinplot only once, then add the scatterplot for one data row, write the result file, remove the scatterplots from the axes and add the scatterplots for the next row.
Everything works, but I would like to add the option, to show() each plot prior to saving it.
If I'm using plt.show(), the figure is shown correctly, but afterwards the figure seems to be cleared and in the next iteration I'm getting the plot without the violin plots.
Is there any way to preserve the content of the figure after plt.show()?
In short, my code is
fig = generate_plot(ws, show=False) #returns the fig instance of the violin plot
#if I do plt.show() here (or in "generate_plot()"), the violin plots are gone.
ax1, ax3, ax2 = fig.get_axes()
scatter1 = ax1.scatter(...) #draw scatter plot for first axes
[...] #same vor every axis
plt.savefig(...)
scatter1.remove()

I was thinking that a possible option is to use the event loop to advance through the plots. The following would define an updating function, which changes only the scatter points, draws the image and saves it. We can manage this via a class with a callback on the key_press - such then when you hit Space the next image is shown; upon pressing Space on the last image, the plot is closed.
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
import numpy as np
class NextPlotter(object):
def __init__(self, fig, func, n):
self.__dict__.update(locals())
self.i = 0
self.cid = self.fig.canvas.mpl_connect("key_press_event", self.adv)
def adv(self, evt):
if evt.key == " " and self.i < self.n:
self.func(self.i)
self.i+=1
elif self.i >= self.n:
plt.close("all")
#Start of code:
# Create data
pos = [1, 2, 4, 5, 7, 8]
data = [np.random.normal(0, std, size=100) for std in pos]
data2 = [np.random.rayleigh(std, size=100) for std in pos]
scatterdata = np.random.normal(0, 5, size=(10,len(pos)))
#Create plot
fig, axes = plt.subplots(ncols=2)
axes[0].violinplot(data, pos, points=40, widths=0.9,
showmeans=True, showextrema=True, showmedians=True)
axes[1].violinplot(data2, pos, points=40, widths=0.9,
showmeans=True, showextrema=True, showmedians=True)
scatter = axes[0].scatter(pos, scatterdata[0,:], c="crimson", s=60)
scatter2 = axes[1].scatter(pos, scatterdata[1,:], c="crimson", s=60)
# define updating function
def update(i):
scatter.set_offsets(np.c_[pos,scatterdata[2*i,:]])
scatter2.set_offsets(np.c_[pos,scatterdata[2*i+1,:]])
fig.canvas.draw()
plt.savefig("plot{i}.png".format(i=i))
# instantiate NextPlotter; press <space> to advance to the next image
c = NextPlotter(fig, update, len(scatterdata)//2)
plt.show()

A workaround could be to not remove the scatterplot.
Why not keep the scatter plot axis, and just update the data for that set of axis?
You will most likely need a plt.draw() after update of scatter plot data to force a new rendering.

I found a way to draw figures interactively here. plt.ion() and block the process with input() seems to be important.
import matplotlib.pyplot as plt
plt.ion()
fig = plt.figure()
ax = plt.subplot(1,1,1)
ax.set_xlim([-1, 5])
ax.set_ylim([-1, 5])
ax.grid('on')
for i in range(5):
lineObject = ax.plot(i,i,'ro')
fig.savefig('%02d.png'%i)
# plt.draw() # not necessary?
input()
lineObject[0].remove()
I also tried to block the process with time.sleep(1), but it does not work at all.

Related

Plotting a variable number of series and data forecasts onto subplots, including axis labels, formatting and line colours

I'm writing a program which gets data and then uses time series forecasting to predict data values for the next, say, 300 data points.
However, only data which fulfills a certain condition will be plotted, so there is no defined number of subplots for the add_subplot() method. I'm aware of the plot.subplots() function, but something such as
fig, (ax1, ax2) = plt.subplots(1, 2)
implies that two graphs will definitely be plotted and I need to change the specific amount, like with a parameter.
Here is a simplified version of the current code which results in each plot being in separate windows:
fig = plt.figure() # creates a figure instance for the final graph output
plots = 1 # indicates the total number of plots to plot, starting from 1
# passed as a parameter to the add_subplot() function
for data in dataSet:
forecast(data, fig, plots)
plt.figure(fig.number)
plt.show()
And the function:
import matplotlib.pyplot as plt
import matplotlib.ticker as tick
from pandas import Series
from statsmodels.tsa.holtwinters import ExponentialSmoothing
def forecast(data, superFigure, plotNumber):
index = range(0, len(data))
plotData = Series(data, index)
# fit the data values into a specific model:
modelFit = ExponentialSmoothing(plotData, trend="add").fit()
# forecast for the next 300 points:
modelForecast = modelFit.forecast(300)
if [condition]:
# plot the original data points:
points = plotData.plot(marker='x', color='black', label='Base Data')
points.set_xlim(0, len(data) + 300)
# plot the forecast in a different colour:
modelForecast.plot(marker='x', ax=points, color='blue', label='Forecasted Data')
plt.title("Plot Title")
plt.xlabel("X Axis")
plt.ylabel("Y Axis")
# format the axes, adding thousand separator
points.get_xaxis().set_major_formatter(
tick.FuncFormatter(lambda x, p: format(int(x), ',')))
points.get_yaxis().set_major_formatter(
tick.FuncFormatter(lambda x, p: format(int(x), ',')))
plt.legend()
plt.show()
This produces multiple graphs such as this (actual labels have been cut out).
Unfortunately you have to close each graph before viewing the next one, and I want every graph to be visible on one page.
I tried changing the code within the "if [condition]" to:
if [condition]:
points = plotData.plot(marker='x', color='black', label='Base Data')
modelForecast.plot(marker='x', ax=points, color='blue', label='Forecasted Data')
dataLine = plt.gca().get_lines()[0]
forecastLine = plt.gca().get_lines()[1]
# put all x and y values into single lists by concatenating them
totalXData = [*dataLine.get_xdata(), *forecastLine.get_xdata()]
totalYData = [*dataLine.get_ydata(), *forecastLine.get_ydata()]
subset = superFigure.add_subplot(10, 10, plotNumber)
for i in range(0, len(totalXData)):
subset.plot(totalXData[i], totalYData[i])
plotNumber += 1
These changes produce this exact graph which seems to have the other graphs squished in the top-left corner, and I get "MatplotlibDeprecationWarning: Adding an axes using the same arguments as a previous axes currently reuses the earlier instance" warnings.
If I change "superFigure.add_subplot(10, 10, plotNumber)" to "superFigure.add_subplot(20, 20, plotNumber)" I also get "UserWarning: Tight layout not applied. tight_layout cannot make axes width small enough to accommodate all axes decorations".
I then tried to change it to:
if [condition]:
fig, ax = plt.subplots()
plotData.plot(marker='x', ax=ax, color='black', label='Base Data')
modelForecast.plot(marker='x', ax=ax, color='blue', label='Forecasted Data')
ax.set([...])
ax.legend()
plt.show()
which doesn't produce the desired output assumedly because it recreates the figure on each call of forecast(), unless a figure window can contain multiple figures.
I also sometimes get the following warning:
RuntimeWarning: More than 20 figures have been opened. Figures created
through the pyplot interface (matplotlib.pyplot.figure) are retained
until explicitly closed and may consume too much memory.
fig, ax = plt.subplots()
How can I create subplots which include all the formatting and are displayed in one window all together?

How to check if colorbar exists on figure

Question: Is there a way to check if a color bar already exists?
I am making many plots with a loop. The issue is that the color bar is drawn every iteration!
If I could determine if the color bar exists then I can put the color bar function in an if statement.
if cb_exists:
# do nothing
else:
plt.colorbar() #draw the colorbar
If I use multiprocessing to make the figures, is it possible to prevent multiple color bars from being added?
import numpy as np
import matplotlib.pyplot as plt
import multiprocessing
def plot(number):
a = np.random.random([5,5])*number
plt.pcolormesh(a)
plt.colorbar()
plt.savefig('this_'+str(number))
# I want to make a 50 plots
some_list = range(0,50)
num_proc = 5
p = multiprocessing.Pool(num_proc)
temps = p.map(plot, some_list)
I realize I can clear the figure with plt.clf() and plt.cla() before plotting the next iteration. But, I have data on my basemap layer I don't want to re-plot (that adds to the time it takes to create the plot). So, if I could remove the colorbar and add a new one I'd save some time.
Is is actually not easy to remove a colorbar from a plot and later draw a new one to it.
The best solution I can come up with at the moment is the following, which assumes that there is only one axes present in the plot. Now, if there was a second axis, it must be the colorbar beeing present. So by checking how many axes we find on the plot, we can judge upon whether or not there is a colorbar.
Here we also mind the user's wish not to reference any named objects from outside. (Which does not makes much sense, as we need to use plt anyways, but hey.. so was the question)
import matplotlib.pyplot as plt
import numpy as np
fig, ax = plt.subplots()
im = ax.pcolormesh(np.array(np.random.rand(2,2) ))
ax.plot(np.cos(np.linspace(0.2,1.8))+0.9, np.sin(np.linspace(0.2,1.8))+0.9, c="k", lw=6)
ax.set_title("Title")
cbar = plt.colorbar(im)
cbar.ax.set_ylabel("Label")
for i in range(10):
# inside this loop we should not access any variables defined outside
# why? no real reason, but questioner asked for it.
#draw new colormesh
im = plt.gcf().gca().pcolormesh(np.random.rand(2,2))
#check if there is more than one axes
if len(plt.gcf().axes) > 1:
# if so, then the last axes must be the colorbar.
# we get its extent
pts = plt.gcf().axes[-1].get_position().get_points()
# and its label
label = plt.gcf().axes[-1].get_ylabel()
# and then remove the axes
plt.gcf().axes[-1].remove()
# then we draw a new axes a the extents of the old one
cax= plt.gcf().add_axes([pts[0][0],pts[0][1],pts[1][0]-pts[0][0],pts[1][1]-pts[0][1] ])
# and add a colorbar to it
cbar = plt.colorbar(im, cax=cax)
cbar.ax.set_ylabel(label)
# unfortunately the aspect is different between the initial call to colorbar
# without cax argument. Try to reset it (but still it's somehow different)
cbar.ax.set_aspect(20)
else:
plt.colorbar(im)
plt.show()
In general a much better solution would be to operate on the objects already present in the plot and only update them with the new data. Thereby, we suppress the need to remove and add axes and find a much cleaner and faster solution.
import matplotlib.pyplot as plt
import numpy as np
fig, ax = plt.subplots()
im = ax.pcolormesh(np.array(np.random.rand(2,2) ))
ax.plot(np.cos(np.linspace(0.2,1.8))+0.9, np.sin(np.linspace(0.2,1.8))+0.9, c="k", lw=6)
ax.set_title("Title")
cbar = plt.colorbar(im)
cbar.ax.set_ylabel("Label")
for i in range(10):
data = np.array(np.random.rand(2,2) )
im.set_array(data.flatten())
cbar.set_clim(vmin=data.min(),vmax=data.max())
cbar.draw_all()
plt.draw()
plt.show()
Update:
Actually, the latter approach of referencing objects from outside even works together with the multiprocess approach desired by the questioner.
So, here is a code that updates the figure, without the need to delete the colorbar.
import matplotlib.pyplot as plt
import numpy as np
import multiprocessing
import time
fig, ax = plt.subplots()
im = ax.pcolormesh(np.array(np.random.rand(2,2) ))
ax.plot(np.cos(np.linspace(0.2,1.8))+0.9, np.sin(np.linspace(0.2,1.8))+0.9, c="w", lw=6)
ax.set_title("Title")
cbar = plt.colorbar(im)
cbar.ax.set_ylabel("Label")
tx = ax.text(0.2,0.8, "", fontsize=30, color="w")
tx2 = ax.text(0.2,0.2, "", fontsize=30, color="w")
def do(number):
start = time.time()
tx.set_text(str(number))
data = np.array(np.random.rand(2,2)*(number+1) )
im.set_array(data.flatten())
cbar.set_clim(vmin=data.min(),vmax=data.max())
tx2.set_text("{m:.2f} < {ma:.2f}".format(m=data.min(), ma= data.max() ))
cbar.draw_all()
plt.draw()
plt.savefig("multiproc/{n}.png".format(n=number))
stop = time.time()
return np.array([number, start, stop])
if __name__ == "__main__":
multiprocessing.freeze_support()
some_list = range(0,50)
num_proc = 5
p = multiprocessing.Pool(num_proc)
nu = p.map(do, some_list)
nu = np.array(nu)
plt.close("all")
fig, ax = plt.subplots(figsize=(16,9))
ax.barh(nu[:,0], nu[:,2]-nu[:,1], height=np.ones(len(some_list)), left=nu[:,1], align="center")
plt.show()
(The code at the end shows a timetable which allows to see that multiprocessing has indeed taken place)
If you can access to axis and image information, colorbar can be retrieved
as a property of the image (or the mappable to which associate colorbar).
Following a previous answer (How to retrieve colorbar instance from figure in matplotlib), an example could be:
ax=plt.gca() #plt.gca() for current axis, otherwise set appropriately.
im=ax.images #this is a list of all images that have been plotted
if im[-1].colorbar is None: #in this case I assume to be interested to the last one plotted, otherwise use the appropriate index or loop over
plt.colorbar() #plot a new colorbar
Note that an image without colorbar returns None to im[-1].colorbar
One approach is:
initially (prior to having any color bar drawn), set a variable
colorBarPresent = False
in the method for drawing the color bar, check to see if it's already drawn. If not, draw it and set the colorBarPresent variable True:
def drawColorBar():
if colorBarPresent:
# leave the function and don't draw the bar again
else:
# draw the color bar
colorBarPresent = True
There is an indirect way of guessing (with reasonable accuracy for most applications, I think) whether an Axes instance is home to a color bar. Depending on whether it is a horizontal or vertical color bar, either the X axis or Y axis (but not both) will satisfy all of these conditions:
No ticks
No tick labels
No axis label
Axis range is (0, 1)
So here's a function for you:
def is_colorbar(ax):
"""
Guesses whether a set of Axes is home to a colorbar
:param ax: Axes instance
:return: bool
True if the x xor y axis satisfies all of the following and thus looks like it's probably a colorbar:
No ticks, no tick labels, no axis label, and range is (0, 1)
"""
xcb = (len(ax.get_xticks()) == 0) and (len(ax.get_xticklabels()) == 0) and (len(ax.get_xlabel()) == 0) and \
(ax.get_xlim() == (0, 1))
ycb = (len(ax.get_yticks()) == 0) and (len(ax.get_yticklabels()) == 0) and (len(ax.get_ylabel()) == 0) and \
(ax.get_ylim() == (0, 1))
return xcb != ycb # != is effectively xor in this case, since xcb and ycb are both bool
Thanks to this answer for the cool != xor trick: https://stackoverflow.com/a/433161/6605826
With this function, you can see if a colorbar exists by:
colorbar_exists = any([is_colorbar(ax) for ax in np.atleast_1d(gcf().axes).flatten()])
or if you're sure the colorbar will always be last, you can get off easy with:
colorbar_exists = is_colorbar(gcf().axes[-1])

Update Existent Matplotlib Subplot with a user input

I am currently trying to implement a 'zoom' functionality into my code. By this I mean I would like to have two subplots side by side, one of which contains the initial data and the other which contains a 'zoomed in' plot which is decided by user input.
Currently, I can create two subplots side by side, but after calling for the user input, instead of updating the second subplot, my script is creating an entirely new figure below and not updating the second subplot. It is important that the graph containing data is plotted first so the user can choose the value for the input accordingly.
def plot_func(data):
plot_this = data
plt.close('all')
fig = plt.figure()
#Subplot 1
ax1 = fig.add_subplot(1,2,1)
ax1.plot(plot_this)
plt.show()
zoom = input("Where would you like to zoom to: ")
zoom_in = plot_this[0:int(zoom)]
#Subplot 2
ax2 = fig.add_subplot(1,2,2)
ax2.plot(zoom_in)
plt.show()
The code above is a simplified version of what I am hoping to do. Display a subplot, and let the user enter an input based on that subplot. Then either edit a subplot that is already created or create a new one that is next to the first. Again it is crucial that the 'zoomed in' subplot is alongside the first opposed to below.
I think it is not very convenient for the user to type in numbers for zooming. The more standard way would be mouse interaction as already provided by the various matplotlib tools.
There is no standard tool for zooming in a different plot, but we can easily provide this functionality using matplotlib.widgets.RectangleSelector as shown in the code below.
We need to plot the same data in two subplots and connect the RectangleSelector to one of the subplots (ax). Every time a selection is made, the data coordinates of the selection in the first subplot are simply used as axis-limits on the second subplot, effectiveliy proving zoom-in (or magnification) functionality.
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import RectangleSelector
def onselect(eclick, erelease):
#http://matplotlib.org/api/widgets_api.html
xlim = np.sort(np.array([erelease.xdata,eclick.xdata ]))
ylim = np.sort(np.array([erelease.ydata,eclick.ydata ]))
ax2.set_xlim(xlim)
ax2.set_ylim(ylim)
def toggle_selector(event):
# press escape to return to non-zoomed plot
if event.key in ['escape'] and toggle_selector.RS.active:
ax2.set_xlim(ax.get_xlim())
ax2.set_ylim(ax.get_ylim())
x = np.arange(100)/(100.)*7.*np.pi
y = np.sin(x)**2
fig = plt.figure()
ax = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
#plot identical data in both axes
ax.plot(x,y, lw=2)
ax.plot([5,14,21],[.3,.6,.1], marker="s", color="red", ls="none")
ax2.plot(x,y, lw=2)
ax2.plot([5,14,21],[.3,.6,.1], marker="s", color="red", ls="none")
ax.set_title("Select region with your mouse.\nPress escape to deactivate zoom")
ax2.set_title("Zoomed Plot")
toggle_selector.RS = RectangleSelector(ax, onselect, drawtype='box', interactive=True)
fig.canvas.mpl_connect('key_press_event', toggle_selector)
plt.show()
%matplotlib inline
import mpld3
mpld3.enable_notebook()

Matplotlib: updating multiple scatter plots in a loop

I have two data sets that I would like to produce scatterplots for, with different colors.
Following the advice in MatPlotLib: Multiple datasets on the same scatter plot
I managed to plot them. However, I would like to be able to update the scatter plots inside of a loop that will affect both sets of data. I looked at the matplotlib animation package but it doesn't seem to fit the bill.
I cannot get the plot to update from within a loop.
The structure of the code looks like this:
fig = plt.figure()
ax1 = fig.add_subplot(111)
for g in range(gen):
# some simulation work that affects the data sets
peng_x, peng_y, bear_x, bear_y = generate_plot(population)
ax1.scatter(peng_x, peng_y, color = 'green')
ax1.scatter(bear_x, bear_y, color = 'red')
# this doesn't refresh the plots
Where generate_plot() extracts the relevant plotting information (x,y) coords from a numpy array with additional info and assigns them to the correct data set so they can be colored differently.
I've tried clearing and redrawing but I can't seem to get it to work.
Edit: Slight clarification. What I'm looking to do basically is to animate two scatter plots on the same plot.
Here's a code that might fit your description:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
# Create new Figure and an Axes which fills it.
fig = plt.figure(figsize=(7, 7))
ax = fig.add_axes([0, 0, 1, 1], frameon=False)
ax.set_xlim(-1, 1), ax.set_xticks([])
ax.set_ylim(-1, 1), ax.set_yticks([])
# Create data
ndata = 50
data = np.zeros(ndata, dtype=[('peng', float, 2), ('bear', float, 2)])
# Initialize the position of data
data['peng'] = np.random.randn(ndata, 2)
data['bear'] = np.random.randn(ndata, 2)
# Construct the scatter which we will update during animation
scat1 = ax.scatter(data['peng'][:, 0], data['peng'][:, 1], color='green')
scat2 = ax.scatter(data['bear'][:, 0], data['bear'][:, 1], color='red')
def update(frame_number):
# insert results from generate_plot(population) here
data['peng'] = np.random.randn(ndata, 2)
data['bear'] = np.random.randn(ndata, 2)
# Update the scatter collection with the new positions.
scat1.set_offsets(data['peng'])
scat2.set_offsets(data['bear'])
# Construct the animation, using the update function as the animation
# director.
animation = FuncAnimation(fig, update, interval=10)
plt.show()
You might also want to take a look at http://matplotlib.org/examples/animation/rain.html. You can learn more tweaks in animating a scatter plot there.

Is there a function to make scatterplot matrices in matplotlib?

Example of scatterplot matrix
Is there such a function in matplotlib.pyplot?
For those who do not want to define their own functions, there is a great data analysis libarary in Python, called Pandas, where one can find the scatter_matrix() method:
from pandas.plotting import scatter_matrix
df = pd.DataFrame(np.random.randn(1000, 4), columns = ['a', 'b', 'c', 'd'])
scatter_matrix(df, alpha = 0.2, figsize = (6, 6), diagonal = 'kde')
Generally speaking, matplotlib doesn't usually contain plotting functions that operate on more than one axes object (subplot, in this case). The expectation is that you'd write a simple function to string things together however you'd like.
I'm not quite sure what your data looks like, but it's quite simple to just build a function to do this from scratch. If you're always going to be working with structured or rec arrays, then you can simplify this a touch. (i.e. There's always a name associated with each data series, so you can omit having to specify names.)
As an example:
import itertools
import numpy as np
import matplotlib.pyplot as plt
def main():
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
plt.show()
def scatterplot_matrix(data, names, **kwargs):
"""Plots a scatterplot matrix of subplots. Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names". Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid."""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
fig.subplots_adjust(hspace=0.05, wspace=0.05)
for ax in axes.flat:
# Hide all ticks and labels
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
# Set up ticks only on one side for the "edge" subplots...
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')
# Plot the data.
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
axes[x,y].plot(data[x], data[y], **kwargs)
# Label the diagonal subplots...
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
# Turn on the proper x or y axes ticks.
for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)
return fig
main()
You can also use Seaborn's pairplot function:
import seaborn as sns
sns.set()
df = sns.load_dataset("iris")
sns.pairplot(df, hue="species")
Thanks for sharing your code! You figured out all the hard stuff for us. As I was working with it, I noticed a few little things that didn't look quite right.
[FIX #1] The axis tics weren't lining up like I would expect (i.e., in your example above, you should be able to draw a vertical and horizontal line through any point across all plots and the lines should cross through the corresponding point in the other plots, but as it sits now this doesn't occur.
[FIX #2] If you have an odd number of variables you are plotting with, the bottom right corner axes doesn't pull the correct xtics or ytics. It just leaves it as the default 0..1 ticks.
Not a fix, but I made it optional to explicitly input names, so that it puts a default xi for variable i in the diagonal positions.
Below you'll find an updated version of your code that addresses these two points, otherwise preserving the beauty of your code.
import itertools
import numpy as np
import matplotlib.pyplot as plt
def scatterplot_matrix(data, names=[], **kwargs):
"""
Plots a scatterplot matrix of subplots. Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names". Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid.
"""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
fig.subplots_adjust(hspace=0.0, wspace=0.0)
for ax in axes.flat:
# Hide all ticks and labels
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
# Set up ticks only on one side for the "edge" subplots...
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')
# Plot the data.
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
# FIX #1: this needed to be changed from ...(data[x], data[y],...)
axes[x,y].plot(data[y], data[x], **kwargs)
# Label the diagonal subplots...
if not names:
names = ['x'+str(i) for i in range(numvars)]
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
# Turn on the proper x or y axes ticks.
for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)
# FIX #2: if numvars is odd, the bottom right corner plot doesn't have the
# correct axes limits, so we pull them from other axes
if numvars%2:
xlimits = axes[0,-1].get_xlim()
ylimits = axes[-1,0].get_ylim()
axes[-1,-1].set_xlim(xlimits)
axes[-1,-1].set_ylim(ylimits)
return fig
if __name__=='__main__':
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
plt.show()
Thanks again for sharing this with us. I have used it many times! Oh, and I re-arranged the main() part of the code so that it can be a formal example code or not get called if it is being imported into another piece of code.
While reading the question I expected to see an answer including rpy. I think this is a nice option taking advantage of two beautiful languages. So here it is:
import rpy
import numpy as np
def main():
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
mpg = data[0,:]
disp = data[1,:]
drat = data[2,:]
wt = data[3,:]
rpy.set_default_mode(rpy.NO_CONVERSION)
R_data = rpy.r.data_frame(mpg=mpg,disp=disp,drat=drat,wt=wt)
# Figure saved as eps
rpy.r.postscript('pairsPlot.eps')
rpy.r.pairs(R_data,
main="Simple Scatterplot Matrix Via RPy")
rpy.r.dev_off()
# Figure saved as png
rpy.r.png('pairsPlot.png')
rpy.r.pairs(R_data,
main="Simple Scatterplot Matrix Via RPy")
rpy.r.dev_off()
rpy.set_default_mode(rpy.BASIC_CONVERSION)
if __name__ == '__main__': main()
I can't post an image to show the result :( sorry!

Categories