Adding labels from another variable to a single bar in a barchart - python

I am trying to add a string variable to my graph but i don't know how to have it inside the bar.
My string variable is called city.
ax.barh(df['date'].values, df['cases'].values, color= '#49cff3')
ax.text(0, 1.06, 'Number of Cases', transform=ax.transAxes, size=12, color='#777777')
ax.xaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))
ax.xaxis.set_ticks_position('top')
ax.tick_params(axis='x', colors='#777777', labelsize=12)
ax.set_yticks([])
ax.grid(which='major', axis='x', linestyle='-')
ax.set_axisbelow(True)
ax.text(0, 1.15, 'Total Number of cases',
transform=ax.transAxes, size=24, weight=600, ha='left', va='top')
plt.yticks()
plt.box(False)
here is my output
Many thanks in advance
edit: i tried adding #JohanC 's solution and got the following output:

If you don't set dates as the y-axis, ax.set_yticks(range(len(df))) followed by ax.set_yticklabels(df['city']) would work just fine.
But a more recommended way is to directly use the cities for the y-axis. Tick marks can be made of length zero, so they keep signalling the position for the tick labels without the marks being visible.
If you want dates as the y-axis, and put the labels inside the bars, you could just place the text via a loop. But it doesn't make too much sense to use the date as the y-axis without showing them.
To invert the y-axis, with the smallest (or first) values at the top, use ax.invert_yaxis().
As the post didn't include test data, the example below starts with creating some random data.
import matplotlib.pyplot as plt
from matplotlib import ticker
import numpy as np
import pandas as pd
df = pd.DataFrame({'date': pd.date_range('2020-01-01', periods=10),
'cases': np.random.randint(80, 200, 10),
'city': [f'City {i}' for i in range(1, 11)]})
fig, ax = plt.subplots()
ax.barh(df['date'], df['cases'], color='#49cff3')
ax.xaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))
ax.xaxis.set_ticks_position('top')
ax.tick_params(axis='x', colors='#777777', labelsize=12)
ax.tick_params(axis='y', length=0)
# ax.set_yticks([]) # this would make the both the tick marks and their labels disappear
ax.grid(which='major', axis='x', linestyle='-')
for row in df.itertuples(index=False):
ax.text(0, row.date, ' '+row.city, ha='left', va='center', color='crimson')
ax.invert_yaxis()
plt.box(False)
plt.tight_layout()
plt.show()

Related

Matplotlib How to change x and y ticks background color? [duplicate]

