Plot map in loop without plotting previous points - python

I am trying to plot points drifting in the sea. The following code works, but plots all the points of the previous plots as well:
Duration = 6 #hours
plt.figure(figsize=(20,10))#20,10
map = Basemap(width=300000,height=300000,projection='lcc',
resolution='c',lat_0=51.25,lon_0=-4)#lat_0lon_0 is left under
map.drawmapboundary(fill_color='turquoise')
map.fillcontinents(color='white',lake_color='aqua')
map.drawcountries(linestyle='--')
x=[]
y=[]
for i in range (0,Duration):
x,y = map(Xpos1[i],Ypos1[i])
map.scatter(x, y, s=1, c='k', marker='o', label = 'Aurelia aurita', zorder=2)
plt.savefig('GIF10P%d' %i)
Xpos1 and Ypos1 are a list of masked arrays. Every array in the lists has a length of 10, so 10 points should be plotted in each map:
Xpos1=[[latitude0,lat1,lat2,lat3,..., lat9],
[latitude0,lat1,lat2,lat3,..., lat9],...]
This gives me six figures, I'll show you the first and last:
Every picture is supposed to have 10 points, but the last one is a combination of all the maps (so 60 points).
How do I still get 6 maps with only 10 points each?
Edit:
When I use the answer from matplotlib.pyplot will not forget previous plots - how can I flush/refresh? I get the error
ValueError: Can not reset the axes. You are probably trying to re-use an artist in more than one Axes which is not supported
A similar error pops up when I use the answer from: How to "clean the slate"?
namely,
plt.clf()
plt.cla()#after plt.show()
Any help is deeply appreciated!

Instead of plotting new scatter plots for each image it would make sense to update the scatter plot's data. The advantage is that the map only needs to be created once, saving some time.
from mpl_toolkits.basemap import Basemap
import matplotlib.pyplot as plt
import numpy as np
Duration = 6 #hours
Xpos1 = np.random.normal(-4, 0.6, size=(Duration,10))
Ypos1 = np.random.normal(51.25, 0.6, size=(Duration,10))
plt.figure(figsize=(20,10))
m = Basemap(width=300000,height=300000,projection='lcc',
resolution='c',lat_0=51.25,lon_0=-4)
m.drawmapboundary(fill_color='turquoise')
m.fillcontinents(color='white',lake_color='aqua')
m.drawcountries(linestyle='--')
scatter = m.scatter([], [], s=10, c='k', marker='o', label = 'Aurelia aurita', zorder=2)
for i in range (0,Duration):
x,y = m(Xpos1[i],Ypos1[i])
scatter.set_offsets(np.c_[x,y])
plt.savefig('GIF10P%d' %i)
plt.show()

As #Primusa said, simply moving all of the things into the for loop works to redefine the map.
The correct code is then:
for i in range (0,Duration,int(24/FramesPerDay)):
plt.figure(figsize=(20,10))#20,10
map = Basemap(width=300000,height=300000,projection='lcc',
resolution='c',lat_0=51.25,lon_0=-4)#lat_0lon_0 is left under
map.drawmapboundary(fill_color='turquoise')
map.fillcontinents(color='white',lake_color='aqua')
map.drawcountries(linestyle='--')
x,y = map(Xpos1[i],Ypos1[i])
map.scatter(x, y, s=1, c='k', marker='o', label = 'Aurelia aurita', zorder=2)
plt.savefig('GIF10P%d' %i)

Related

How to determine the x value on the edge of the violinplot for a mean line

