Matplotlib cropping secondary axis labels [duplicate] - python

I need to generate a whole bunch of vertically-stacked plots in matplotlib. The result will be saved using savefig and viewed on a webpage, so I don't care how tall the final image is, as long as the subplots are spaced so they don't overlap.
No matter how big I allow the figure to be, the subplots always seem to overlap.
My code currently looks like
import matplotlib.pyplot as plt
import my_other_module
titles, x_lists, y_lists = my_other_module.get_data()
fig = plt.figure(figsize=(10,60))
for i, y_list in enumerate(y_lists):
plt.subplot(len(titles), 1, i)
plt.xlabel("Some X label")
plt.ylabel("Some Y label")
plt.title(titles[i])
plt.plot(x_lists[i],y_list)
fig.savefig('out.png', dpi=100)

Please review matplotlib: Tight Layout guide and try using matplotlib.pyplot.tight_layout, or matplotlib.figure.Figure.tight_layout
As a quick example:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(8, 8))
fig.tight_layout() # Or equivalently, "plt.tight_layout()"
plt.show()
Without Tight Layout
With Tight Layout

You can use plt.subplots_adjust to change the spacing between the subplots.
call signature:
subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=None)
The parameter meanings (and suggested defaults) are:
left = 0.125 # the left side of the subplots of the figure
right = 0.9 # the right side of the subplots of the figure
bottom = 0.1 # the bottom of the subplots of the figure
top = 0.9 # the top of the subplots of the figure
wspace = 0.2 # the amount of width reserved for blank space between subplots
hspace = 0.2 # the amount of height reserved for white space between subplots
The actual defaults are controlled by the rc file

Using subplots_adjust(hspace=0) or a very small number (hspace=0.001) will completely remove the whitespace between the subplots, whereas hspace=None does not.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as tic
fig = plt.figure(figsize=(8, 8))
x = np.arange(100)
y = 3.*np.sin(x*2.*np.pi/100.)
for i in range(1, 6):
temp = 510 + i
ax = plt.subplot(temp)
plt.plot(x, y)
plt.subplots_adjust(hspace=0)
temp = tic.MaxNLocator(3)
ax.yaxis.set_major_locator(temp)
ax.set_xticklabels(())
ax.title.set_visible(False)
plt.show()
hspace=0 or hspace=0.001
hspace=None

Similar to tight_layout matplotlib now (as of version 2.2) provides constrained_layout. In contrast to tight_layout, which may be called any time in the code for a single optimized layout, constrained_layout is a property, which may be active and will optimze the layout before every drawing step.
Hence it needs to be activated before or during subplot creation, such as figure(constrained_layout=True) or subplots(constrained_layout=True).
Example:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(4,4, constrained_layout=True)
plt.show()
constrained_layout may as well be set via rcParams
plt.rcParams['figure.constrained_layout.use'] = True
See the what's new entry and the Constrained Layout Guide

import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10,60))
plt.subplots_adjust( ... )
The plt.subplots_adjust method:
def subplots_adjust(*args, **kwargs):
"""
call signature::
subplots_adjust(left=None, bottom=None, right=None, top=None,
wspace=None, hspace=None)
Tune the subplot layout via the
:class:`matplotlib.figure.SubplotParams` mechanism. The parameter
meanings (and suggested defaults) are::
left = 0.125 # the left side of the subplots of the figure
right = 0.9 # the right side of the subplots of the figure
bottom = 0.1 # the bottom of the subplots of the figure
top = 0.9 # the top of the subplots of the figure
wspace = 0.2 # the amount of width reserved for blank space between subplots
hspace = 0.2 # the amount of height reserved for white space between subplots
The actual defaults are controlled by the rc file
"""
fig = gcf()
fig.subplots_adjust(*args, **kwargs)
draw_if_interactive()
or
fig = plt.figure(figsize=(10,60))
fig.subplots_adjust( ... )
The size of the picture matters.
"I've tried messing with hspace, but increasing it only seems to make all of the graphs smaller without resolving the overlap problem."
Thus to make more white space and keep the sub plot size the total image needs to be bigger.

You could try the .subplot_tool()
plt.subplot_tool()