I'd like to Change the color of the axis, as well as ticks and value-labels for a plot I did using matplotlib and PyQt.
Any ideas?
As a quick example (using a slightly cleaner method than the potentially duplicate question):
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(range(10))
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')
ax.spines['bottom'].set_color('red')
ax.spines['top'].set_color('red')
ax.xaxis.label.set_color('red')
ax.tick_params(axis='x', colors='red')
plt.show()
Alternatively
[t.set_color('red') for t in ax.xaxis.get_ticklines()]
[t.set_color('red') for t in ax.xaxis.get_ticklabels()]
If you have several figures or subplots that you want to modify, it can be helpful to use the matplotlib context manager to change the color, instead of changing each one individually. The context manager allows you to temporarily change the rc parameters only for the immediately following indented code, but does not affect the global rc parameters.
This snippet yields two figures, the first one with modified colors for the axis, ticks and ticklabels, and the second one with the default rc parameters.
import matplotlib.pyplot as plt
with plt.rc_context({'axes.edgecolor':'orange', 'xtick.color':'red', 'ytick.color':'green', 'figure.facecolor':'white'}):
# Temporary rc parameters in effect
fig, (ax1, ax2) = plt.subplots(1,2)
ax1.plot(range(10))
ax2.plot(range(10))
# Back to default rc parameters
fig, ax = plt.subplots()
ax.plot(range(10))
You can type plt.rcParams to view all available rc parameters, and use list comprehension to search for keywords:
# Search for all parameters containing the word 'color'
[(param, value) for param, value in plt.rcParams.items() if 'color' in param]
For those using pandas.DataFrame.plot(), matplotlib.axes.Axes is returned when creating a plot from a dataframe. Therefore, the dataframe plot can be assigned to a variable, ax, which enables the usage of the associated formatting methods.
The default plotting backend for pandas, is matplotlib.
See matplotlib.spines
Tested in python 3.10, pandas 1.4.2, matplotlib 3.5.1, seaborn 0.11.2
import pandas as pd
# test dataframe
data = {'a': range(20), 'date': pd.bdate_range('2021-01-09', freq='D', periods=20)}
df = pd.DataFrame(data)
# plot the dataframe and assign the returned axes
ax = df.plot(x='date', color='green', ylabel='values', xlabel='date', figsize=(8, 6))
# set various colors
ax.spines['bottom'].set_color('blue')
ax.spines['top'].set_color('red')
ax.spines['right'].set_color('magenta')
ax.spines['right'].set_linewidth(3)
ax.spines['left'].set_color('orange')
ax.spines['left'].set_lw(3)
ax.xaxis.label.set_color('purple')
ax.yaxis.label.set_color('silver')
ax.tick_params(colors='red', which='both') # 'both' refers to minor and major axes
seaborn axes-level plot
import seaborn as sns
# plot the dataframe and assign the returned axes
fig, ax = plt.subplots(figsize=(12, 5))
g = sns.lineplot(data=df, x='date', y='a', color='g', label='a', ax=ax)
# set the margines to 0
ax.margins(x=0, y=0)
# set various colors
ax.spines['bottom'].set_color('blue')
ax.spines['top'].set_color('red')
ax.spines['right'].set_color('magenta')
ax.spines['right'].set_linewidth(3)
ax.spines['left'].set_color('orange')
ax.spines['left'].set_lw(3)
ax.xaxis.label.set_color('purple')
ax.yaxis.label.set_color('silver')
ax.tick_params(colors='red', which='both') # 'both' refers to minor and major axes
seaborn figure-level plot
# plot the dataframe and assign the returned axes
g = sns.relplot(kind='line', data=df, x='date', y='a', color='g', aspect=2)
# iterate through each axes
for ax in g.axes.flat:
# set the margins to 0
ax.margins(x=0, y=0)
# make the top and right spines visible
ax.spines[['top', 'right']].set_visible(True)
# set various colors
ax.spines['bottom'].set_color('blue')
ax.spines['top'].set_color('red')
ax.spines['right'].set_color('magenta')
ax.spines['right'].set_linewidth(3)
ax.spines['left'].set_color('orange')
ax.spines['left'].set_lw(3)
ax.xaxis.label.set_color('purple')
ax.yaxis.label.set_color('silver')
ax.tick_params(colors='red', which='both') # 'both' refers to minor and major axes
motivated by previous contributors, this is an example of three axes.
import matplotlib.pyplot as plt
x_values1=[1,2,3,4,5]
y_values1=[1,2,2,4,1]
x_values2=[-1000,-800,-600,-400,-200]
y_values2=[10,20,39,40,50]
x_values3=[150,200,250,300,350]
y_values3=[-10,-20,-30,-40,-50]
fig=plt.figure()
ax=fig.add_subplot(111, label="1")
ax2=fig.add_subplot(111, label="2", frame_on=False)
ax3=fig.add_subplot(111, label="3", frame_on=False)
ax.plot(x_values1, y_values1, color="C0")
ax.set_xlabel("x label 1", color="C0")
ax.set_ylabel("y label 1", color="C0")
ax.tick_params(axis='x', colors="C0")
ax.tick_params(axis='y', colors="C0")
ax2.scatter(x_values2, y_values2, color="C1")
ax2.set_xlabel('x label 2', color="C1")
ax2.xaxis.set_label_position('bottom') # set the position of the second x-axis to bottom
ax2.spines['bottom'].set_position(('outward', 36))
ax2.tick_params(axis='x', colors="C1")
ax2.set_ylabel('y label 2', color="C1")
ax2.yaxis.tick_right()
ax2.yaxis.set_label_position('right')
ax2.tick_params(axis='y', colors="C1")
ax3.plot(x_values3, y_values3, color="C2")
ax3.set_xlabel('x label 3', color='C2')
ax3.xaxis.set_label_position('bottom')
ax3.spines['bottom'].set_position(('outward', 72))
ax3.tick_params(axis='x', colors='C2')
ax3.set_ylabel('y label 3', color='C2')
ax3.yaxis.tick_right()
ax3.yaxis.set_label_position('right')
ax3.spines['right'].set_position(('outward', 36))
ax3.tick_params(axis='y', colors='C2')
plt.show()
You can also use this to draw multiple plots in same figure and style them using same color palette.
An example is given below
fig = plt.figure()
# Plot ROC curves
plotfigure(lambda: plt.plot(fpr1, tpr1, linestyle='--',color='orange', label='Logistic Regression'), fig)
plotfigure(lambda: plt.plot(fpr2, tpr2, linestyle='--',color='green', label='KNN'), fig)
plotfigure(lambda: plt.plot(p_fpr, p_tpr, linestyle='-', color='blue'), fig)
# Title
plt.title('ROC curve')
# X label
plt.xlabel('False Positive Rate')
# Y label
plt.ylabel('True Positive rate')
plt.legend(loc='best',labelcolor='white')
plt.savefig('ROC',dpi=300)
plt.show();
Output:
Here is a utility function that takes a plotting function with necessary args and plots the figure with required background-color styles. You can add more arguments as necessary.
def plotfigure(plot_fn, fig, background_col = 'xkcd:black', face_col = (0.06,0.06,0.06)):
"""
Plot Figure using plt plot functions.
Customize different background and face-colors of the plot.
Parameters:
plot_fn (func): The plot functions with necessary arguments as a lamdda function.
fig : The Figure object by plt.figure()
background_col: The background color of the plot. Supports matlplotlib colors
face_col: The face color of the plot. Supports matlplotlib colors
Returns:
void
"""
fig.patch.set_facecolor(background_col)
plot_fn()
ax = plt.gca()
ax.set_facecolor(face_col)
ax.spines['bottom'].set_color('white')
ax.spines['top'].set_color('white')
ax.spines['left'].set_color('white')
ax.spines['right'].set_color('white')
ax.xaxis.label.set_color('white')
ax.yaxis.label.set_color('white')
ax.grid(alpha=0.1)
ax.title.set_color('white')
ax.tick_params(axis='x', colors='white')
ax.tick_params(axis='y', colors='white')
A use case is defined below
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
X, y = make_classification(n_samples=50, n_classes=2, n_features=5, random_state=27)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=27)
fig=plt.figure()
plotfigure(lambda: plt.scatter(range(0,len(y)), y, marker=".",c="orange"), fig)

