Matplotlib: updating multiple scatter plots in a loop - python

I have two data sets that I would like to produce scatterplots for, with different colors.
Following the advice in MatPlotLib: Multiple datasets on the same scatter plot
I managed to plot them. However, I would like to be able to update the scatter plots inside of a loop that will affect both sets of data. I looked at the matplotlib animation package but it doesn't seem to fit the bill.
I cannot get the plot to update from within a loop.
The structure of the code looks like this:
fig = plt.figure()
ax1 = fig.add_subplot(111)
for g in range(gen):
# some simulation work that affects the data sets
peng_x, peng_y, bear_x, bear_y = generate_plot(population)
ax1.scatter(peng_x, peng_y, color = 'green')
ax1.scatter(bear_x, bear_y, color = 'red')
# this doesn't refresh the plots
Where generate_plot() extracts the relevant plotting information (x,y) coords from a numpy array with additional info and assigns them to the correct data set so they can be colored differently.
I've tried clearing and redrawing but I can't seem to get it to work.
Edit: Slight clarification. What I'm looking to do basically is to animate two scatter plots on the same plot.

Here's a code that might fit your description:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
# Create new Figure and an Axes which fills it.
fig = plt.figure(figsize=(7, 7))
ax = fig.add_axes([0, 0, 1, 1], frameon=False)
ax.set_xlim(-1, 1), ax.set_xticks([])
ax.set_ylim(-1, 1), ax.set_yticks([])
# Create data
ndata = 50
data = np.zeros(ndata, dtype=[('peng', float, 2), ('bear', float, 2)])
# Initialize the position of data
data['peng'] = np.random.randn(ndata, 2)
data['bear'] = np.random.randn(ndata, 2)
# Construct the scatter which we will update during animation
scat1 = ax.scatter(data['peng'][:, 0], data['peng'][:, 1], color='green')
scat2 = ax.scatter(data['bear'][:, 0], data['bear'][:, 1], color='red')
def update(frame_number):
# insert results from generate_plot(population) here
data['peng'] = np.random.randn(ndata, 2)
data['bear'] = np.random.randn(ndata, 2)
# Update the scatter collection with the new positions.
# Construct the animation, using the update function as the animation
# director.
animation = FuncAnimation(fig, update, interval=10)
You might also want to take a look at You can learn more tweaks in animating a scatter plot there.


multiple boxplots, side by side, using matplotlib from a dataframe

I'm trying to plot 60+ boxplots side by side from a dataframe and I was wondering if someone could suggest some possible solutions.
At the moment I have df_new, a dataframe with 66 columns, which I'm using to plot boxplots. The easiest way I found to plot the boxplots was to use the boxplot package inside pandas:
boxplot = df_new.boxplot(column=x, figsize = (100,50))
This gives me a very very tiny chart with illegible axis which I cannot seem to change the font size for, so I'm trying to do this natively in matplotlib but I cannot think of an efficient way of doing it. I'm trying to avoid creating 66 separate boxplots using something like:
fig, ax = plt.subplots(nrows = 1,
ncols = 66,
figsize = (10,5),
sharex = True)
ax[0,0].boxplot(#insert parameters here)
I actually do not not how to get the data from df_new.describe() into the boxplot function, so any tips on this would be greatly appreciated! The documentation is confusing. Not sure what x vectors should be.
Ideally I'd like to just give the boxplot function the dataframe and for it to automatically create all the boxplots by working out all the quartiles, column separations etc on the fly - is this even possible?
I tried to replace the boxplot with a ridge plot, which takes up less space because:
it requires half of the width
you can partially overlap the ridges
it develops vertically, so you can scroll down all the plot
I took the code from the seaborn documentation and adapted it a little bit in order to have 60 different ridges, normally distributed; here the code:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import itertools
sns.set(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})
# # Create the data
n = 20
x = list(np.random.randn(1, 60)[0])
g = [item[0] + item[1] for item in list(itertools.product(list('ABCDEFGHIJ'), list('123456')))]
df = pd.DataFrame({'x': n*x,
'g': n*g})
# Initialize the FacetGrid object
pal = sns.cubehelix_palette(10, rot=-.25, light=.7)
g = sns.FacetGrid(df, row="g", hue="g", aspect=15, height=.5, palette=pal)
# Draw the densities in a few steps, "x", clip_on=False, shade=True, alpha=1, lw=1.5, bw=.2), "x", clip_on=False, color="w", lw=2, bw=.2), y=0, lw=2, clip_on=False)
# Define and use a simple function to label the plot in axes coordinates
def label(x, color, label):
ax = plt.gca()
ax.text(0, .2, label, fontweight="bold", color=color,
ha="left", va="center", transform=ax.transAxes), "x")
# Set the subplots to overlap
# Remove axes details that don't play well with overlap
g.despine(bottom=True, left=True)
This is the result I get:
I don't know if it will be good for your needs, in any case keep in mind that keeping so many distributions next to each other will always require a lot of space (and a very big screen).
Maybe you could try dividing the distrubutions into smaller groups and plotting them a little at a time?

