I've got a script wherein I have two functions, makeplots() which makes a figure of blank subplots arranged in a particular way (depending on the number of subplots to be drawn), and drawplots() which is called later, drawing the plots (obviously). The functions are posted below.
The script does some analysis of data for a given number of 'targets' (which can number anywhere from one to nine) and creates plots of the linear regression for each target. When there are multiple targets, this works great. But when there's a single target (i.e. a single 'subplot' in the figure), the Y-axis label overlaps the axis itself (this does not happen when there are multiple targets).
Ideally, each subplot would be square, no labels would overlap, and it would work the same for one target as for multiple targets. But when I tried to decrease the size of the y-axis label and shift it over a bit, it appears that the actual axes object was drawn over the previously blank, square plot (whose axes ranged from 0 to 1), and the old tick mark labels are still visible. I'd like to have those old tick marks removed when calling drawplots(). I've tried changing the subplot_kw={} arguments in makeplots, as well as removing ax.set_aspect('auto') from drawplots, both to no avail. Note that there are also screenshots of various behaviors at the end, also.
def makeplots(targets, active=actwindow):
def rowcnt(y):
rownumb = y//3 if (y%3 == 0) else y//3+1
return rownumb
def colcnt(x):
if x <= 3: colnumb = x
elif x == 4: colnumb = 2
else: colnumb = 3
return colnumb
numsubs = len(targets)
numrow, numcol = rowcnt(numsubs), colcnt(numsubs)
if numsubs >= 1:
if numsubs == 1:
fig, axs = plt.subplots(num='LOD-95 Plots', nrows=1, ncols=1, figsize = [8,6], subplot_kw={'adjustable': 'box', 'aspect': 1})
# changed 'box' to 'datalim'
fig, axs = plt.subplots(num='LOD-95 Plots', nrows=numrow, ncols=numcol, figsize = [numcol*6,numrow*6], subplot_kw={'adjustable': 'box', 'aspect': 1})
fig.text(0.02, 0.5, 'Probit score\n $(\sigma + 5)$', va='center', rotation='vertical', size='16')
else:
raise ValueError(f'Error generating plots [call: makeplots({targets},{active}) - invalid numsubs value]')
axs = np.ravel(axs)
for i, ax in enumerate(axs):
ax.set_title(f'Limit of Detection: {targets[i]}', size=11)
ax.grid()
return fig, axs
and
def drawplots(ax, dftables, color1, color2):
y = dftables.probit
y95 = 6.6448536269514722
logreg = False
regfun = lambda m, x, b : (m*x) + b
regq = scipy.stats.linregress(dftables.qty,y)
regl = scipy.stats.linregress(dftables.log_qty,y)
if regq.rvalue**2 >= regl.rvalue**2:
regression = regq
x_label = 'input quantity'
x = dftables.qty
elif regq.rvalue**2 < regl.rvalue**2:
regression = regl
x_label = '$log_{10}$(input quantity)'
x = dftables.log_qty
logreg = True
slope, intercept, r = regression.slope, regression.intercept, regression.rvalue
r2 = r**2
lod = (y95-intercept)/slope
xr = [0, lod*1.2]
yr = [intercept, regfun(slope, xr[1], intercept)]
regeqn = "y = "+str(f"{slope:.2e}")+"x + "+str(f"{intercept:.3f}")
if logreg:
lodstr = f'log(LOD) = {lod:.2f}' if lod <= 100 else f'log(LOD) = {lod:.2e}'
elif not logreg:
lodstr = f'LOD = {lod:.2f}' if lod <= 100 else f'LOD = {lod:.2e}'
# raise ValueError(f'Error raised calling drawplots()')
ax.set_xlabel(x_label, fontweight='bold')
ax.plot(xr, yr, color=color1, linestyle='--') # plot regression line
ax.plot(lod, y95, marker='D', color=color2, markersize=7) # plot point for LoD
ax.plot(xr, [y95,y95], color=color2, linestyle=':') # horizontal crosshair
ax.plot([lod,lod],[0, 7.1], color=color2, linestyle=':') # vertical crosshair
ax.scatter(x, y, s=81, color=color1, marker='.') # actual data points
ax.annotate(f"{lodstr}", xy=(lod,0.1),
xytext=(0.9*lod,0.5), fontsize=8, arrowprops = dict(facecolor='black', headlength=5, width=2, headwidth=5))
ax.set_aspect('auto')
ax.set_xlim(left=0)
ax.set_ylim(bottom=0)
ax.plot()
if logreg: lod = 10 ** lod
return r2, lod, regeqn, logreg
The context they're called in:
fig, axs = makeplots(targets)
wg.SetForegroundWindow(actwindow)
with open(outName, 'a+') as f:
print(f"Lower Limit of Detection Analysis on {dt} at {tm}\n", file=f)
for i, tars in enumerate(targets):
data[tars] = stripThousands(data[tars])
# logans = checkyn(f"Analyze {tars} using log10(concentration/quantity)? (y/n): ")
for idx, val in enumerate(qtys):
tables[i,idx,2] = hitrate(val,data,tars)
tables[i,idx,3] = norm.ppf(tables[i,idx,2])+5
printtables[tars] = pd.DataFrame(tables[i,:,:], columns=["qty","log_qty","probability","probit"])
# construct dataframes from np.arrays and drop
# rows with infinite probit values:
dftables[tars] = pd.DataFrame(tables[i,:,:], columns=["qty","log_qty","probability","probit"])
dftables[tars].probit.replace([np.inf,-np.inf],np.nan, inplace=True)
dftables[tars].dropna(inplace=True)
r2, lod, eqn, logreg = drawplots(axs[i], dftables[tars], cbcolors[i], cbcolors[i+5])
You should clear the axes in each iteration using pyplot.cla().
You posted a lot of code, so I'm not 100% sure of the best location to place it in your code, but the general idea is to clear the axes before each new plot.
Here is a minimal demo without cla():
x = [[1,2,3], [3,2,1]]
fig, ax = plt.subplots()
for index, data in enumerate(x):
ax.plot(data)
And with cla():
for index, data in enumerate(x):
ax.cla()
ax.plot(data)
Related
i have a little problem with matplotlib and python. So my problem is the line don't appear in the plot. I am trying to make a graph of a custom function. My code is here bellow:
fig, ax = plt.subplots(figsize=(8,4))
# Define the x axis values:
x = np.linspace(2000,32000)
# Creating the functions that we will plot
def pmgc(x):
return 0.853
def pmec(x):
return (-124.84/(x)) + pmgc(x)
for x in range(2000,32000):
pmgc(x)
pmec(x)
#Plotting
ax.plot(x,pmgc(x), color = 'blue',linewidth = 3)
ax.plot(x,pmec(x), color = 'red',linewidth = 3)
plt.rcParams["figure.autolayout"] = True
ax.set_xlabel("Renda")
plt.legend(labels = ['Propensão Marginal a Cosumir','Propensão Média a Cosumir'],loc = 'upper left', borderaxespad = 0,bbox_to_anchor=(1.02, 1))
plt.title('Gráfico da Questão 6, item c\nFeito por Luiz Mario. Fonte: Autor', loc='center')
Everytime that i run the code the graph appears without the lines. Please could someone can help me ?
Thank you for the attention :)
A few things. You are defining x as np.linspace(2000,32000) so use another variable in your for loop instead (such as i). Then, you want to create empty lists for your pmgc and pmec values to append to in your for loop. Lastly, you don't want to do for x in range(2000,32000): you want to do for i in np.linspace(2000, 32000): to match the length of your x list. But you've already defined np.linspace(2000, 32000) above in your code when you set x equal to it. So just do for i in x:. Put it all together, and you get your lines:
fig, ax = plt.subplots(figsize=(8,4))
# Define the x axis values:
x = np.linspace(2000,32000)
# Creating the functions that we will plot
def pmgc(x):
return 0.853
def pmec(x):
return (-124.84/(x)) + pmgc(x)
pmgc_list = []
pmec_list = []
for i in x:
pmgc_list.append(pmgc(i))
pmec_list.append(pmec(i))
#Plotting
ax.plot(x,pmgc_list, color = 'blue',linewidth = 3)
ax.plot(x,pmec_list, color = 'red',linewidth = 3)
plt.rcParams["figure.autolayout"] = True
ax.set_xlabel("Renda")
plt.legend(labels = ['Propensão Marginal a Cosumir','Propensão Média a Cosumir'],loc = 'upper left', borderaxespad = 0,bbox_to_anchor=(1.02, 1))
plt.title('Gráfico da Questão 6, item c\nFeito por Luiz Mario. Fonte: Autor', loc='center')
Output:
You can create two lists that contain the info this way
# Define the x axis values:
x = np.linspace(2000,32000)
# Creating the functions that we will plot
# create three empty sets
x_list = []
y_list1 = []
y_list2 = []
def pmgc(x):
return 0.853
def pmec(x):
return (-124.84/(x)) + pmgc(x)
for x in range(2000,32000):
# fill in the sets
x_list.append(x)
y_list1.append(pmgc(x))
y_list2.append(pmec(x))
#Plotting
# add x_list and y_list respectively
ax.plot(x_list,y_list1, color = 'blue',linewidth = 3)
ax.plot(x_list,y_list2, color = 'red',linewidth = 3)
plt.rcParams["figure.autolayout"] = True
ax.set_xlabel("Renda")
plt.legend(labels = ['Propensão Marginal a Cosumir','Propensão Média a Cosumir'],loc = 'upper left', borderaxespad = 0,bbox_to_anchor=(1.02, 1))
plt.title('Gráfico da Questão 6, item c\nFeito por Luiz Mario. Fonte: Autor', loc='center')
plt.show()
this might not be the best way to do it, but it will work.
My code is like this, at the moment:
df = pd.read_csv("Table.csv")
x=df['Fe']
y=df['V']
z=df['HIP'] #here, is a column of strings
rect_scatter = [left, bottom, width, height]
fig=plt.figure(figsize=(10, 8))
ax_scatter = plt.axes(rect_scatter)
ax_scatter.tick_params(direction='in', top=True, right=True)
# the function that separates the dots in different classes:
classes = np.zeros( len(x) )
classes[(z == 'KOI-2')]= 1
classes[(z == 'KOI-10')]= 1
classes[(z == 'KOI-17')]= 1
classes[(z == 'KOI-18')]= 1
classes[(z == 'KOI-22')]= 1
classes[(z == 'KOI-94')]= 1
classes[(z == 'KOI-97')]= 1
# create color map:
colors = ['green', 'red']
cm = LinearSegmentedColormap.from_list('custom', colors, N=len(colors))
# the scatter plot:
scatter = ax_scatter.scatter(x, y, c=classes, s=10, cmap=cm)
lines, labels = scatter.legend_elements()
# legend with custom labels
labels = [r'Hypatia', r'CKS']
legend = ax_scatter.legend(lines, labels,
loc="upper left", title="Stars with giant planets")
ax_scatter.add_artist(legend)
ax_scatter.set_xlabel('[Fe/H]')
ax_scatter.set_ylabel('[V/H]')
My data, however, has a lot of values other than these 7 I've set as classes=1. Due to that, when I plot the scatter, these 3 values are overlapped by the other hundreds. How can I make these 7 dots appear in front of the others in the plot? Is there a way of giving preference to a class over the other?
In addition to jfaccionis anwer, you can explicitly set the plotting order with the parameter zorder. See the docs.
For each scatter-command, you can specify its order with:
ax.scatter(x, y, s=12, zorder=2)
In your case it's simpler to divide the data prior to plotting, and then call ax.scatter twice. The last call will have Z-index priority by default.
I can't properly test it without access to your data, but something like this should work:
class_one_strings = ['KOI-2', 'KOI-10', 'KOI-17', 'KOI-18', 'KOI-22', 'KOI-94', 'KOI-97']
df['Classes'] = df['HIP'].apply(lambda s: 1 if s in class_one_strings else 0)
class_zero_x = df.loc[df['Classes'] == 0]['Fe']
class_zero_y = df.loc[df['Classes'] == 0]['V']
class_one_x = df.loc[df['Classes'] == 1]['Fe']
class_one_y = df.loc[df['Classes'] == 1]['V']
ax_scatter.scatter(class_zero_x, class_zero_y, c='green', s=10)
ax_scatter.scatter(class_one_x, class_one_y, c='red', s=10)
I use a bar graph to indicate the data of each group. Some of these bars differ significantly from each other. How can I indicate the significant difference in the bar plot?
import numpy as np
import matplotlib.pyplot as plt
menMeans = (5, 15, 30, 40)
menStd = (2, 3, 4, 5)
ind = np.arange(4) # the x locations for the groups
width=0.35
p1 = plt.bar(ind, menMeans, width=width, color='r', yerr=menStd)
plt.xticks(ind+width/2., ('A', 'B', 'C', 'D') )
I am aiming for
The answer above inspired me to write a small but flexible function myself:
def barplot_annotate_brackets(num1, num2, data, center, height, yerr=None, dh=.05, barh=.05, fs=None, maxasterix=None):
"""
Annotate barplot with p-values.
:param num1: number of left bar to put bracket over
:param num2: number of right bar to put bracket over
:param data: string to write or number for generating asterixes
:param center: centers of all bars (like plt.bar() input)
:param height: heights of all bars (like plt.bar() input)
:param yerr: yerrs of all bars (like plt.bar() input)
:param dh: height offset over bar / bar + yerr in axes coordinates (0 to 1)
:param barh: bar height in axes coordinates (0 to 1)
:param fs: font size
:param maxasterix: maximum number of asterixes to write (for very small p-values)
"""
if type(data) is str:
text = data
else:
# * is p < 0.05
# ** is p < 0.005
# *** is p < 0.0005
# etc.
text = ''
p = .05
while data < p:
text += '*'
p /= 10.
if maxasterix and len(text) == maxasterix:
break
if len(text) == 0:
text = 'n. s.'
lx, ly = center[num1], height[num1]
rx, ry = center[num2], height[num2]
if yerr:
ly += yerr[num1]
ry += yerr[num2]
ax_y0, ax_y1 = plt.gca().get_ylim()
dh *= (ax_y1 - ax_y0)
barh *= (ax_y1 - ax_y0)
y = max(ly, ry) + dh
barx = [lx, lx, rx, rx]
bary = [y, y+barh, y+barh, y]
mid = ((lx+rx)/2, y+barh)
plt.plot(barx, bary, c='black')
kwargs = dict(ha='center', va='bottom')
if fs is not None:
kwargs['fontsize'] = fs
plt.text(*mid, text, **kwargs)
which allows me to get some nice annotations relatively simple, e.g.:
heights = [1.8, 2, 3]
bars = np.arange(len(heights))
plt.figure()
plt.bar(bars, heights, align='center')
plt.ylim(0, 5)
barplot_annotate_brackets(0, 1, .1, bars, heights)
barplot_annotate_brackets(1, 2, .001, bars, heights)
barplot_annotate_brackets(0, 2, 'p < 0.0075', bars, heights, dh=.2)
I've done a couple of things here that I suggest when working with complex plots. Pull out the custom formatting into a dictionary, it makes life simple when you want to change a parameter - and you can pass this dictionary to multiple plots. I've also written a custom function to annotate the itervalues, as a bonus it can annotate between (A,C) if you really want to (I stand by my comment that this isn't the right visual approach however). It may need some tweaking once the data changes but this should put you on the right track.
import numpy as np
import matplotlib.pyplot as plt
menMeans = (5, 15, 30, 40)
menStd = (2, 3, 4, 5)
ind = np.arange(4) # the x locations for the groups
width= 0.7
labels = ('A', 'B', 'C', 'D')
# Pull the formatting out here
bar_kwargs = {'width':width,'color':'y','linewidth':2,'zorder':5}
err_kwargs = {'zorder':0,'fmt':None,'linewidth':2,'ecolor':'k'} #for matplotlib >= v1.4 use 'fmt':'none' instead
fig, ax = plt.subplots()
ax.p1 = plt.bar(ind, menMeans, **bar_kwargs)
ax.errs = plt.errorbar(ind, menMeans, yerr=menStd, **err_kwargs)
# Custom function to draw the diff bars
def label_diff(i,j,text,X,Y):
x = (X[i]+X[j])/2
y = 1.1*max(Y[i], Y[j])
dx = abs(X[i]-X[j])
props = {'connectionstyle':'bar','arrowstyle':'-',\
'shrinkA':20,'shrinkB':20,'linewidth':2}
ax.annotate(text, xy=(X[i],y+7), zorder=10)
ax.annotate('', xy=(X[i],y), xytext=(X[j],y), arrowprops=props)
# Call the function
label_diff(0,1,'p=0.0370',ind,menMeans)
label_diff(1,2,'p<0.0001',ind,menMeans)
label_diff(2,3,'p=0.0025',ind,menMeans)
plt.ylim(ymax=60)
plt.xticks(ind, labels, color='k')
plt.show()
If you are using matplotlib and seeking boxplot annotation, use my code as a function:
statistical annotation
def AnnoMe(x1, x2, ARRAY, TXT):
y, h, col = max(max(ARRAY[x1-1]),max(ARRAY[x2-1])) + 2, 2, 'k'
plt.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1.5, c=col)
plt.text((x1+x2)*.5, y+h, TXT, ha='center', va='bottom', color=col)
where 'x1' and 'x2' are two columns you want to compare, 'ARRAY' is the list of lists you are using for illustrating the boxplot. And, 'TXT' is your text like p-value or significant/not significant in string format.
Accordingly, call it with:
AnnoMe(1, 2, MyArray, "p-value=0.02")
Grouped bar plot from pandas dataframe
Annotate significant difference between bars
I have modified the solution of #cheersmate in order to receive in input also pandas dataframes. This function is tested with matplotlib 3.5.1
def annotate_barplot_dataframe(bar0, bar1, text, patches, dh=0.2):
"""Annotate a grouped barplot from a pandas dataframe
An annotation is added to the figure from bar0 to bar1
Args:
bar0 (int): index of first bar
bar1 (int): index of second bar
text (string): what to write on the annotation
patches (matplotlib.patches): data source
df (float): height of the annotation bar
"""
patches.sort(key=lambda x: x.xy[0])
left = patches[bar0]
right = patches[bar1]
y = max(left._height, right._height) + dh
l_bbox = left.get_bbox()
l_mid = l_bbox.x1 - left._width / 2
r_bbox = right.get_bbox()
r_mid = r_bbox.x1 - right._width / 2
barh = 0.07
# lower-left, upper-left, upper-right, lower-right
barx = [l_mid, l_mid, r_mid, r_mid]
bary = [
y,
y + barh,
y + barh,
y,
]
plt.plot(barx, bary, c="black")
kwargs = dict(ha="center", va="bottom")
mid = ((l_mid + r_mid) / 2, y + 0.01)
plt.text(*mid, text, **kwargs)
def prepare_df(filename):
"""load filename is exists and prepare it for the plot
Args:
filename (string): must be a .xlsx file
Returns:
pandas.df: grouped dataframe
"""
assert filename.endswith("xlsx"), "Check file extension"
try:
df = pd.read_excel(filename, sheet_name=0, usecols="H:W", engine="openpyxl")
except Exception as e:
raise ValueError(e)
# Columnkey is the variable by which we want to group
# e.g. in this example columnskey's entries have 3 different values
grouped = df.groupby(df["Columnkey"])
df_group1 = grouped.get_group(1)
df_group2 = grouped.get_group(2)
df_group3 = grouped.get_group(3)
g = pd.concat(
[
df_group1.mean().rename("C1"),
df_group2.mean().rename("C2"),
df_group3.mean().rename("C3"),
],
axis=1,
)
return g
So the input to the function should look something like this.
if __name__ == "__main__":
filename = "Data.xlsx"
dataframe = prepare_df(filename)
width = 0.7
ax = dataframe.plot.bar(width=width, figsize=(9, 2))
# this plot will group in sets of 3
patches = ax.patches._axes.axes.containers[0].patches
patches.extend(ax.patches._axes.axes.containers[1].patches)
patches.extend(ax.patches._axes.axes.containers[2].patches)
annotate_barplot_dataframe(0, 1, "*", patches, 0.1)
annotate_barplot_dataframe(1, 2, "*", patches, 0.1)
plt.savefig(fname="filename.pdf", bbox_inches="tight")
plt.show()
The outcome will save to disk a picture like
I have a code:
import math
import numpy as np
import pylab as plt1
from matplotlib import pyplot as plt
uH2 = 1.90866638
uHe = 3.60187307
eH2 = 213.38
eHe = 31.96
R = float(uH2*eH2)/(uHe*eHe)
C_Values = []
Delta = []
kHeST = []
J_f21 = []
data = np.genfromtxt("Lamda_HeHCL.txt", unpack=True);
J_i1=data[1];
J_f1=data[2];
kHe=data[7]
data = np.genfromtxt("Basecol_Basic_New_1.txt", unpack=True);
J_i2=data[0];
J_f2=data[1];
kH2=data[5]
print kHe
print kH2
kHe = map(float, kHe)
kH2 = map(float, kH2)
kHe = np.array(kHe)
kH2= np.array(kH2)
g = len(kH2)
for n in range(0,g):
if J_f2[n] == 1:
Jf21 = J_f2[n]
J_f21.append(Jf21)
ratio = kHe[n]/kH2[n]
C = (((math.log(float(kH2[n]),10)))-(math.log(float(kHe[n]),10)))/math.log(R,10)
C_Values.append(C)
St = abs(J_f1[n] - J_i1[n])
Delta.append(St)
print C_Values
print Delta
print J_f21
fig, ax = plt.subplots()
ax.scatter(Delta,C_Values)
for i, txt in enumerate(J_f21):
ax.annotate(txt, (Delta[i],C_Values[i]))
plt.plot(np.unique(Delta), np.poly1d(np.polyfit(Delta, C_Values, 1))(np.unique(Delta)))
plt.plot(Delta, C_Values)
fit = np.polyfit(Delta,C_Values,1)
fit_fn = np.poly1d(fit)
# fit_fn is now a function which takes in x and returns an estimate for y
plt.scatter(Delta,C_Values, Delta, fit_fn(Delta))
plt.xlim(0, 12)
plt.ylim(-3, 3)
In this code, I am trying to plot a linear regression that extends past the data and touches the x-axis. I am also trying to add a legend to the plot that shows the slope of the plot. Using the code, I was able to plot this graph.
Here is some trash data I have been using to try and extend the line and add a legend to my code.
x =[5,7,9,15,20]
y =[10,9,8,7,6]
I would also like it to be a scatter except for the linear regression line.
Given that you don't provide the data you're loading from files I was unable to test this, but off the top of my head:
To extend the line past the plot, you could turn this line
plt.plot(np.unique(Delta), np.poly1d(np.polyfit(Delta, C_Values, 1))(np.unique(Delta)))
Into something like
x = np.linspace(0, 12, 50) # both 0 and 12 are from visually inspecting the plot
plt.plot(x, np.poly1d(np.polyfit(Delta, C_Values, 1))(x))
But if you want the line extended to the x-axis,
polynomial = np.polyfit(Delta, C_Values, 1)
x = np.linspace(0, *np.roots(polynomial))
plt.plot(x, np.poly1d(polynomial)(x))
As for the scatter plot thing, it seems to me you could just remove this line:
plt.plot(Delta, C_Values)
Oh right, as for the legend, add a label to the plots you make, like this:
plt.plot(x, np.poly1d(polynomial)(x), label='Linear regression')
and add a call to plt.legend() just before plt.show().
I know pandas supports a secondary Y axis, but I'm curious if anyone knows a way to put a tertiary Y axis on plots. Currently I am achieving this with numpy+pyplot, but it is slow with large data sets.
This is to plot different measurements with distinct units on the same graph for easy comparison (eg: Relative Humidity/Temperature/ and Electrical Conductivity).
So really just curious if anyone knows if this is possible in pandas without too much work.
[Edit] I doubt that there is a way to do this(without too much overhead) however I hope to be proven wrong, as this may be a limitation of matplotlib.
I think this might work:
import matplotlib.pyplot as plt
import numpy as np
from pandas import DataFrame
df = DataFrame(np.random.randn(5, 3), columns=['A', 'B', 'C'])
fig, ax = plt.subplots()
ax3 = ax.twinx()
rspine = ax3.spines['right']
rspine.set_position(('axes', 1.15))
ax3.set_frame_on(True)
ax3.patch.set_visible(False)
fig.subplots_adjust(right=0.7)
df.A.plot(ax=ax, style='b-')
# same ax as above since it's automatically added on the right
df.B.plot(ax=ax, style='r-', secondary_y=True)
df.C.plot(ax=ax3, style='g-')
# add legend --> take advantage of pandas providing us access
# to the line associated with the right part of the axis
ax3.legend([ax.get_lines()[0], ax.right_ax.get_lines()[0], ax3.get_lines()[0]],\
['A','B','C'], bbox_to_anchor=(1.5, 0.5))
Output:
A simpler solution without plt:
ax1 = df1.plot()
ax2 = ax1.twinx()
ax2.spines['right'].set_position(('axes', 1.0))
df2.plot(ax=ax2)
ax3 = ax1.twinx()
ax3.spines['right'].set_position(('axes', 1.1))
df3.plot(ax=ax3)
....
Using function to achieve this:
def plot_multi(data, cols=None, spacing=.1, **kwargs):
from pandas.plotting._matplotlib.style import get_standard_colors
# Get default color style from pandas - can be changed to any other color list
if cols is None: cols = data.columns
if len(cols) == 0: return
colors = get_standard_colors(num_colors=len(cols))
# First axis
ax = data.loc[:, cols[0]].plot(label=cols[0], color=colors[0], **kwargs)
ax.set_ylabel(ylabel=cols[0])
lines, labels = ax.get_legend_handles_labels()
for n in range(1, len(cols)):
# Multiple y-axes
ax_new = ax.twinx()
ax_new.spines['right'].set_position(('axes', 1 + spacing * (n - 1)))
data.loc[:, cols[n]].plot(ax=ax_new, label=cols[n], color=colors[n % len(colors)], **kwargs)
ax_new.set_ylabel(ylabel=cols[n])
# Proper legend position
line, label = ax_new.get_legend_handles_labels()
lines += line
labels += label
ax.legend(lines, labels, loc=0)
return ax
Example:
from random import randrange
data = pd.DataFrame(dict(
s1=[randrange(-1000, 1000) for _ in range(100)],
s2=[randrange(-100, 100) for _ in range(100)],
s3=[randrange(-10, 10) for _ in range(100)],
))
plot_multi(data.cumsum(), figsize=(10, 5))
Output:
I modified the above answer a bit to make it accept custom x column, well-documented, and more flexible.
You can copy this snippet and use it as a function:
from typing import List, Union
import matplotlib.axes
import pandas as pd
def plot_multi(
data: pd.DataFrame,
x: Union[str, None] = None,
y: Union[List[str], None] = None,
spacing: float = 0.1,
**kwargs
) -> matplotlib.axes.Axes:
"""Plot multiple Y axes on the same chart with same x axis.
Args:
data: dataframe which contains x and y columns
x: column to use as x axis. If None, use index.
y: list of columns to use as Y axes. If None, all columns are used
except x column.
spacing: spacing between the plots
**kwargs: keyword arguments to pass to data.plot()
Returns:
a matplotlib.axes.Axes object returned from data.plot()
Example:
>>> plot_multi(df, figsize=(22, 10))
>>> plot_multi(df, x='time', figsize=(22, 10))
>>> plot_multi(df, y='price qty value'.split(), figsize=(22, 10))
>>> plot_multi(df, x='time', y='price qty value'.split(), figsize=(22, 10))
>>> plot_multi(df[['time price qty'.split()]], x='time', figsize=(22, 10))
See Also:
This code is mentioned in https://stackoverflow.com/q/11640243/2593810
"""
from pandas.plotting._matplotlib.style import get_standard_colors
# Get default color style from pandas - can be changed to any other color list
if y is None:
y = data.columns
# remove x_col from y_cols
if x:
y = [col for col in y if col != x]
if len(y) == 0:
return
colors = get_standard_colors(num_colors=len(y))
if "legend" not in kwargs:
kwargs["legend"] = False # prevent multiple legends
# First axis
ax = data.plot(x=x, y=y[0], color=colors[0], **kwargs)
ax.set_ylabel(ylabel=y[0])
lines, labels = ax.get_legend_handles_labels()
for i in range(1, len(y)):
# Multiple y-axes
ax_new = ax.twinx()
ax_new.spines["right"].set_position(("axes", 1 + spacing * (i - 1)))
data.plot(
ax=ax_new, x=x, y=y[i], color=colors[i % len(colors)], **kwargs
)
ax_new.set_ylabel(ylabel=y[i])
# Proper legend position
line, label = ax_new.get_legend_handles_labels()
lines += line
labels += label
ax.legend(lines, labels, loc=0)
return ax
Here's one way to use it:
plot_multi(df, x='time', y='price qty value'.split(), figsize=(22, 10))