Resolving this issue when plotting a dataframe with pandas.DataFrame.plot, which uses matplotlib as the default backend.
The following works for whichever kind= is specified (e.g. 'bar', 'scatter', 'hist', etc.).
Tested in python 3.8.12, pandas 1.3.4, matplotlib 3.4.3
Imports and sample data
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# sinusoidal sample data
sample_length = range(1, 15+1)
rads = np.arange(0, 2*np.pi, 0.01)
data = np.array([np.sin(t*rads) for t in sample_length])
df = pd.DataFrame(data.T, index=pd.Series(rads.tolist(), name='radians'), columns=[f'freq: {i}x' for i in sample_length])
# default plot with subplots; each column is a subplot
axes = df.plot(subplots=True)
Adjust the Spacing
Adjust the default parameters in pandas.DataFrame.plot
Change figsize: a width of 5 and a height of 4 for each subplot is a good place to start.
Change layout: (rows, columns) for the layout of subplots.
sharey=True and sharex=True so space isn't taken for redundant labels on each subplot.
The .plot method returns a numpy array of matplotlib.axes.Axes, which should be flattened to easily work with.
Use .get_figure() to extract the DataFrame.plot figure object from one of the Axes.
Use fig.tight_layout() if desired.
axes = df.plot(subplots=True, layout=(3, 5), figsize=(25, 16), sharex=True, sharey=True)
# flatten the axes array to easily access any subplot
axes = axes.flat
# extract the figure object
fig = axes[0].get_figure()
# use tight_layout
fig.tight_layout()
df
# display(df.head(3))
freq: 1x freq: 2x freq: 3x freq: 4x freq: 5x freq: 6x freq: 7x freq: 8x freq: 9x freq: 10x freq: 11x freq: 12x freq: 13x freq: 14x freq: 15x
radians
0.00 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
0.01 0.010000 0.019999 0.029996 0.039989 0.049979 0.059964 0.069943 0.079915 0.089879 0.099833 0.109778 0.119712 0.129634 0.139543 0.149438
0.02 0.019999 0.039989 0.059964 0.079915 0.099833 0.119712 0.139543 0.159318 0.179030 0.198669 0.218230 0.237703 0.257081 0.276356 0.295520

This answer shows using fig.tight_layout after creating the figure. However, tight_layout can be set directly when creating the figure, because matplotlib.pyplot.subplots accepts additional parameters with **fig_kw. All additional keyword arguments are passed to the pyplot.figure call.
See How to plot in multiple subplots for accessing and plotting in subplots.
import matplotlib.pyplot as plt
# create the figure with tight_layout=True
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(8, 8), tight_layout=True)

Related

Matplotlib: Plot on double y-axis plot misaligned

I'm trying to plot two datasets into one plot with matplotlib. One of the two plots is misaligned by 1 on the x-axis.
This MWE pretty much sums up the problem. What do I have to adjust to bring the box-plot further to the left?
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
titles = ["nlnd", "nlmd", "nlhd", "mlnd", "mlmd", "mlhd", "hlnd", "hlmd", "hlhd"]
plotData = pd.DataFrame(np.random.rand(25, 9), columns=titles)
failureRates = pd.DataFrame(np.random.rand(9, 1), index=titles)
color = {'boxes': 'DarkGreen', 'whiskers': 'DarkOrange', 'medians': 'DarkBlue',
'caps': 'Gray'}
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax2 = ax1.twinx()
plotData.plot.box(ax=ax1, color=color, sym='+')
failureRates.plot(ax=ax2, color='b', legend=False)
ax1.set_ylabel('Seconds')
ax2.set_ylabel('Failure Rate in %')
plt.xlim(-0.7, 8.7)
ax1.set_xticks(range(len(titles)))
ax1.set_xticklabels(titles)
fig.tight_layout()
fig.show()
Actual result. Note that its only 8 box-plots instead of 9 and that they're starting at index 1.
The issue is a mismatch between how box() and plot() work - box() starts at x-position 1 and plot() depends on the index of the dataframe (which defaults to starting at 0). There are only 8 plots because the 9th is being cut off since you specify plt.xlim(-0.7, 8.7). There are several easy ways to fix this, as #Sheldore's answer indicates, you can explicitly set the positions for the boxplot. Another way you can do this is to change the indexing of the failureRates dataframe to start at 1 in construction of the dataframe, i.e.
failureRates = pd.DataFrame(np.random.rand(9, 1), index=range(1, len(titles)+1))
note that you need not specify the xticks or the xlim for the question MCVE, but you may need to for your complete code.
You can specify the positions on the x-axis where you want to have the box plots. Since you have 9 boxes, use the following which generates the figure below
plotData.plot.box(ax=ax1, color=color, sym='+', positions=range(9))