I am trying to draw a mean line on violin plots, since I was not able to find a way to make sns replace the "median" line that comes from "quartiles", I decided to code so that for each case it draws on top. I am planning on drawing horizontal lines using plt.plot on the mean value (y value) of each of the three graphs I have.
I have the exact y (height) values where I want my horizontal line to be drawn, however, I am having difficulty trying to figure out the bound of each violin graph on that specific y value. I know since it is symmetric the domain is (-x, x), so I need a way to find that "x" value for me to be able to have 3 added horizontal lines which each bounded by the violin graphs that I have.
Here is my code, the x value of the plt.plot is -0.37, which is something I found by trial and error, I want python to find that for me for a given y value.
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
data = [2.57e-05, 4.17e-06, -5.4e-06, -5.05e-06, 1.15e-05, -6.7e-06, 1.01e-05, 5.53e-06, 8.13e-06, 1.27e-05, 1.11e-06, -2.87e-06, -1.38e-06, -1.07e-05, -8.04e-06, 4.77e-06, 3.22e-07, 9.86e-06, 1.38e-05, 1.32e-05, -3.48e-06, -4.69e-06, 8.15e-06, 4.21e-07, 2.71e-06, 7.52e-08, 1.04e-06, -1.92e-06, -4.08e-06, 4.76e-06]
vg = sns.violinplot(data=data, inner="quartile", scale="width")
a = sns.pointplot(data=data, zlinestyles='-', join=False, ci=None, color='red')
for p in vg.lines:
p.set_linestyle('-')
p.set_linewidth(0.8) # Sets the thickness of the quartile lines
p.set_color('white') # Sets the color of the quartile lines
p.set_alpha(0.8)
for p in vg.lines[1::3]: # these are the median lines; not means
p.set_linestyle('-')
p.set_linewidth(0) # Sets the thickness of the median lines
p.set_color('black') # Sets the color of the median lines
p.set_alpha(0.8)
# add a mean line from the edge of the violin plot
plt.plot([-0.37, 0], [np.mean(data), np.mean(data)], 'k-', lw=1)
plt.show()
Refer to the picture where I removed the median point but left the quartile lines, where I want to draw mean lines across where the blue dots are visible
And here is a picture once I draw that plt.plot with the x value I found via trial and error: For case I only
You can draw a line that is too long, and then clip it with the polygon forming the violin.
Note that inner='quartile' shows the 25%, 50% and 75% lines. The 50% line is also known as the median. This is similar to how boxplots are usually drawn. It is rather confusing to show the mean in a too similar style. That's why seaborn (and many other libraries) prefer to show the mean as a point.
Here is some example code (note that the return value of sns.violinplot is an ax, and naming it very different makes it rather hard to find your way into matplotlib and seaborn docs and examples).
import matplotlib.pyplot as plt
from matplotlib.patches import PathPatch
import seaborn as sns
import pandas as pd
import numpy as np
tips = sns.load_dataset('tips')
tips['day'] = pd.Categorical(tips['day'])
ax = sns.violinplot(data=tips, x='day', y='total_bill', hue='day', inner='quartile', scale='width', dodge=False)
sns.pointplot(data=tips, x='day', y='total_bill', join=False, ci=None, color='yellow', ax=ax)
ax.legend_.remove()
for p in ax.lines:
p.set_linestyle('-')
p.set_linewidth(0.8) # Sets the thickness of the quartile lines
p.set_color('white') # Sets the color of the quartile lines
p.set_alpha(0.8)
for x, (day, violin) in enumerate(zip(tips['day'].cat.categories, ax.collections)):
line = ax.hlines(tips[tips['day'] == day]['total_bill'].mean(), x - 0.5, x + 0.5, color='black', ls=':', lw=2)
patch = PathPatch(violin.get_paths()[0], transform=ax.transData)
line.set_clip_path(patch) # clip the line by the form of the violin
plt.show()
Updated to use a list of lists of data:
data = [np.random.randn(10, 7).cumsum(axis=0).ravel() for _ in range(3)]
ax = sns.violinplot(data=data, inner='quartile', scale='width', palette='Set2')
# sns.pointplot(data=data, join=False, ci=None, color='red', ax=ax) # shows the means
ax.set_xticks(range(len(data)))
ax.set_xticklabels(['I' * (k + 1) for k in range(len(data))])
for p in ax.lines:
p.set_linestyle('-')
p.set_linewidth(0.8) # Sets the thickness of the quartile lines
p.set_color('white') # Sets the color of the quartile lines
p.set_alpha(0.8)
for x, (data_x, violin) in enumerate(zip(data, ax.collections)):
line = ax.hlines(np.mean(data_x), x - 0.5, x + 0.5, color='black', ls=':', lw=2)
patch = PathPatch(violin.get_paths()[0], transform=ax.transData)
line.set_clip_path(patch)
plt.show()
PS: Some further explanation about enumerate(zip(...))
for data_x in data: would loop through the entries of the list data, first assigning data[0] to data_x etc.
for x, data_x in enumerate(data): would loop through the entries of the list data and at the same time increment a variable x from 0 to 1 and finally to 2.
for data_x, violin in zip(data, ax.collections): would the data_x loop through the entries of the list data and simultaneously a variable violin through the list stored in ax.collections (this is where matplotlib stores the shapes of the violins)
for x, (data_x, violin) in enumerate(zip(data, ax.collections)): combines the enumeration with zip`

How to get rid of extra white space on subplots with shared axes?

I'm creating a plot using python 3.5.1 and matplotlib 1.5.1 that has two subplots (side by side) with a shared Y axis. A sample output image is shown below:
Notice the extra white space at the top and bottom of each set of axes. Try as I might I can't seem to get rid of it. The overall goal of the figure is to have a waterfall type plot on the left with a shared Y axes with the plot on the right.
Here's some sample code to reproduce the image above.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
%matplotlib inline
# create some X values
periods = np.linspace(1/1440, 1, 1000)
# create some Y values (will be datetimes, not necessarily evenly spaced
# like they are in this example)
day_ints = np.linspace(1, 100, 100)
days = pd.to_timedelta(day_ints, 'D') + pd.to_datetime('2016-01-01')
# create some fake data for the number of points
points = np.random.random(len(day_ints))
# create some fake data for the color mesh
Sxx = np.random.random((len(days), len(periods)))
# Create the plots
fig = plt.figure(figsize=(8, 6))
# create first plot
ax1 = plt.subplot2grid((1,5), (0,0), colspan=4)
im = ax1.pcolormesh(periods, days, Sxx, cmap='viridis', vmin=0, vmax=1)
ax1.invert_yaxis()
ax1.autoscale(enable=True, axis='Y', tight=True)
# create second plot and use the same y axis as the first one
ax2 = plt.subplot2grid((1,5), (0,4), sharey=ax1)
ax2.scatter(points, days)
ax2.autoscale(enable=True, axis='Y', tight=True)
# Hide the Y axis scale on the second plot
plt.setp(ax2.get_yticklabels(), visible=False)
#ax1.set_adjustable('box-forced')
#ax2.set_adjustable('box-forced')
fig.colorbar(im, ax=ax1)
As you can see in the commented out code I've tried a number of approaches, as suggested by posts like https://github.com/matplotlib/matplotlib/issues/1789/ and Matplotlib: set axis tight only to x or y axis.
As soon as I remove the sharey=ax1 part of the second subplot2grid call the problem goes away, but then I also don't have a common Y axis.
Autoscale tends to add a buffer to the data so that all of the data points are easily visible and not part-way cut off by the axes.
Change:
ax1.autoscale(enable=True, axis='Y', tight=True)
to:
ax1.set_ylim(days.min(),days.max())
and
ax2.autoscale(enable=True, axis='Y', tight=True)
to:
ax2.set_ylim(days.min(),days.max())
To get:

How to pick a point in a subplot and highlight it in adjacent subplots in matplotlib(extension to region of points)

I want to create a scatter plot matrix which will be composed by some subplots. I have extracted from a .txt file my data and created an array of shape (x,y,z,p1,p2,p3). The first three columns of the array represent the x,y,z coordinates from the original image that these data come from and the last three columns(p1, p2, p3) some other parameters. Consequently, in each row of the array the parameters p1, p2, p3 have the same coordinates(x,y,z).In the scatter plot, I want to visualize the p1 parameter against the p2, p3 parameters in a first stage. For every point I pick, I would like its (x,y,z) parameters from the first three columns of my array to be annotated and the point with the same coordinates in the adjacent subplot to be highlighted or its color to be modified.
In my code, two subplots are created and in the terminal are printed the (p1,p2 or p3) values that are acquired by picking a point, the respective values of the same point in the adjacent subplot and the (x,y,z) parameters of this point.
Moreover, when I pick a point in the first subplot, the color of the corresponding point in the second subplot changes but not vice versa. This color modification is recognizable only if I resize manually the figure. How could I add interactivity for both subplots without having to tweak the figure in order to notice any changes? What kind of modifications should I make in order this interactivity to be feasible in a reduced scatter plot matrix like in this question "Is there a function to make scatterplot matrices in matplotlib?" . I am not an experienced python, matplotlib user, so any kind of help will be appreciated
import numpy as np
import matplotlib.pyplot as plt
import pylab as pl
def main():
#load data from file
data = np.loadtxt(r"data.txt")
plt.close("all")
x = data[:, 3]
y = data[:, 4]
y1 = data[:, 5]
fig1 = plt.figure(1)
#subplot p1 vs p2
plt.subplot(121)
subplot1, = plt.plot(x, y, 'bo', picker=3)
plt.xlabel('p1')
plt.ylabel('p2')
#subplot p1 vs p3
plt.subplot(122)
subplot2, = plt.plot(x, y1, 'bo', picker=3)
plt.xlabel('p1')
plt.ylabel('p3')
plt.subplots_adjust(left=0.1, right=0.95, wspace=0.3, hspace=0.45)
# art.getp(fig1.patch)
def onpick(event):
thisevent = event.artist
valx = thisevent.get_xdata()
valy = thisevent.get_ydata()
ind = event.ind
print 'index', ind
print 'selected point:', zip(valx[ind], valy[ind])
print 'point in the adjacent subplot', x[ind], y1[ind]
print '(x,y,z):', data[:, 0][ind], data[:, 1][ind], data[:, 2][ind]
for xcord,ycord in zip(valx[ind], valy[ind]):
plt.annotate("(x,y,z):", xy = (x[ind], y1[ind]), xycoords = ('data' ),
xytext=(x[ind] - .5, y1[ind]- .5), textcoords='data',
arrowprops=dict(arrowstyle="->",
connectionstyle="arc3"),
)
subplot2, = plt.plot(x[ind], y[ind], 'ro', picker=3)
subplot1 = plt.plot(x[ind], y[ind], 'ro', picker=3)
fig1.canvas.mpl_connect('pick_event', onpick)
plt.show()
main()
In conclusion, information are printed in the terminal, independently of the subplot, when I pick a point. But, the color is modified only in the points of the right subplot, when I pick a point in the left subplot and not vice versa. Moreover, the change of the color is not noticeable until I tweak the figure(e.g. move it or resize it) and when I choose a second point, the previous one remains colored.
Any kind of contribution will be appreciated. Thank you in advance.
You're already on the right track with your current code. You're basically just missing a call to plt.draw() in your onpick function.
However, in our discussion in the comments, mpldatacursor came up, and you asked about an example of doing things that way.
The current HighlightingDataCursor in mpldatacursor is set up around the idea of highlighting an entire Line2D artist, not just a particular index of it. (It's deliberately a bit limited, as there's no good way to draw an arbitrary highlight for any artist in matplotlib, so I kept the highlighting parts small.)
However, you could subclass things similar to this (assumes you're using plot and want the first thing you plot in each axes to be used). I'm also illustrating using point_labels, in case you want to have different labels for each point shown.:
import numpy as np
import matplotlib.pyplot as plt
from mpldatacursor import HighlightingDataCursor, DataCursor
def main():
fig, axes = plt.subplots(nrows=2, ncols=2)
for ax, marker in zip(axes.flat, ['o', '^', 's', '*']):
x, y = np.random.random((2,20))
ax.plot(x, y, ls='', marker=marker)
IndexedHighlight(axes.flat, point_labels=[str(i) for i in range(20)])
plt.show()
class IndexedHighlight(HighlightingDataCursor):
def __init__(self, axes, **kwargs):
# Use the first plotted Line2D in each axes
artists = [ax.lines[0] for ax in axes]
kwargs['display'] = 'single'
HighlightingDataCursor.__init__(self, artists, **kwargs)
self.highlights = [self.create_highlight(artist) for artist in artists]
plt.setp(self.highlights, visible=False)
def update(self, event, annotation):
# Hide all other annotations
plt.setp(self.highlights, visible=False)
# Highlight everything with the same index.
artist, ind = event.artist, event.ind
for original, highlight in zip(self.artists, self.highlights):
x, y = original.get_data()
highlight.set(visible=True, xdata=x[ind], ydata=y[ind])
DataCursor.update(self, event, annotation)
main()
Again, this assumes you're using plot and not, say, scatter. It is possible to do this with scatter, but you need to change an annoyingly large amount of details. (There's no general way to highlight an arbitrary matplotlib artist, so you have to have a lot of very verbose code to deal with each type of artist individually.)
Hope it's useful, at any rate.

Instead of grid lines on a plot, can matplotlib print grid crosses?

I want to have some grid lines on a plot, but actually full-length lines are too much/distracting, even dashed light grey lines. I went and manually did some editing of the SVG output to get the effect I was looking for. Can this be done with matplotlib? I had a look at the pyplot api for grid, and the only thing I can see that might be able to get near it are the xdata and ydata Line2D kwargs.
This cannot be done through the basic API, because the grid lines are created using only two points. The grid lines would need a 'data' point at every tick mark for there to be a marker drawn. This is shown in the following example:
import matplotlib.pyplot as plt
ax = plt.subplot(111)
ax.grid(clip_on=False, marker='o', markersize=10)
plt.savefig('crosses.png')
plt.show()
This results in:
Notice how the 'o' markers are only at the beginning and the end of the Axes edges, because the grid lines only involve two points.
You could write a method to emulate what you want, creating the cross marks using a series of Artists, but it's quicker to just leverage the basic plotting capabilities to draw the cross pattern.
This is what I do in the following example:
import matplotlib.pyplot as plt
import numpy as np
NPOINTS=100
def set_grid_cross(ax, in_back=True):
xticks = ax.get_xticks()
yticks = ax.get_yticks()
xgrid, ygrid = np.meshgrid(xticks, yticks)
kywds = dict()
if in_back:
kywds['zorder'] = 0
grid_lines = ax.plot(xgrid, ygrid, 'k+', **kywds)
xvals = np.arange(NPOINTS)
yvals = np.random.random(NPOINTS) * NPOINTS
ax1 = plt.subplot(121)
ax2 = plt.subplot(122)
ax1.plot(xvals, yvals, linewidth=4)
ax1.plot(xvals, xvals, linewidth=7)
set_grid_cross(ax1)
ax2.plot(xvals, yvals, linewidth=4)
ax2.plot(xvals, xvals, linewidth=7)
set_grid_cross(ax2, in_back=False)
plt.savefig('gridpoints.png')
plt.show()
This results in the following figure:
As you can see, I take the tick marks in x and y to define a series of points where I want grid marks ('+'). I use meshgrid to take two 1D arrays and make 2 2D arrays corresponding to the double loop over each grid point. I plot this with the mark style as '+', and I'm done... almost. This plots the crosses on top, and I added an extra keyword to reorder the list of lines associated with the plot. I adjust the zorder of the grid marks if they are to be drawn behind everything.*****
The example shows the left subplot where by default the grid is placed in back, and the right subplot disables this option. You can notice the difference if you follow the green line in each plot.
If you are bothered by having grid crosses on the boarder, you can remove the first and last tick marks for both x and y before you define the grid in set_grid_cross, like so:
xticks = ax.get_xticks()[1:-1] #< notice the slicing
yticks = ax.get_yticks()[1:-1] #< notice the slicing
xgrid, ygrid = np.meshgrid(xticks, yticks)
I do this in the following example, using a larger, different marker to make my point:
***** Thanks to the answer by #fraxel for pointing this out.
You can draw on line segments at every intersection of the tickpoints. Its pretty easy to do, just grab the tick locations get_ticklocs() for both axis, then loop through all combinations, drawing short line segments using axhline and axvline, thus creating a cross hair at every intersection. I've set zorder=0 so the cross-hairs are drawn first, so that they are behind the plot data. Its easy to control the color/alpha and cross-hair size. Couple of slight 'gotchas'... do the plot before you get the tick locations.. and also the xmin and xmax parameters seem to require normalisation.
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.plot((0,2,3,5,5,5,6,7,8,6,6,4,3,32,7,99), 'r-',linewidth=4)
x_ticks = ax.xaxis.get_ticklocs()
y_ticks = ax.yaxis.get_ticklocs()
for yy in y_ticks[1:-1]:
for xx in x_ticks[1:-1]:
plt.axhline(y=yy, xmin=xx / max(x_ticks) - 0.02,
xmax=xx / max(x_ticks) + 0.02, color='gray', alpha=0.5, zorder=0)
plt.axvline(x=xx, ymin=yy / max(y_ticks) - 0.02,
ymax=yy / max(y_ticks) + 0.02, color='gray', alpha=0.5, zorder=0)
plt.show()

Is there a function to make scatterplot matrices in matplotlib?

Example of scatterplot matrix
Is there such a function in matplotlib.pyplot?
For those who do not want to define their own functions, there is a great data analysis libarary in Python, called Pandas, where one can find the scatter_matrix() method:
from pandas.plotting import scatter_matrix
df = pd.DataFrame(np.random.randn(1000, 4), columns = ['a', 'b', 'c', 'd'])
scatter_matrix(df, alpha = 0.2, figsize = (6, 6), diagonal = 'kde')
Generally speaking, matplotlib doesn't usually contain plotting functions that operate on more than one axes object (subplot, in this case). The expectation is that you'd write a simple function to string things together however you'd like.
I'm not quite sure what your data looks like, but it's quite simple to just build a function to do this from scratch. If you're always going to be working with structured or rec arrays, then you can simplify this a touch. (i.e. There's always a name associated with each data series, so you can omit having to specify names.)
As an example:
import itertools
import numpy as np
import matplotlib.pyplot as plt
def main():
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
plt.show()
def scatterplot_matrix(data, names, **kwargs):
"""Plots a scatterplot matrix of subplots. Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names". Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid."""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
fig.subplots_adjust(hspace=0.05, wspace=0.05)
for ax in axes.flat:
# Hide all ticks and labels
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
# Set up ticks only on one side for the "edge" subplots...
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')
# Plot the data.
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
axes[x,y].plot(data[x], data[y], **kwargs)
# Label the diagonal subplots...
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
# Turn on the proper x or y axes ticks.
for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)
return fig
main()
You can also use Seaborn's pairplot function:
import seaborn as sns
sns.set()
df = sns.load_dataset("iris")
sns.pairplot(df, hue="species")
Thanks for sharing your code! You figured out all the hard stuff for us. As I was working with it, I noticed a few little things that didn't look quite right.
[FIX #1] The axis tics weren't lining up like I would expect (i.e., in your example above, you should be able to draw a vertical and horizontal line through any point across all plots and the lines should cross through the corresponding point in the other plots, but as it sits now this doesn't occur.
[FIX #2] If you have an odd number of variables you are plotting with, the bottom right corner axes doesn't pull the correct xtics or ytics. It just leaves it as the default 0..1 ticks.
Not a fix, but I made it optional to explicitly input names, so that it puts a default xi for variable i in the diagonal positions.
Below you'll find an updated version of your code that addresses these two points, otherwise preserving the beauty of your code.
import itertools
import numpy as np
import matplotlib.pyplot as plt
def scatterplot_matrix(data, names=[], **kwargs):
"""
Plots a scatterplot matrix of subplots. Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names". Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid.
"""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
fig.subplots_adjust(hspace=0.0, wspace=0.0)
for ax in axes.flat:
# Hide all ticks and labels
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
# Set up ticks only on one side for the "edge" subplots...
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')
# Plot the data.
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
# FIX #1: this needed to be changed from ...(data[x], data[y],...)
axes[x,y].plot(data[y], data[x], **kwargs)
# Label the diagonal subplots...
if not names:
names = ['x'+str(i) for i in range(numvars)]
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
# Turn on the proper x or y axes ticks.
for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)
# FIX #2: if numvars is odd, the bottom right corner plot doesn't have the
# correct axes limits, so we pull them from other axes
if numvars%2:
xlimits = axes[0,-1].get_xlim()
ylimits = axes[-1,0].get_ylim()
axes[-1,-1].set_xlim(xlimits)
axes[-1,-1].set_ylim(ylimits)
return fig
if __name__=='__main__':
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
plt.show()
Thanks again for sharing this with us. I have used it many times! Oh, and I re-arranged the main() part of the code so that it can be a formal example code or not get called if it is being imported into another piece of code.
While reading the question I expected to see an answer including rpy. I think this is a nice option taking advantage of two beautiful languages. So here it is:
import rpy
import numpy as np
def main():
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
mpg = data[0,:]
disp = data[1,:]
drat = data[2,:]
wt = data[3,:]
rpy.set_default_mode(rpy.NO_CONVERSION)
R_data = rpy.r.data_frame(mpg=mpg,disp=disp,drat=drat,wt=wt)
# Figure saved as eps
rpy.r.postscript('pairsPlot.eps')
rpy.r.pairs(R_data,
main="Simple Scatterplot Matrix Via RPy")
rpy.r.dev_off()
# Figure saved as png
rpy.r.png('pairsPlot.png')
rpy.r.pairs(R_data,
main="Simple Scatterplot Matrix Via RPy")
rpy.r.dev_off()
rpy.set_default_mode(rpy.BASIC_CONVERSION)
if __name__ == '__main__': main()
I can't post an image to show the result :( sorry!

Categories