Preserve content of fig after show() in matplotlib?

I'm creating a violinplot of some data and afterwards I render a scatterplot with individual data points (red points in example) to three subplots.
Since the generation of the violinplot is relatively time consuming, I'm generating the violinplot only once, then add the scatterplot for one data row, write the result file, remove the scatterplots from the axes and add the scatterplots for the next row.
Everything works, but I would like to add the option, to show() each plot prior to saving it.
If I'm using, the figure is shown correctly, but afterwards the figure seems to be cleared and in the next iteration I'm getting the plot without the violin plots.
Is there any way to preserve the content of the figure after
In short, my code is
fig = generate_plot(ws, show=False) #returns the fig instance of the violin plot
#if I do here (or in "generate_plot()"), the violin plots are gone.
ax1, ax3, ax2 = fig.get_axes()
scatter1 = ax1.scatter(...) #draw scatter plot for first axes
[...] #same vor every axis
I was thinking that a possible option is to use the event loop to advance through the plots. The following would define an updating function, which changes only the scatter points, draws the image and saves it. We can manage this via a class with a callback on the key_press - such then when you hit Space the next image is shown; upon pressing Space on the last image, the plot is closed.
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
class NextPlotter(object):
def __init__(self, fig, func, n):
self.i = 0
self.cid = self.fig.canvas.mpl_connect("key_press_event", self.adv)
def adv(self, evt):
if evt.key == " " and self.i < self.n:
elif self.i >= self.n:
#Start of code:
# Create data
pos = [1, 2, 4, 5, 7, 8]
data = [np.random.normal(0, std, size=100) for std in pos]
data2 = [np.random.rayleigh(std, size=100) for std in pos]
scatterdata = np.random.normal(0, 5, size=(10,len(pos)))
#Create plot
fig, axes = plt.subplots(ncols=2)
axes[0].violinplot(data, pos, points=40, widths=0.9,
showmeans=True, showextrema=True, showmedians=True)
axes[1].violinplot(data2, pos, points=40, widths=0.9,
showmeans=True, showextrema=True, showmedians=True)
scatter = axes[0].scatter(pos, scatterdata[0,:], c="crimson", s=60)
scatter2 = axes[1].scatter(pos, scatterdata[1,:], c="crimson", s=60)
# define updating function
def update(i):
# instantiate NextPlotter; press <space> to advance to the next image
c = NextPlotter(fig, update, len(scatterdata)//2)
A workaround could be to not remove the scatterplot.
Why not keep the scatter plot axis, and just update the data for that set of axis?
You will most likely need a plt.draw() after update of scatter plot data to force a new rendering.
I found a way to draw figures interactively here. plt.ion() and block the process with input() seems to be important.
import matplotlib.pyplot as plt
fig = plt.figure()
ax = plt.subplot(1,1,1)
ax.set_xlim([-1, 5])
ax.set_ylim([-1, 5])
for i in range(5):
lineObject = ax.plot(i,i,'ro')
# plt.draw() # not necessary?
I also tried to block the process with time.sleep(1), but it does not work at all.

Adding second legend to scatter plot

Is there a way to add a secondary legend to a scatterplot, where the size of the scatter is proportional to some data?
I have written the following code that generates a scatterplot. The color of the scatter represents the year (and is taken from a user-defined df) while the size of the scatter represents variable 3 (also taken from a df but is raw data):
import pandas as pd
colors = pd.DataFrame({'1985':'red','1990':'b','1995':'k','2000':'g','2005':'m','2010':'y'}, index=[0,1,2,3,4,5])
fig = plt.figure()
ax = fig.add_subplot(111)
for i in df.keys():
df[i].plot(kind='scatter',x='variable1',y='variable2',ax=ax,label=i,s=df[i]['variable3']/100, c=colors[i])
ax.legend(loc='upper right')
ax.set_xlabel("Variable 1")
ax.set_ylabel("Variable 2")
This code (with my data) produces the following graph:
So while the colors/years are well and clearly defined, the size of the scatter is not.
How can I add a secondary or additional legend that defines what the size of the scatter means?
You will need to create the second legend yourself, i.e. you need to create some artists to populate the legend with. In the case of a scatter we can use a normal plot and set the marker accordingly.
This is shown in the below example. To actually add a second legend we need to add the first legend to the axes, such that the new legend does not overwrite the first one.
import matplotlib.pyplot as plt
import matplotlib.colors
import numpy as np; np.random.seed(1)
import pandas as pd
plt.rcParams["figure.subplot.right"] = 0.8
v = np.random.rand(30,4)
v[:,2] = np.random.choice(np.arange(1980,2015,5), size=30)
v[:,3] = np.random.randint(5,13,size=30)
df= pd.DataFrame(v, columns=["x","y","year","quality"])
df.year = df.year.values.astype(int)
fig, ax = plt.subplots()
for i, (name, dff) in enumerate(df.groupby("year")):
c = matplotlib.colors.to_hex(
dff.plot(kind='scatter',x='x',y='y', label=name, c=c,
s=dff.quality**2, ax=ax)
leg = plt.legend(loc=(1.03,0), title="Year")
h = [plt.plot([],[], color="gray", marker="o", ms=i, ls="")[0] for i in range(5,13)]
plt.legend(handles=h, labels=range(5,13),loc=(1.03,0.5), title="Quality")
Have a look at
It shows how to have multiple legends (about halfway down) and there is another example that shows how to set the marker size.
If that doesn't work, then you can also create a custom legend (last example).

How to get rid of extra white space on subplots with shared axes?

I'm creating a plot using python 3.5.1 and matplotlib 1.5.1 that has two subplots (side by side) with a shared Y axis. A sample output image is shown below:
Notice the extra white space at the top and bottom of each set of axes. Try as I might I can't seem to get rid of it. The overall goal of the figure is to have a waterfall type plot on the left with a shared Y axes with the plot on the right.
Here's some sample code to reproduce the image above.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
%matplotlib inline
# create some X values
periods = np.linspace(1/1440, 1, 1000)
# create some Y values (will be datetimes, not necessarily evenly spaced
# like they are in this example)
day_ints = np.linspace(1, 100, 100)
days = pd.to_timedelta(day_ints, 'D') + pd.to_datetime('2016-01-01')
# create some fake data for the number of points
points = np.random.random(len(day_ints))
# create some fake data for the color mesh
Sxx = np.random.random((len(days), len(periods)))
# Create the plots
fig = plt.figure(figsize=(8, 6))
# create first plot
ax1 = plt.subplot2grid((1,5), (0,0), colspan=4)
im = ax1.pcolormesh(periods, days, Sxx, cmap='viridis', vmin=0, vmax=1)
ax1.autoscale(enable=True, axis='Y', tight=True)
# create second plot and use the same y axis as the first one
ax2 = plt.subplot2grid((1,5), (0,4), sharey=ax1)
ax2.scatter(points, days)
ax2.autoscale(enable=True, axis='Y', tight=True)
# Hide the Y axis scale on the second plot
plt.setp(ax2.get_yticklabels(), visible=False)
fig.colorbar(im, ax=ax1)
As you can see in the commented out code I've tried a number of approaches, as suggested by posts like and Matplotlib: set axis tight only to x or y axis.
As soon as I remove the sharey=ax1 part of the second subplot2grid call the problem goes away, but then I also don't have a common Y axis.
Autoscale tends to add a buffer to the data so that all of the data points are easily visible and not part-way cut off by the axes.
ax1.autoscale(enable=True, axis='Y', tight=True)
ax2.autoscale(enable=True, axis='Y', tight=True)
To get:

Is there a function to make scatterplot matrices in matplotlib?

Example of scatterplot matrix
Is there such a function in matplotlib.pyplot?
For those who do not want to define their own functions, there is a great data analysis libarary in Python, called Pandas, where one can find the scatter_matrix() method:
from pandas.plotting import scatter_matrix
df = pd.DataFrame(np.random.randn(1000, 4), columns = ['a', 'b', 'c', 'd'])
scatter_matrix(df, alpha = 0.2, figsize = (6, 6), diagonal = 'kde')
Generally speaking, matplotlib doesn't usually contain plotting functions that operate on more than one axes object (subplot, in this case). The expectation is that you'd write a simple function to string things together however you'd like.
I'm not quite sure what your data looks like, but it's quite simple to just build a function to do this from scratch. If you're always going to be working with structured or rec arrays, then you can simplify this a touch. (i.e. There's always a name associated with each data series, so you can omit having to specify names.)
As an example:
import itertools
import numpy as np
import matplotlib.pyplot as plt
def main():
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
def scatterplot_matrix(data, names, **kwargs):
"""Plots a scatterplot matrix of subplots. Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names". Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid."""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
fig.subplots_adjust(hspace=0.05, wspace=0.05)
for ax in axes.flat:
# Hide all ticks and labels
# Set up ticks only on one side for the "edge" subplots...
if ax.is_first_col():
if ax.is_last_col():
if ax.is_first_row():
if ax.is_last_row():
# Plot the data.
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
axes[x,y].plot(data[x], data[y], **kwargs)
# Label the diagonal subplots...
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
# Turn on the proper x or y axes ticks.
for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
return fig
You can also use Seaborn's pairplot function:
import seaborn as sns
df = sns.load_dataset("iris")
sns.pairplot(df, hue="species")
Thanks for sharing your code! You figured out all the hard stuff for us. As I was working with it, I noticed a few little things that didn't look quite right.
[FIX #1] The axis tics weren't lining up like I would expect (i.e., in your example above, you should be able to draw a vertical and horizontal line through any point across all plots and the lines should cross through the corresponding point in the other plots, but as it sits now this doesn't occur.
[FIX #2] If you have an odd number of variables you are plotting with, the bottom right corner axes doesn't pull the correct xtics or ytics. It just leaves it as the default 0..1 ticks.
Not a fix, but I made it optional to explicitly input names, so that it puts a default xi for variable i in the diagonal positions.
Below you'll find an updated version of your code that addresses these two points, otherwise preserving the beauty of your code.
import itertools
import numpy as np
import matplotlib.pyplot as plt
def scatterplot_matrix(data, names=[], **kwargs):
Plots a scatterplot matrix of subplots. Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names". Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid.
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
fig.subplots_adjust(hspace=0.0, wspace=0.0)
for ax in axes.flat:
# Hide all ticks and labels
# Set up ticks only on one side for the "edge" subplots...
if ax.is_first_col():
if ax.is_last_col():
if ax.is_first_row():
if ax.is_last_row():
# Plot the data.
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
# FIX #1: this needed to be changed from ...(data[x], data[y],...)
axes[x,y].plot(data[y], data[x], **kwargs)
# Label the diagonal subplots...
if not names:
names = ['x'+str(i) for i in range(numvars)]
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
# Turn on the proper x or y axes ticks.
for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
# FIX #2: if numvars is odd, the bottom right corner plot doesn't have the
# correct axes limits, so we pull them from other axes
if numvars%2:
xlimits = axes[0,-1].get_xlim()
ylimits = axes[-1,0].get_ylim()
return fig
if __name__=='__main__':
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
Thanks again for sharing this with us. I have used it many times! Oh, and I re-arranged the main() part of the code so that it can be a formal example code or not get called if it is being imported into another piece of code.
While reading the question I expected to see an answer including rpy. I think this is a nice option taking advantage of two beautiful languages. So here it is:
import rpy
import numpy as np
def main():
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
mpg = data[0,:]
disp = data[1,:]
drat = data[2,:]
wt = data[3,:]
R_data = rpy.r.data_frame(mpg=mpg,disp=disp,drat=drat,wt=wt)
# Figure saved as eps
main="Simple Scatterplot Matrix Via RPy")
# Figure saved as png
main="Simple Scatterplot Matrix Via RPy")
if __name__ == '__main__': main()
I can't post an image to show the result :( sorry!