Scatter Matrix showing too many floating point values on graph

I'm trying to plot a scatter matrix using Python but the ticks on the y-axis for the top left plot has a high amount of unnecessary digits. I'm directly plotting the graph from pandas using scatter_matrix function from pandas.plotting
Also, I am quite new to Python so sorry if this is a stupid question but I just couldn't find the right answer to fit my needs.
I've tried to use different axis formatting options using yaxis.set_major_formatter (not sure if this doesn't work because I'm plotting from pandas, but yielding no results either way), pandas.set_option to customise display.
from pandas.plotting import scatter_matrix
scatter_matrix(df, alpha=0.3, figsize=(9,9), diagonal='kde')
df: Tesla Ret Ford Ret GM Ret
Date
2012-01-03 NaN NaN NaN
2012-01-04 -0.013177 0.015274 0.004751
2012-01-05 -0.021292 0.025664 0.048227
2012-01-06 -0.008481 0.010354 0.033829
2012-01-09 0.013388 0.007686 -0.003490
2012-01-10 0.013578 0.000000 0.017513
2012-01-11 0.022085 0.022881 0.052926
2012-01-12 0.000708 0.005800 0.008173
2012-01-13 -0.193274 -0.008237 -0.015403
2012-01-17 0.167179 -0.001661 -0.003705
...
I've tried to use:
plt.gca().yaxis.set_major_formatter(StrMethodFormatter('{x:,.2f}')) and ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) after importing the respective modulesm, to no avail.
Figure is available here
Everything else in the figure is just as it should be, just the y-axis of the top left plot. I would like it to show one or two decimal point values like the rest of the figure.
I'd greatly appreciate any help that could fix my issue.
Thanks.
P.S: I have edited this answer based on the problem pointed out by #ImportanceOfBeingEarnest (thanks to him). Please read the comments below the answer to see what I mean.
The new solution is to get the displayed ticks for that particular axis and format them up to 2 decimal places.
new_labels = [round(float(i.get_text()), 2) for i in axes[0,0].get_yticklabels()]
axes[0,0].set_yticklabels(new_labels)
OLD ANSWER (Still kept as a history as you will see that the y-ticks in the figure generated below are not correct)
The problem is that you are using ax object to format the labels but ax returned from scatter_matrix is not a single axis object. It is an object containing 9 axis (3x3 subfigure). You can prove this if you plot the shape of the axes variable.
axes = scatter_matrix(df, alpha=0.3, figsize=(9,9), diagonal='kde')
print (axes.shape)
# (3, 3)
The solution is either to iterate through all the axis or to just change the formatting for the problematic case. P.S: The figure below don't match with your's because I just used the small DataFrame you posted.
Following is how you can do it for all the y-axis
from pandas.plotting import scatter_matrix
from matplotlib.ticker import FormatStrFormatter
axes = scatter_matrix(df, alpha=0.3, figsize=(9,9), diagonal='kde')
for ax in axes.flatten():
ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
Alternatively you can just choose a particular axis. Here your top left subfigure can be accessed using axes[0,0]
axes[0,0].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
pandas.scatter_matrix suffers from an unfortunate design choice. That is, it plots the kde or histogram on the diagonal to the axes that shows the ticks for the rest of the row. This then requires to fake the ticks and labels to be fitting for the data. In the course of this a FixedLocator and a FixedFormatter are used. The format of the ticklabels is hence directly taken over from the string representation of a number.
I would propose a completely different design here. That is, the diagonal axes should stay empty, and instead twin axes are used to show the histogram or kde curve. The problem from the question can hence not occur.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
def scatter_matrix(df, axes=None, **kw):
n = df.columns.size
diagonal = kw.pop("diagonal", "hist")
if not axes:
fig, axes = plt.subplots(n,n, figsize=kw.pop("figsize", None),
squeeze=False, sharex="col", sharey="row")
else:
flax = axes.flatten()
fig = flax[0].figure
assert len(flax) == n*n
# no gaps between subplots
fig.subplots_adjust(wspace=0, hspace=0)
hist_kwds = kw.pop("hist_kwds", {})
density_kwds = kw.pop("density_kwds", {})
import itertools
p = itertools.permutations(df.columns, r=2)
n = itertools.permutations(np.arange(len(df.columns)), r=2)
for (i,j), (y,x) in zip(n,p):
axes[i,j].scatter(df[x].values, df[y].values, **kw)
axes[i,j].tick_params(left=False, labelleft=False,
bottom=False, labelbottom=False)
diagaxes = []
for i, c in enumerate(df.columns):
ax = axes[i,i].twinx()
diagaxes.append(ax)
if diagonal == 'hist':
ax.hist(df[c].values, **hist_kwds)
elif diagonal in ('kde', 'density'):
from scipy.stats import gaussian_kde
y = df[c].values
gkde = gaussian_kde(y)
ind = np.linspace(y.min(), y.max(), 1000)
ax.plot(ind, gkde.evaluate(ind), **density_kwds)
if i!= 0:
diagaxes[0].get_shared_y_axes().join(diagaxes[0], ax)
ax.axis("off")
for i,c in enumerate(df.columns):
axes[i,i].tick_params(left=False, labelleft=False,
bottom=False, labelbottom=False)
axes[i,0].set_ylabel(c)
axes[-1,i].set_xlabel(c)
axes[i,0].tick_params(left=True, labelleft=True)
axes[-1,i].tick_params(bottom=True, labelbottom=True)
return axes, diagaxes
df = pd.DataFrame(np.random.randn(1000, 4), columns=['A','B','C','D'])
axes,diagaxes = scatter_matrix(df, diagonal='kde', alpha=0.5)
plt.show()