Decrease index-ticks frequency

I have this sample data:
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import seaborn as sns
import pandas as pd
df = pd.DataFrame({'AAAAAAAAAAAAAAAAAAAA': np.random.choice([False,True], 100000),
'BBBBBBBBBBBBBBBBBBBB': np.random.choice([False,True], 100000),
'CCCCCCCCCCCCCCCCCCCC': np.random.choice([False,True], 100000)},
index= np.random.choice([202006,202006, 202006,202005,202005,202005,202004,202004,202003], 100000)).sort_index(ascending=False)
With this plot:
fig, ax = plt.subplots(figsize=(5, 6))
cmap = sns.mpl_palette("Set2", 2)
sns.heatmap(data=df, cmap=cmap, cbar=False)
plt.xticks(rotation=90, fontsize=10)
plt.yticks(rotation=0, fontsize=10)
legend_handles = [Patch(color=cmap[True], label='Missing Value'), # red
Patch(color=cmap[False], label='Non Missing Value')] # green
plt.legend(handles=legend_handles, ncol=2, bbox_to_anchor=[0.5, 1.02], loc='lower center', fontsize=8, handlelength=.8)
plt.tight_layout()
plt.show()
The overlapping occurs because of the length of the variables names (I cannot change them as they are informative in my real plot). So, I need to decrease the frequency of y-ticks, it could be two ticks per value (when the month changes), or simply? eliminating the overlapping you see in the image above. The y-ticks of this plot needs to show clearly when the next month starts and ends (202006 means June of 2020), because with the real data I have, I can see if a whole piece of data is missing for a whole month (or more months) for any variable.
All possible-adaptable solutions I have found are based when the ticks are from a column: Change tick frequency, adding space between ticks labels, increase spacing between ticks, among others. but I'm still struggling with any adaptation.
Any suggestions?
NOTE: You can't increase/decrease the size of the figure.
Create your DataFrame with a small correction, namely set the number
of elements as a variable (n):
n = 100000
df = pd.DataFrame({'AAAAAAAAAAAAAAAAAAAA': np.random.choice([False,True], n),
'BBBBBBBBBBBBBBBBBBBB': np.random.choice([False,True], n),
'CCCCCCCCCCCCCCCCCCCC': np.random.choice([False,True], n)},
index = np.random.choice([202006,202006, 202006,202005,202005,202005,
202004,202004,202003], n)).sort_index(ascending=False)
Then run your drawing code with another 2 corrections, namely:
set yLabelNo = 10 (the number of y labels),
pass yticklabels=n // yLabelNo to sns.heatmap.
So the code is:
yLabelNo = 10
fig, ax = plt.subplots(figsize=(5, 6))
cmap = sns.mpl_palette("Set2", 2)
sns.heatmap(data=df, cmap=cmap, cbar=False, yticklabels=n // yLabelNo)
plt.xticks(rotation=90, fontsize=10)
plt.yticks(rotation=0, fontsize=10)
legend_handles = [Patch(color=cmap[True], label='Missing Value'), # red
Patch(color=cmap[False], label='Non Missing Value')] # green
plt.legend(handles=legend_handles, ncol=2, bbox_to_anchor=[0.5, 1.02],
loc='lower center', fontsize=8, handlelength=.8)
plt.tight_layout()
plt.show()
And the result is:
If you wish, experiment with other (maybe smaller) values of yLabelNo.

Python - Pyplot x-axis not showing on graph

pyplot is not showing the x-axis on the graph:
import pandas as pd
import matplotlib.pyplot as plt
df = pd.read_csv('sitka_weather_2014.csv')
df['AKST'] = pd.to_datetime(df.AKST)
df['Dates'] = df['AKST'].dt.strftime('%b %d, %Y')
df.set_index("Dates", inplace= True)
# Plot Data
fig = plt.figure(dpi=256, figsize=(14, 7))
plt.title("Daily high and low temperature - 2014")
df['Max TemperatureF'].plot(linewidth=1, c='blue', label="Max Temperature °F")
df['Min TemperatureF'].plot(linewidth=1, c='red', label="Min Temperature °F")
plt.grid(True)
plt.rc('grid', linestyle=":", linewidth=1, color='gray')
plt.legend(loc='upper left')
plt.xlabel('', fontsize=10)
plt.ylabel("Temperature (°F)", fontsize=10)
plt.tick_params(axis='both', which='major', labelsize=10)
fig.autofmt_xdate(rotation=45)
plt.show()
The x-axis should be the index of the Pandas Dataframe (df) containing the dates.
Your code is actually fine. I tried to run it with the necessary sitka_weather_2014.csv file and it works.
The problem is that you can't see the x-axis because the size of the figure is too big, and thus the description of the x-axis dissapears. Try to scale your figure e.g. by making the dpi smaller:
fig = plt.figure(dpi=100, figsize=(14, 7)) #dpi=100 instead of dpi=256
Or make the labelsize smaller:
plt.tick_params(axis='both', which='major', labelsize=5) #labelsize=5 instead of labelsize=10
Whatever works best for you. But the code is fine and the description of the x-axis is showing.
You have your xlabel value set to null:
plt.xlabel('', fontsize=10)

Unable to generate legend using python / matlibplot for 4 lines all labelled

Want labels for Bollinger Bands (R) ('upper band', 'rolling mean', 'lower band') to show up in legend. But legend just applies the same label to each line with the pandas label for the first (only) column, 'IBM'.
# Plot price values, rolling mean and Bollinger Bands (R)
ax = prices['IBM'].plot(title="Bollinger Bands")
rm_sym.plot(label='Rolling mean', ax=ax)
upper_band.plot(label='upper band', c='r', ax=ax)
lower_band.plot(label='lower band', c='r', ax=ax)
#
# Add axis labels and legend
ax.set_xlabel("Date")
ax.set_ylabel("Adjusted Closing Price")
ax.legend(loc='upper left')
plt.show()
I know this code may represent a fundamental lack of understanding of how matlibplot works so explanations are particularly welcome.
The problem is most probably that whatever upper_band and lower_band are, they are not labeled.
One option is to label them by putting them as column to a dataframe. This will allow to plot the dataframe column directly.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
y =np.random.rand(4)
yupper = y+0.2
ylower = y-0.2
df = pd.DataFrame({"price" : y, "upper": yupper, "lower": ylower})
fig, ax = plt.subplots()
df["price"].plot(label='Rolling mean', ax=ax)
df["upper"].plot(label='upper band', c='r', ax=ax)
df["lower"].plot(label='lower band', c='r', ax=ax)
ax.legend(loc='upper left')
plt.show()
Otherwise you can also plot the data directly.
import matplotlib.pyplot as plt
import numpy as np
y =np.random.rand(4)
yupper = y+0.2
ylower = y-0.2
fig, ax = plt.subplots()
ax.plot(y, label='Rolling mean')
ax.plot(yupper, label='upper band', c='r')
ax.plot(ylower, label='lower band', c='r')
ax.legend(loc='upper left')
plt.show()
In both cases, you'll get a legend with labels. If that isn't enough, I recommend reading the Matplotlib Legend Guide which also tells you how to manually add labels to legends.

How to change the color of the axis, ticks and labels for a plot

I'd like to Change the color of the axis, as well as ticks and value-labels for a plot I did using matplotlib and PyQt.
Any ideas?
As a quick example (using a slightly cleaner method than the potentially duplicate question):
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(range(10))
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')
ax.spines['bottom'].set_color('red')
ax.spines['top'].set_color('red')
ax.xaxis.label.set_color('red')
ax.tick_params(axis='x', colors='red')
plt.show()
Alternatively
[t.set_color('red') for t in ax.xaxis.get_ticklines()]
[t.set_color('red') for t in ax.xaxis.get_ticklabels()]
If you have several figures or subplots that you want to modify, it can be helpful to use the matplotlib context manager to change the color, instead of changing each one individually. The context manager allows you to temporarily change the rc parameters only for the immediately following indented code, but does not affect the global rc parameters.
This snippet yields two figures, the first one with modified colors for the axis, ticks and ticklabels, and the second one with the default rc parameters.
import matplotlib.pyplot as plt
with plt.rc_context({'axes.edgecolor':'orange', 'xtick.color':'red', 'ytick.color':'green', 'figure.facecolor':'white'}):
# Temporary rc parameters in effect
fig, (ax1, ax2) = plt.subplots(1,2)
ax1.plot(range(10))
ax2.plot(range(10))
# Back to default rc parameters
fig, ax = plt.subplots()
ax.plot(range(10))
You can type plt.rcParams to view all available rc parameters, and use list comprehension to search for keywords:
# Search for all parameters containing the word 'color'
[(param, value) for param, value in plt.rcParams.items() if 'color' in param]
For those using pandas.DataFrame.plot(), matplotlib.axes.Axes is returned when creating a plot from a dataframe. Therefore, the dataframe plot can be assigned to a variable, ax, which enables the usage of the associated formatting methods.
The default plotting backend for pandas, is matplotlib.
See matplotlib.spines
Tested in python 3.10, pandas 1.4.2, matplotlib 3.5.1, seaborn 0.11.2
import pandas as pd
# test dataframe
data = {'a': range(20), 'date': pd.bdate_range('2021-01-09', freq='D', periods=20)}
df = pd.DataFrame(data)
# plot the dataframe and assign the returned axes
ax = df.plot(x='date', color='green', ylabel='values', xlabel='date', figsize=(8, 6))
# set various colors
ax.spines['bottom'].set_color('blue')
ax.spines['top'].set_color('red')
ax.spines['right'].set_color('magenta')
ax.spines['right'].set_linewidth(3)
ax.spines['left'].set_color('orange')
ax.spines['left'].set_lw(3)
ax.xaxis.label.set_color('purple')
ax.yaxis.label.set_color('silver')
ax.tick_params(colors='red', which='both') # 'both' refers to minor and major axes
seaborn axes-level plot
import seaborn as sns
# plot the dataframe and assign the returned axes
fig, ax = plt.subplots(figsize=(12, 5))
g = sns.lineplot(data=df, x='date', y='a', color='g', label='a', ax=ax)
# set the margines to 0
ax.margins(x=0, y=0)
# set various colors
ax.spines['bottom'].set_color('blue')
ax.spines['top'].set_color('red')
ax.spines['right'].set_color('magenta')
ax.spines['right'].set_linewidth(3)
ax.spines['left'].set_color('orange')
ax.spines['left'].set_lw(3)
ax.xaxis.label.set_color('purple')
ax.yaxis.label.set_color('silver')
ax.tick_params(colors='red', which='both') # 'both' refers to minor and major axes
seaborn figure-level plot
# plot the dataframe and assign the returned axes
g = sns.relplot(kind='line', data=df, x='date', y='a', color='g', aspect=2)
# iterate through each axes
for ax in g.axes.flat:
# set the margins to 0
ax.margins(x=0, y=0)
# make the top and right spines visible
ax.spines[['top', 'right']].set_visible(True)
# set various colors
ax.spines['bottom'].set_color('blue')
ax.spines['top'].set_color('red')
ax.spines['right'].set_color('magenta')
ax.spines['right'].set_linewidth(3)
ax.spines['left'].set_color('orange')
ax.spines['left'].set_lw(3)
ax.xaxis.label.set_color('purple')
ax.yaxis.label.set_color('silver')
ax.tick_params(colors='red', which='both') # 'both' refers to minor and major axes
motivated by previous contributors, this is an example of three axes.
import matplotlib.pyplot as plt
x_values1=[1,2,3,4,5]
y_values1=[1,2,2,4,1]
x_values2=[-1000,-800,-600,-400,-200]
y_values2=[10,20,39,40,50]
x_values3=[150,200,250,300,350]
y_values3=[-10,-20,-30,-40,-50]
fig=plt.figure()
ax=fig.add_subplot(111, label="1")
ax2=fig.add_subplot(111, label="2", frame_on=False)
ax3=fig.add_subplot(111, label="3", frame_on=False)
ax.plot(x_values1, y_values1, color="C0")
ax.set_xlabel("x label 1", color="C0")
ax.set_ylabel("y label 1", color="C0")
ax.tick_params(axis='x', colors="C0")
ax.tick_params(axis='y', colors="C0")
ax2.scatter(x_values2, y_values2, color="C1")
ax2.set_xlabel('x label 2', color="C1")
ax2.xaxis.set_label_position('bottom') # set the position of the second x-axis to bottom
ax2.spines['bottom'].set_position(('outward', 36))
ax2.tick_params(axis='x', colors="C1")
ax2.set_ylabel('y label 2', color="C1")
ax2.yaxis.tick_right()
ax2.yaxis.set_label_position('right')
ax2.tick_params(axis='y', colors="C1")
ax3.plot(x_values3, y_values3, color="C2")
ax3.set_xlabel('x label 3', color='C2')
ax3.xaxis.set_label_position('bottom')
ax3.spines['bottom'].set_position(('outward', 72))
ax3.tick_params(axis='x', colors='C2')
ax3.set_ylabel('y label 3', color='C2')
ax3.yaxis.tick_right()
ax3.yaxis.set_label_position('right')
ax3.spines['right'].set_position(('outward', 36))
ax3.tick_params(axis='y', colors='C2')
plt.show()
You can also use this to draw multiple plots in same figure and style them using same color palette.
An example is given below
fig = plt.figure()
# Plot ROC curves
plotfigure(lambda: plt.plot(fpr1, tpr1, linestyle='--',color='orange', label='Logistic Regression'), fig)
plotfigure(lambda: plt.plot(fpr2, tpr2, linestyle='--',color='green', label='KNN'), fig)
plotfigure(lambda: plt.plot(p_fpr, p_tpr, linestyle='-', color='blue'), fig)
# Title
plt.title('ROC curve')
# X label
plt.xlabel('False Positive Rate')
# Y label
plt.ylabel('True Positive rate')
plt.legend(loc='best',labelcolor='white')
plt.savefig('ROC',dpi=300)
plt.show();
Output:
Here is a utility function that takes a plotting function with necessary args and plots the figure with required background-color styles. You can add more arguments as necessary.
def plotfigure(plot_fn, fig, background_col = 'xkcd:black', face_col = (0.06,0.06,0.06)):
"""
Plot Figure using plt plot functions.
Customize different background and face-colors of the plot.
Parameters:
plot_fn (func): The plot functions with necessary arguments as a lamdda function.
fig : The Figure object by plt.figure()
background_col: The background color of the plot. Supports matlplotlib colors
face_col: The face color of the plot. Supports matlplotlib colors
Returns:
void
"""
fig.patch.set_facecolor(background_col)
plot_fn()
ax = plt.gca()
ax.set_facecolor(face_col)
ax.spines['bottom'].set_color('white')
ax.spines['top'].set_color('white')
ax.spines['left'].set_color('white')
ax.spines['right'].set_color('white')
ax.xaxis.label.set_color('white')
ax.yaxis.label.set_color('white')
ax.grid(alpha=0.1)
ax.title.set_color('white')
ax.tick_params(axis='x', colors='white')
ax.tick_params(axis='y', colors='white')
A use case is defined below
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
X, y = make_classification(n_samples=50, n_classes=2, n_features=5, random_state=27)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=27)
fig=plt.figure()
plotfigure(lambda: plt.scatter(range(0,len(y)), y, marker=".",c="orange"), fig)

Categories