How to specify linewidth in Seaborn's clustermap dendrograms - python

Normally I would increase matplotlib's global linewidths by editing the matplotlib.rcParams. This seems to work well directly with SciPy's dendrogram implementation but not with Seaborn's clustermap (which uses SciPy's dendrograms). Can anyone suggest a working method?
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 10
import seaborn as sns; sns.set()
flights = sns.load_dataset("flights")
flights = flights.pivot("month", "year", "passengers")
g = sns.clustermap(flights)

This has now been addressed in a more robust way by the following merged pull request https://github.com/mwaskom/seaborn/pull/1935. I'm assuming it will be included in the release after v0.9.0.
You can control the LineCollection properties of the dendrogram by using the tree_kws parameter.
For example:
>>> import seaborn as sns
>>> iris = sns.load_dataset("iris")
>>> species = iris.pop("species")
>>> g = sns.clustermap(iris, tree_kws=dict(linewidths=1.5, colors=(0.2, 0.2, 0.4))
Would create a clustermap with 1.5 pt thick lines for the tree in an alternative dark purple color.

for newer versions of seaborn (tested with 0.7.1, 0.9.0), the lines are in a LineCollection, rather than by themselves. So their width can be changed as follows:
import seaborn as sns
import matplotlib.pyplot as plt
# load data and make clustermap
df = sns.load_dataset('iris')
g = sns.clustermap(df[['sepal_length', 'sepal_width']])
for a in g.ax_row_dendrogram.collections:
a.set_linewidth(10)
for a in g.ax_col_dendrogram.collections:
a.set_linewidth(10)

There may be an easier way to do it, but this seems to work:
import matplotlib
import seaborn as sns; sns.set()
flights = sns.load_dataset("flights")
flights = flights.pivot("month", "year", "passengers")
g = sns.clustermap(flights)
for l in g.ax_row_dendrogram.lines:
l.set_linewidth(10)
for l in g.ax_col_dendrogram.lines:
l.set_linewidth(10)
Edit This no longer works in Seaborn v. 0.7.1 (and probably some earlier versions as well); g.ax_col_dendrogram.lines now returns an empty list. I couldn't find a way to increase line width and I ended up temporarily modifying the Seaborn module. In file matrix.py, function class _DendrogramPlotter, the linewidth is hard-coded as 0.5; I modified it to 1.5:
line_kwargs = dict(linewidths=1.5, colors='k')
This worked but obviously isn't a very sustainable approach.

Related

Making seaborn.PairGrid() look like pairplot()

In the example below, how do I use seaborn.PairGrid() to reproduce the plots created by seaborn.pairplot()? Specifically, I'd like the diagonal distributions to span the vertical axis. Markers with white borders etc... would be great too. Thanks!
import seaborn as sns
import matplotlib.pyplot as plt
iris = sns.load_dataset('iris')
# pairplot() example
g = sns.pairplot(iris, kind='scatter', diag_kind='kde')
plt.show()
# PairGrid() example
g = sns.PairGrid(iris)
g.map_diag(sns.kdeplot)
g.map_offdiag(plt.scatter)
plt.show()
This is quite simple to achieve. The main differences between your plot and what pairplot does are:
the use of the diag_sharey parameter of PairGrid
using sns.scatterplot instead of plt.scatter
With that, we have:
iris = sns.load_dataset('iris')
g = sns.PairGrid(iris, diag_sharey=False)
g.map_diag(sns.kdeplot)
g.map_offdiag(sns.scatterplot)
To change the visual style:
import seaborn as sns
import matplotlib.pyplot as plt
iris = sns.load_dataset('iris')
g = sns.PairGrid(iris)
g.map_diag(sns.kdeplot, shade=True)
g.map_offdiag(plt.scatter, edgecolor="w")
plt.show()

Can you plot interquartile range as the error band on a seaborn lineplot?

I'm plotting time series data using seaborn lineplot (https://seaborn.pydata.org/generated/seaborn.lineplot.html), and plotting the median instead of mean. Example code:
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt
fmri = sns.load_dataset("fmri")
ax = sns.lineplot(x="timepoint", y="signal", estimator = np.median, data=fmri)
I want the error bands to show the interquartile range as opposed to the confidence interval. I know I can use ci = "sd" for standard deviation, but is there a simple way to add the IQR instead? I cannot figure it out.
Thank you!
I don't know if this can be done with seaborn alone, but here's one way to do it with matplotlib, keeping the seaborn style. The describe() method conveniently provides summary statistics for a DataFrame, among them the quartiles, which we can use to plot the medians with inter-quartile-ranges.
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt
fmri = sns.load_dataset("fmri")
fmri_stats = fmri.groupby(['timepoint']).describe()
x = fmri_stats.index
medians = fmri_stats[('signal', '50%')]
medians.name = 'signal'
quartiles1 = fmri_stats[('signal', '25%')]
quartiles3 = fmri_stats[('signal', '75%')]
ax = sns.lineplot(x, medians)
ax.fill_between(x, quartiles1, quartiles3, alpha=0.3);
You can calculate the median within lineplot like you have done, set ci to be none and fill in using ax.fill_between()
import numpy as np
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt
fmri = sns.load_dataset("fmri")
ax = sns.lineplot(x="timepoint", y="signal", estimator = np.median,
data=fmri,ci=None)
bounds = fmri.groupby('timepoint')['signal'].quantile((0.25,0.75)).unstack()
ax.fill_between(x=bounds.index,y1=bounds.iloc[:,0],y2=bounds.iloc[:,1],alpha=0.1)
This option is possible since version 0.12 of seaborn, see here for the documentation.
pip install --upgrade seaborn
The estimator specifies the point by the name of pandas method or callable, such as 'median' or 'mean'.
The errorbar is an option to plot a distribution spread by a string, (string, number) tuple, or callable. In order to mark the median value and fill the area between the interquartile, you would need the params:
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt
fmri = sns.load_dataset("fmri")
ax = sns.lineplot(data=fmri, x="timepoint", y="signal", estimator=np.median,
errorbar=lambda x: (np.quantile(x, 0.25), np.quantile(x, 0.75)))
You can now!
estimator="median", errobar=("pi",0.5)
https://seaborn.pydata.org/tutorial/error_bars

matplotlib (seaborn): plot correlations between one variable vs multiple others

When plotting correlations, this code
>>> import seaborn as sns
>>> iris = sns.load_dataset("iris")
>>> g = sns.pairplot(iris)
results in the following pairplot:
http://seaborn.pydata.org/_images/seaborn-pairplot-1.png
What if I just want to show the first row out of those four (i.e. correlations of 'sepal_length' vs all other features)? How can I plot that? Could pairplot be used but with some modifications?
Thanks
Using the x_vars and y_vars arguments of pairplot you can select which columns to correlate.
import matplotlib.pyplot as plt
import seaborn as sns
iris = sns.load_dataset("iris")
g = sns.pairplot(iris,
x_vars=["sepal_width","petal_length","petal_width"],
y_vars=["sepal_length"])
plt.show()

Stop seaborn plotting multiple figures on top of one another

I'm starting to learn a bit of python (been using R) for data analysis. I'm trying to create two plots using seaborn, but it keeps saving the second on top of the first. How do I stop this behavior?
import seaborn as sns
iris = sns.load_dataset('iris')
length_plot = sns.barplot(x='sepal_length', y='species', data=iris).get_figure()
length_plot.savefig('ex1.pdf')
width_plot = sns.barplot(x='sepal_width', y='species', data=iris).get_figure()
width_plot.savefig('ex2.pdf')
You have to start a new figure in order to do that. There are multiple ways to do that, assuming you have matplotlib. Also get rid of get_figure() and you can use plt.savefig() from there.
Method 1
Use plt.clf()
import seaborn as sns
import matplotlib.pyplot as plt
iris = sns.load_dataset('iris')
length_plot = sns.barplot(x='sepal_length', y='species', data=iris)
plt.savefig('ex1.pdf')
plt.clf()
width_plot = sns.barplot(x='sepal_width', y='species', data=iris)
plt.savefig('ex2.pdf')
Method 2
Call plt.figure() before each one
plt.figure()
length_plot = sns.barplot(x='sepal_length', y='species', data=iris)
plt.savefig('ex1.pdf')
plt.figure()
width_plot = sns.barplot(x='sepal_width', y='species', data=iris)
plt.savefig('ex2.pdf')
I agree with a previous comment that importing matplotlib.pyplot is not the best software engineering practice as it exposes the underlying library. As I was creating and saving plots in a loop, then I needed to clear the figure and found out that this can now be easily done by importing seaborn only:
since version 0.11:
import seaborn as sns
import numpy as np
data = np.random.normal(size=100)
path = "/path/to/img/plot.png"
plot = sns.displot(data) # also works with histplot() etc
plot.fig.savefig(path)
plot.fig.clf() # this clears the figure
# ... continue with next figure
alternative example with a loop:
import seaborn as sns
import numpy as np
for i in range(3):
data = np.random.normal(size=100)
path = "/path/to/img/plot2_{0:01d}.png".format(i)
plot = sns.displot(data)
plot.fig.savefig(path)
plot.fig.clf() # this clears the figure
before version 0.11 (original post):
import seaborn as sns
import numpy as np
data = np.random.normal(size=100)
path = "/path/to/img/plot.png"
plot = sns.distplot(data)
plot.get_figure().savefig(path)
plot.get_figure().clf() # this clears the figure
# ... continue with next figure
Create specific figures and plot onto them:
import seaborn as sns
iris = sns.load_dataset('iris')
length_fig, length_ax = plt.subplots()
sns.barplot(x='sepal_length', y='species', data=iris, ax=length_ax)
length_fig.savefig('ex1.pdf')
width_fig, width_ax = plt.subplots()
sns.barplot(x='sepal_width', y='species', data=iris, ax=width_ax)
width_fig.savefig('ex2.pdf')
I've found that if the interaction is turned off seaborn plot the heatmap normally.

Control tick labels in Python seaborn package

I have a scatter plot matrix generated using the seaborn package and I'd like to remove all the tick mark labels as these are just messying up the graph (either that or just remove those on the x-axis), but I'm not sure how to do it and have had no success doing Google searches. Any suggestions?
import seaborn as sns
sns.pairplot(wheat[['area_planted',
'area_harvested',
'production',
'yield']])
plt.show()
import seaborn as sns
iris = sns.load_dataset("iris")
g = sns.pairplot(iris)
g.set(xticklabels=[])
You can use a list comprehension to loop through all columns and turn off visibility of the xaxis.
df = pd.DataFrame(np.random.randn(1000, 2)) * 1e6
sns.pairplot(df)
plot = sns.pairplot(df)
[plot.axes[len(df.columns) - 1][col].xaxis.set_visible(False)
for col in range(len(df.columns))]
plt.show()
You could also rescale your data to something more readable:
df /= 1e6
sns.pairplot(df)
Probably using the following is more appropriate
import seaborn as sns
iris = sns.load_dataset("iris")
g = sns.pairplot(iris)
g.set(xticks=[])

Categories