Matplotlib Center plot vertically and horizontally [duplicate]

I need to generate a whole bunch of vertically-stacked plots in matplotlib. The result will be saved using savefig and viewed on a webpage, so I don't care how tall the final image is, as long as the subplots are spaced so they don't overlap.
No matter how big I allow the figure to be, the subplots always seem to overlap.
My code currently looks like
import matplotlib.pyplot as plt
import my_other_module
titles, x_lists, y_lists = my_other_module.get_data()
fig = plt.figure(figsize=(10,60))
for i, y_list in enumerate(y_lists):
plt.subplot(len(titles), 1, i)
plt.xlabel("Some X label")
plt.ylabel("Some Y label")
plt.title(titles[i])
plt.plot(x_lists[i],y_list)
fig.savefig('out.png', dpi=100)
Please review matplotlib: Tight Layout guide and try using matplotlib.pyplot.tight_layout, or matplotlib.figure.Figure.tight_layout
As a quick example:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(8, 8))
fig.tight_layout() # Or equivalently, "plt.tight_layout()"
plt.show()
Without Tight Layout
With Tight Layout
You can use plt.subplots_adjust to change the spacing between the subplots.
call signature:
subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=None)
The parameter meanings (and suggested defaults) are:
left = 0.125 # the left side of the subplots of the figure
right = 0.9 # the right side of the subplots of the figure
bottom = 0.1 # the bottom of the subplots of the figure
top = 0.9 # the top of the subplots of the figure
wspace = 0.2 # the amount of width reserved for blank space between subplots
hspace = 0.2 # the amount of height reserved for white space between subplots
The actual defaults are controlled by the rc file
Using subplots_adjust(hspace=0) or a very small number (hspace=0.001) will completely remove the whitespace between the subplots, whereas hspace=None does not.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as tic
fig = plt.figure(figsize=(8, 8))
x = np.arange(100)
y = 3.*np.sin(x*2.*np.pi/100.)
for i in range(1, 6):
temp = 510 + i
ax = plt.subplot(temp)
plt.plot(x, y)
plt.subplots_adjust(hspace=0)
temp = tic.MaxNLocator(3)
ax.yaxis.set_major_locator(temp)
ax.set_xticklabels(())
ax.title.set_visible(False)
plt.show()
hspace=0 or hspace=0.001
hspace=None
Similar to tight_layout matplotlib now (as of version 2.2) provides constrained_layout. In contrast to tight_layout, which may be called any time in the code for a single optimized layout, constrained_layout is a property, which may be active and will optimze the layout before every drawing step.
Hence it needs to be activated before or during subplot creation, such as figure(constrained_layout=True) or subplots(constrained_layout=True).
Example:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(4,4, constrained_layout=True)
plt.show()
constrained_layout may as well be set via rcParams
plt.rcParams['figure.constrained_layout.use'] = True
See the what's new entry and the Constrained Layout Guide
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10,60))
plt.subplots_adjust( ... )
The plt.subplots_adjust method:
def subplots_adjust(*args, **kwargs):
"""
call signature::
subplots_adjust(left=None, bottom=None, right=None, top=None,
wspace=None, hspace=None)
Tune the subplot layout via the
:class:`matplotlib.figure.SubplotParams` mechanism. The parameter
meanings (and suggested defaults) are::
left = 0.125 # the left side of the subplots of the figure
right = 0.9 # the right side of the subplots of the figure
bottom = 0.1 # the bottom of the subplots of the figure
top = 0.9 # the top of the subplots of the figure
wspace = 0.2 # the amount of width reserved for blank space between subplots
hspace = 0.2 # the amount of height reserved for white space between subplots
The actual defaults are controlled by the rc file
"""
fig = gcf()
fig.subplots_adjust(*args, **kwargs)
draw_if_interactive()
or
fig = plt.figure(figsize=(10,60))
fig.subplots_adjust( ... )
The size of the picture matters.
"I've tried messing with hspace, but increasing it only seems to make all of the graphs smaller without resolving the overlap problem."
Thus to make more white space and keep the sub plot size the total image needs to be bigger.
You could try the .subplot_tool()
plt.subplot_tool()
Resolving this issue when plotting a dataframe with pandas.DataFrame.plot, which uses matplotlib as the default backend.
The following works for whichever kind= is specified (e.g. 'bar', 'scatter', 'hist', etc.).
Tested in python 3.8.12, pandas 1.3.4, matplotlib 3.4.3
Imports and sample data
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# sinusoidal sample data
sample_length = range(1, 15+1)
rads = np.arange(0, 2*np.pi, 0.01)
data = np.array([np.sin(t*rads) for t in sample_length])
df = pd.DataFrame(data.T, index=pd.Series(rads.tolist(), name='radians'), columns=[f'freq: {i}x' for i in sample_length])
# default plot with subplots; each column is a subplot
axes = df.plot(subplots=True)
Adjust the Spacing
Adjust the default parameters in pandas.DataFrame.plot
Change figsize: a width of 5 and a height of 4 for each subplot is a good place to start.
Change layout: (rows, columns) for the layout of subplots.
sharey=True and sharex=True so space isn't taken for redundant labels on each subplot.
The .plot method returns a numpy array of matplotlib.axes.Axes, which should be flattened to easily work with.
Use .get_figure() to extract the DataFrame.plot figure object from one of the Axes.
Use fig.tight_layout() if desired.
axes = df.plot(subplots=True, layout=(3, 5), figsize=(25, 16), sharex=True, sharey=True)
# flatten the axes array to easily access any subplot
axes = axes.flat
# extract the figure object
fig = axes[0].get_figure()
# use tight_layout
fig.tight_layout()
df
# display(df.head(3))
freq: 1x freq: 2x freq: 3x freq: 4x freq: 5x freq: 6x freq: 7x freq: 8x freq: 9x freq: 10x freq: 11x freq: 12x freq: 13x freq: 14x freq: 15x
radians
0.00 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
0.01 0.010000 0.019999 0.029996 0.039989 0.049979 0.059964 0.069943 0.079915 0.089879 0.099833 0.109778 0.119712 0.129634 0.139543 0.149438
0.02 0.019999 0.039989 0.059964 0.079915 0.099833 0.119712 0.139543 0.159318 0.179030 0.198669 0.218230 0.237703 0.257081 0.276356 0.295520
This answer shows using fig.tight_layout after creating the figure. However, tight_layout can be set directly when creating the figure, because matplotlib.pyplot.subplots accepts additional parameters with **fig_kw. All additional keyword arguments are passed to the pyplot.figure call.
See How to plot in multiple subplots for accessing and plotting in subplots.
import matplotlib.pyplot as plt
# create the figure with tight_layout=True
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(8, 8), tight_layout=True)

Eliminate white space between subplots in a matplotlib figure

I am trying to create a nice plot which joins a 4x4 grid of subplots (placed with gridspec, each subplot is 8x8 pixels ). I constantly struggle getting the spacing between the plots to match what I am trying to tell it to do. I imagine the problem is arising from plotting a color bar on the right side of the figure, and adjusting the location of the plots in the figure to accommodate. However, it appears that this issue crops up even without the color bar included, which has further confused me. It may also have to do with the margin spacing. The images shown below are produced by the associated code. As you can see, I am trying to get the space between the plots to go to zero, but it doesn't seem to be working. Can anyone advise?
fig = plt.figure('W Heat Map', (18., 15.))
gs = gridspec.GridSpec(4,4)
gs.update(wspace=0., hspace=0.)
for index in indices:
loc = (i,j) #determined by the code
ax = plt.subplot(gs[loc])
c = ax.pcolor(physHeatArr[index,:,:], vmin=0, vmax=1500)
# take off axes
ax.axis('off')
ax.set_aspect('equal')
fig.subplots_adjust(right=0.8,top=0.9,bottom=0.1)
cbar_ax = heatFig.add_axes([0.85, 0.15, 0.05, 0.7])
cbar = heatFig.colorbar(c, cax=cbar_ax)
cbar_ax.tick_params(labelsize=16)
fig.savefig("heatMap.jpg")
Similarly, in making a square figure without the color bar:
fig = plt.figure('W Heat Map', (15., 15.))
gs = gridspec.GridSpec(4,4)
gs.update(wspace=0., hspace=0.)
for index in indices:
loc = (i,j) #determined by the code
ax = plt.subplot(gs[loc])
c = ax.pcolor(physHeatArr[index,:,:], vmin=0, vmax=400, cmap=plt.get_cmap("Reds_r"))
# take off axes
ax.axis('off')
ax.set_aspect('equal')
fig.savefig("heatMap.jpg")
When the axes aspect ratio is set to not automatically adjust (e.g. using set_aspect("equal") or a numeric aspect, or in general using imshow), there might be some white space between the subplots, even if wspace and hspaceare set to 0. In order to eliminate white space between figures, you may have a look at the following questions
How to remove gaps between *images* in matplotlib?
How to combine gridspec with plt.subplots() to eliminate space between rows of subplots
How to remove the space between subplots in matplotlib.pyplot?
You may first consider this answer to the first question, where the solution is to build a single array out of the individual arrays and then plot this single array using pcolor, pcolormesh or imshow. This makes it especially comfortable to add a colorbar later on.
Otherwise consider setting the figuresize and subplot parameters such that no whitespae will remain. Formulas for that calculation are found in this answer to the second question.
An adapted version with colorbar would look like this:
import matplotlib.pyplot as plt
import matplotlib.colors
import matplotlib.cm
import numpy as np
image = np.random.rand(16,8,8)
aspect = 1.
n = 4 # number of rows
m = 4 # numberof columns
bottom = 0.1; left=0.05
top=1.-bottom; right = 1.-0.18
fisasp = (1-bottom-(1-top))/float( 1-left-(1-right) )
#widthspace, relative to subplot size
wspace=0 # set to zero for no spacing
hspace=wspace/float(aspect)
#fix the figure height
figheight= 4 # inch
figwidth = (m + (m-1)*wspace)/float((n+(n-1)*hspace)*aspect)*figheight*fisasp
fig, axes = plt.subplots(nrows=n, ncols=m, figsize=(figwidth, figheight))
plt.subplots_adjust(top=top, bottom=bottom, left=left, right=right,
wspace=wspace, hspace=hspace)
#use a normalization to make sure the colormapping is the same for all subplots
norm=matplotlib.colors.Normalize(vmin=0, vmax=1 )
for i, ax in enumerate(axes.flatten()):
ax.imshow(image[i, :,:], cmap = "RdBu", norm=norm)
ax.axis('off')
# use a scalarmappable derived from the norm instance to create colorbar
sm = matplotlib.cm.ScalarMappable(cmap="RdBu", norm=norm)
sm.set_array([])
cax = fig.add_axes([right+0.035, bottom, 0.035, top-bottom])
fig.colorbar(sm, cax=cax)
plt.show()

Is there a function to make scatterplot matrices in matplotlib?

Example of scatterplot matrix
Is there such a function in matplotlib.pyplot?
For those who do not want to define their own functions, there is a great data analysis libarary in Python, called Pandas, where one can find the scatter_matrix() method:
from pandas.plotting import scatter_matrix
df = pd.DataFrame(np.random.randn(1000, 4), columns = ['a', 'b', 'c', 'd'])
scatter_matrix(df, alpha = 0.2, figsize = (6, 6), diagonal = 'kde')
Generally speaking, matplotlib doesn't usually contain plotting functions that operate on more than one axes object (subplot, in this case). The expectation is that you'd write a simple function to string things together however you'd like.
I'm not quite sure what your data looks like, but it's quite simple to just build a function to do this from scratch. If you're always going to be working with structured or rec arrays, then you can simplify this a touch. (i.e. There's always a name associated with each data series, so you can omit having to specify names.)
As an example:
import itertools
import numpy as np
import matplotlib.pyplot as plt
def main():
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
plt.show()
def scatterplot_matrix(data, names, **kwargs):
"""Plots a scatterplot matrix of subplots. Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names". Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid."""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
fig.subplots_adjust(hspace=0.05, wspace=0.05)
for ax in axes.flat:
# Hide all ticks and labels
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
# Set up ticks only on one side for the "edge" subplots...
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')
# Plot the data.
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
axes[x,y].plot(data[x], data[y], **kwargs)
# Label the diagonal subplots...
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
# Turn on the proper x or y axes ticks.
for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)
return fig
main()
You can also use Seaborn's pairplot function:
import seaborn as sns
sns.set()
df = sns.load_dataset("iris")
sns.pairplot(df, hue="species")
Thanks for sharing your code! You figured out all the hard stuff for us. As I was working with it, I noticed a few little things that didn't look quite right.
[FIX #1] The axis tics weren't lining up like I would expect (i.e., in your example above, you should be able to draw a vertical and horizontal line through any point across all plots and the lines should cross through the corresponding point in the other plots, but as it sits now this doesn't occur.
[FIX #2] If you have an odd number of variables you are plotting with, the bottom right corner axes doesn't pull the correct xtics or ytics. It just leaves it as the default 0..1 ticks.
Not a fix, but I made it optional to explicitly input names, so that it puts a default xi for variable i in the diagonal positions.
Below you'll find an updated version of your code that addresses these two points, otherwise preserving the beauty of your code.
import itertools
import numpy as np
import matplotlib.pyplot as plt
def scatterplot_matrix(data, names=[], **kwargs):
"""
Plots a scatterplot matrix of subplots. Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names". Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid.
"""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
fig.subplots_adjust(hspace=0.0, wspace=0.0)
for ax in axes.flat:
# Hide all ticks and labels
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
# Set up ticks only on one side for the "edge" subplots...
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')
# Plot the data.
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
# FIX #1: this needed to be changed from ...(data[x], data[y],...)
axes[x,y].plot(data[y], data[x], **kwargs)
# Label the diagonal subplots...
if not names:
names = ['x'+str(i) for i in range(numvars)]
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
# Turn on the proper x or y axes ticks.
for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)
# FIX #2: if numvars is odd, the bottom right corner plot doesn't have the
# correct axes limits, so we pull them from other axes
if numvars%2:
xlimits = axes[0,-1].get_xlim()
ylimits = axes[-1,0].get_ylim()
axes[-1,-1].set_xlim(xlimits)
axes[-1,-1].set_ylim(ylimits)
return fig
if __name__=='__main__':
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
plt.show()
Thanks again for sharing this with us. I have used it many times! Oh, and I re-arranged the main() part of the code so that it can be a formal example code or not get called if it is being imported into another piece of code.
While reading the question I expected to see an answer including rpy. I think this is a nice option taking advantage of two beautiful languages. So here it is:
import rpy
import numpy as np
def main():
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
mpg = data[0,:]
disp = data[1,:]
drat = data[2,:]
wt = data[3,:]
rpy.set_default_mode(rpy.NO_CONVERSION)
R_data = rpy.r.data_frame(mpg=mpg,disp=disp,drat=drat,wt=wt)
# Figure saved as eps
rpy.r.postscript('pairsPlot.eps')
rpy.r.pairs(R_data,
main="Simple Scatterplot Matrix Via RPy")
rpy.r.dev_off()
# Figure saved as png
rpy.r.png('pairsPlot.png')
rpy.r.pairs(R_data,
main="Simple Scatterplot Matrix Via RPy")
rpy.r.dev_off()
rpy.set_default_mode(rpy.BASIC_CONVERSION)
if __name__ == '__main__': main()
I can't post an image to show the result :( sorry!

Categories