I need to fix an errorbar like in the graph, but I don't know how to use it. I get an error, and it doesn't work. Please can you help me?
#! /usr/bin/python3
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
x = np.arange(9)
country = [
"Finland",
"Denmark",
"Switzerland",
"Iceland",
"Netherland",
"Norway",
"Sweden",
"Luxembourg"
]
data = {
"Explained by : Log GDP per Capita": [1.446, 1.502, 1.566, 1.482, 1.501, 1.543, 1.478, 1.751],
"Explained by : Social Support": [1.106, 1.108, 1.079, 1.172, 1.079, 1.108, 1.062, 1.003],
"Explained by : Healthy life expectancy": [0.741, 0.763, 0.816, 0.772, 0.753, 0.782, 0.763, 0.760],
"Explained by : Freedom to make life choices": [0.691, 0.686, 0.653, 0.698, 0.647, 0.703, 0.685, 0.639],
"Explained by : Generosity": [0.124, 0.208, 0.204, 0.293, 0.302, 0.249, 0.244, 0.166],
"Explained by : Perceptions of corruption": [0.481, 0.485, 0.413, 0.170, 0.384, 0.427, 0.448, 0.353],
"Dystopia + residual": [3.253, 2.868, 2.839, 2.967, 2.798, 2.580, 2.683, 2.653]
}
error_value = [[7.904, 7.780], [7.687, 7.552], [7.643, 7.500], [7.670, 7.438], [7.518, 7.410], [7.462, 7.323], [7.433, 7.293], [7.396, 7.252]]
df = pd.DataFrame(data, index=country)
df.plot(width=0.1, kind='barh', stacked=True, figsize=(11, 8))
plt.subplots_adjust(bottom=0.2)
# plt.errorbar(country, error_value, yerr=error_value)
plt.axvline(x=2.43, label="Dystopia (hapiness=2.43)")
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
fancybox=True, shadow=True, ncol=3)
plt.xticks(x)
plt.show()
Error bars are drawn as differences from the center. You provide seemingly the values where each error bar ends, so you have to recalculate the distance to the endpoint and provide a numpy array in form (2, N) where the first row contains the negative errorbar values and the second row the positive values:
...
df.plot(width=0.1, kind='barh', stacked=True, figsize=(11, 8))
#determine x-values of the stacked bars
country_sum = df.sum(axis=1).values
#calculate differences of error bar values to bar heights
#and transform array into necessary (2, N) form
err_vals = np.abs(np.asarray(error_value).T - country_sum[None])[::-1, :]
plt.errorbar(country_sum, np.arange(df.shape[0]), xerr=err_vals, capsize=4, color="k", ls="none")
plt.subplots_adjust(bottom=0.2)
...
Sample output:
Related
The code below produce this graph. I wonder if there is a way to make the lines between value1 and value2 into arrows, pointing in the direction of 1 to 2, from blue to green (In this case none of blues is lower than the greens).
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Create a dataframe
value1=np.random.uniform(size=20)
value2=value1+np.random.uniform(size=20)/4
df = pd.DataFrame({'group':list(map(chr, range(65, 85))), 'value1':value1 , 'value2':value2 })
# Reorder it following the values of the first value:
ordered_df = df.sort_values(by='value1')
my_range=range(1,len(df.index)+1)
# The horizontal plot is made using the hline function
plt.hlines(y=my_range, xmin=ordered_df['value1'], xmax=ordered_df['value2'], color='grey', alpha=0.4)
plt.scatter(ordered_df['value1'], my_range, color='skyblue', alpha=1, label='value1')
plt.scatter(ordered_df['value2'], my_range, color='green', alpha=0.4 , label='value2')
plt.legend()
# Add title and axis names
plt.yticks(my_range, ordered_df['group'])
plt.title("Comparison of the value 1 and the value 2", loc='left')
plt.xlabel('Value of the variables')
plt.ylabel('Group')
# Show the graph
plt.show()
The best option for multiple arrows is matplotlib.pyplot.quiver, because it accepts an array or dataframe of locations, unlike matplotlib.pyplot.arrow, which only accepts a single value.
Since the y-axis labels are defined by 'group', which are letters, use V = np.zeros(len(ordered_df)) or V = ordered_df.index - ordered_df.index for the .quiver direction vector.
Plot the dataframe directly with pandas.DataFrame.plot and kind='scatter'.
Tested in python 3.8.12, pandas 1.3.3, matplotlib 3.4.3
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# Create a dataframe
np.random.seed(354)
value1=np.random.uniform(size=20)
value2=value1+np.random.uniform(size=20)/4
df = pd.DataFrame({'group':list(map(chr, range(65, 85))), 'value1':value1 , 'value2':value2 })
# Reorder it following the values of the first value and reset the index so the index values correspond to the y-axis tick locations
ordered_df = df.sort_values(by='value1').reset_index(drop=True)
# plot the dataframe
ax = ordered_df.plot(kind='scatter', x='value1', y='group', color='skyblue', alpha=1, figsize=(8, 6), label='value1')
ordered_df.plot(kind='scatter', x='value2', y='group', color='green', alpha=1, ax=ax, label='value2', xlabel='Value of the variables', ylabel='Group')
# plot the arrows
V = ordered_df.index - ordered_df.index # the Y direction vector is 0 for each
ax.quiver(ordered_df.value1, ordered_df.group, (ordered_df.value2-ordered_df.value1), V, width=0.003, color='gray', scale_units='x', scale=1)
# Add title with position
ax.set_title("Comparison of the value 1 and the value 2", loc='left')
# Show the graph
plt.show()
You can use plt.arrow instead of plt.hlines, but you have to loop over the rows:
for y, (_, row) in enumerate(ordered_df.iterrows()):
arrow_head_length = 0.02
plt.arrow(x=row['value1'], y=y+1, dx=row['value2']-row['value1']-arrow_head_length, dy=0,
head_width=0.5, head_length=arrow_head_length, fc='k', ec='k',
color='grey', alpha=0.4)
example:
I have a data frame that contains 4 columns of data. Each of these columns is a character variable containing 5 different values ( i.e. column1 contains the values A,B,C,D or E . column2 contains the values EXCELLENT , VERY GOOD, GOOD, AVERAGE, and POOR. columns 3 and 4 are similar.
I'm trying to get a separate bar chart for each of the columns by using the below for loop. Unfortunately, it only provides me with the bar chart for column 4. It does not provide the bar chart for the previous 3 columns. Not sure what I am doing wrong.
categorical_attribs=list(CharacterVarDF)
for i in categorical_attribs:
CharacterVarDF [i].value_counts().plot(kind='bar')
Simply set up matplotlib subplots with number of rows and columns. Then in loop, assign each column bar plot to each ax:
import matplotlib.pyplot as plt
...
fig, axes = plt.subplots(figsize=(8,6), ncols=1, nrows=CharacterVarDF.shape[1])
for col, ax in zip(CharacterVarDF.columns, np.ravel(axes)):
CharacterVarDF[col].value_counts().plot(kind='bar', ax=ax, rot=0, title=col)
plt.tight_layout()
plt.show()
To demonstrate with random data:
import numpy as np
import pandas as pd
from matplotlib import rc
import matplotlib.pyplot as plt
np.random.seed(52021)
env_df = pd.DataFrame({
"planetary_boundaries": np.random.choice(
["ocean", "land", "biosphere", "atmosphere",
"climate", "soil", "ozone", "freshwater"], 50),
"species": np.random.choice(
["invertebrates", "vertebrates", "plants", "fungi & protists"], 50),
"tipping_points": np.random.choice(
["Arctic Sea Ice", "Greenland ice sheet", "West Antarctica ice sheet",
"Amazon Rainforest", "Boreal forest", "Indian Monsoon",
"Atlantic meridional overturning circulation",
"West African Monsoon", "Coral reef"], 50)
})
rc('font', **{'family' : 'Arial'})
fig, axes = plt.subplots(ncols=1, nrows=env_df.shape[1], figsize=(7,7))
for col, ax in zip(env_df.columns, np.ravel(axes)):
env_df[col] = env_df[col].str.replace(" ", "\n")
env_df[col].value_counts(sort=False).sort_index().plot(
kind='bar', ax=ax, color='g', rot=0,
title=col.replace("_", " ").title(),
)
plt.tight_layout()
plt.show()
plt.clf()
plt.close()
I'm making a barplot using 3 datasets in seaborn, however each datapoint overlays the previous, regardless of if it is now hiding the previous plot. eg:
sns.barplot(x="Portfolio", y="Factor", data=d2,
label="Portfolio", color="g")
sns.barplot(x="Benchmark", y="Factor", data=d2,
label="Benchmark", color="b")
sns.barplot(x="Active Exposure", y="Factor", data=d2,
label="Active", color="r")
ax.legend(frameon=True)
ax.set(xlim=(-.1, .5), ylabel="", xlabel="Sector Decomposition")
sns.despine(left=True, bottom=True)
However, I want it to show green, even if the blue being overlayed is greater. Any ideas?
Without being able to see your data I can only guess that your dataframe is not in long-form. There's a section on the seaborn tutorial on the expected shape of DataFrames that seaborn is expecting, I'd take a look there for more info, specifically the section on messy data.
Because I can't see your DataFrame I have made some assumptions about it's shape:
import numpy as np
import pandas as pd
import seaborn as sns
df = pd.DataFrame({
"Factor": list("ABC"),
"Portfolio": np.random.random(3),
"Benchmark": np.random.random(3),
"Active Exposure": np.random.random(3),
})
# Active Exposure Benchmark Factor Portfolio
# 0 0.140177 0.112653 A 0.669687
# 1 0.823740 0.078819 B 0.072474
# 2 0.450814 0.702114 C 0.039068
We can melt this DataFrame to get the long-form data seaborn wants:
d2 = df.melt(id_vars="Factor", var_name="exposure")
# Factor exposure value
# 0 A Active Exposure 0.140177
# 1 B Active Exposure 0.823740
# 2 C Active Exposure 0.450814
# 3 A Benchmark 0.112653
# 4 B Benchmark 0.078819
# 5 C Benchmark 0.702114
# 6 A Portfolio 0.669687
# 7 B Portfolio 0.072474
# 8 C Portfolio 0.039068
Then, finally we can plot out box plot using the seaborn's builtin aggregations:
ax = sns.barplot(x="value", y="Factor", hue="exposure", data=d2)
ax.set(ylabel="", xlabel="Sector Decomposition")
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
Which produces:
Here's the plot params I used to make this chart:
import matplotlib as mpl
# Plot configuration
mpl.style.use("seaborn-pastel")
mpl.rcParams.update(
{
"font.size": 14,
"figure.facecolor": "w",
"axes.facecolor": "w",
"axes.spines.right": False,
"axes.spines.top": False,
"axes.spines.bottom": False,
"xtick.top": False,
"xtick.bottom": False,
"ytick.right": False,
"ytick.left": False,
}
)
If you are fine without using seaborn you can use pandas plotting to create a stacked horizontal bar chart (barh):
import pandas as pd
import matplotlib as mpl
# Plot configuration
mpl.style.use("seaborn-pastel")
mpl.rcParams.update(
{
"font.size": 14,
"figure.facecolor": "w",
"axes.facecolor": "w",
"axes.spines.right": False,
"axes.spines.top": False,
"axes.spines.bottom": False,
"xtick.top": False,
"xtick.bottom": False,
"ytick.right": False,
"ytick.left": False,
}
)
df = pd.DataFrame({
"Factor": list("ABC"),
"Portfolio": [0.669687, 0.072474, 0.039068],
"Benchmark": [0.112653, 0.078819, 0.702114],
"Active Exposure": [0.140177, 0.823740, 0.450814],
}).set_index("Factor")
ax = df.plot.barh(stacked=True)
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
ax.set_ylabel("")
ax.set_xlabel("Sector Decomposition")
Notice in the code above the index is set to Factor which then becomes the y axis.
If you don't set stacked=True you get almost the same chart as seaborn produced:
ax = df.plot.barh(stacked=False)
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
ax.set_ylabel("")
ax.set_xlabel("Sector Decomposition")
UPDATE
Trying some more, I managed to run this code without error:
from matplotlib.pyplot import figure
dict = pd.DataFrame({"Return": mkw_returns, "Standard Deviation": mkw_stds})
dict.head()
#plt.annotate("Sharpe Ratio", xytext=(0.5,0.5), xy=(0.03,0.03) , arrowprops=dict(facecolor='blue', shrink=0.01, width=220)) # arrowprops={width = 3, "facecolor":
#dict.plot(x="Standard Deviation", y = "Return", kind="scatter", figsize=(10,6))
#plt.xlabel("Standard Deviations")
#plt.ylabel("log_Return YoY")
figure(num=None, figsize=(15, 10), dpi=100, facecolor='w', edgecolor='k')
plt.plot( 'Standard Deviation', 'Return', data=dict, linestyle='none', marker='o')
plt.xlabel("Standard Deviations")
plt.ylabel("log_Return YoY")
# Annotate with text + Arrow
plt.annotate(
# Label and coordinate
'This is a Test', xy=(0.01, 1), xytext=(0.01, 1), color= "r", arrowprops={"facecolor": 'black', "shrink": 0.05}
)
Which now works YaY, can anybody shed some light onto this issue? Im not so sure why it suddenly started working. Thank you :)
Also, how would I simply mark a point, instead of using the arrow?
Problem: Cannot figure out how to mark/select/highlight a specific point in my scatter graph
(Python 3 Beginner)
So my goal is to highlight one or more points in a scatter graph with some text by it or supplied by a legend.
https://imgur.com/a/VWeO1EH
(not enough reputation to post images, sorry)
dict = pd.DataFrame({"Return": mkw_returns, "Standard Deviation": mkw_stds})
dict.head()
#plt.annotate("Sharpe Ratio", xytext=(0.5,0.5), xy=(0.03,0.03) , arrowprops=dict(facecolor='blue', shrink=0.01, width=220)) # arrowprops={width = 3, "facecolor":
dict.plot(x="Standard Deviation", y = "Return", kind="scatter", figsize=(10,6))
plt.xlabel("Standard Deviations")
plt.ylabel("log_Return YoY")
The supressed "plt.annotate" would give an error as specified below.
Specifically i would like to select the sharpe ratio, but for now Im happy if I manage to select any point in the scatter graph.
Im truly confused how to work with matplotlib, so any help is welcomed
I tried the following solutions I found online:
I)
This shows a simple way to use annotate in a plot, to mark a specific point by an arrow.
https://www.youtube.com/watch?v=ItHDZEE5wSk
However the pd.dataframe environment does not like annotate and i get the error:
TypeError: 'DataFrame' object is not callable
II)
Since Im running into issues with annotate in a Data Frame environment, I looked at the following solution
Annotate data points while plotting from Pandas DataFrame
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import string
df = pd.DataFrame({'x':np.random.rand(10), 'y':np.random.rand(10)},
index=list(string.ascii_lowercase[:10]))
fig, ax = plt.subplots()
df.plot('x', 'y', kind='scatter', ax=ax, figsize=(10,6))
for k, v in df.iterrows():
ax.annotate(k, v)
However the resulting plot does not show any annotation what so ever when applied to my problem, besides this very long horizontal scroll bar
https://imgur.com/a/O8ykmeg
III)
Further, I stumbled upon this solution, to use a marker instead of an arrow,
Matplotlib annotate with marker instead of arrow
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
x=[1,2,3,4,5,6,7,8,9,10]
y=[1,1,1,2,10,2,1,1,1,1]
line, = ax.plot(x, y)
ymax = max(y)
xpos = y.index(ymax)
xmax = x[xpos]
# Add dot and corresponding text
ax.plot(xmax, ymax, 'ro')
ax.text(xmax, ymax+2, 'local max:' + str(ymax))
ax.set_ylim(0,20)
plt.show()
however the code does absolutely nothing, when applied to my situation like so
dict = pd.DataFrame({"Return": mkw_returns, "Standard Deviation": mkw_stds})
dict.head()
plt.annotate("Sharpe Ratio", xytext=(0.5,0.5), xy=(0.03,0.03) , arrowprops=dict(facecolor='blue', shrink=0.01, width=220)) # arrowprops={width = 3, "facecolor":
dict.plot(x="Standard Deviation", y = "Return", kind="scatter", figsize=(10,6))
plt.xlabel("Standard Deviations")
plt.ylabel("log_Return YoY")
ymax = max(y)
xpos = y.index(ymax)
xmax = x[xpos]
# Add dot and corresponding text
ax.plot(xmax, ymax, 'ro')
ax.text(xmax, ymax+2, 'local max:' + str(ymax))
ax.set_ylim(0,20)
plt.show()
IV)
Lastly, I tried a solution that apparently works flawlessly with an arrow in a pd.dataframe,
https://python-graph-gallery.com/193-annotate-matplotlib-chart/
# Library
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Basic chart
df=pd.DataFrame({'x': range(1,101), 'y': np.random.randn(100)*15+range(1,101) })
plt.plot( 'x', 'y', data=df, linestyle='none', marker='o')
# Annotate with text + Arrow
plt.annotate(
# Label and coordinate
'This point is interesting!', xy=(25, 50), xytext=(0, 80),
# Custom arrow
arrowprops=dict(facecolor='black', shrink=0.05)
)
however running this code yields me the same error as above:
TypeError: 'DataFrame' object is not callable
Version:
import sys; print(sys.version)
3.7.1 (default, Dec 10 2018, 22:54:23) [MSC v.1915 64 bit (AMD64)]
Sorry for the WoT, but I thought its best to have everything I tried together in one post.
Any help is appreciated, thank you!
I think one solution is the following, as posted above as the "UPDATE":
UPDATE
Trying some more, I managed to run this code without error:
from matplotlib.pyplot import figure
dict = pd.DataFrame({"Return": mkw_returns, "Standard Deviation": mkw_stds})
dict.head()
#plt.annotate("Sharpe Ratio", xytext=(0.5,0.5), xy=(0.03,0.03) , arrowprops=dict(facecolor='blue', shrink=0.01, width=220)) # arrowprops={width = 3, "facecolor":
#dict.plot(x="Standard Deviation", y = "Return", kind="scatter", figsize=(10,6))
#plt.xlabel("Standard Deviations")
#plt.ylabel("log_Return YoY")
figure(num=None, figsize=(15, 10), dpi=100, facecolor='w', edgecolor='k')
plt.plot( 'Standard Deviation', 'Return', data=dict, linestyle='none', marker='o')
plt.xlabel("Standard Deviations")
plt.ylabel("log_Return YoY")
# Annotate with text + Arrow
plt.annotate(
# Label and coordinate
'This is a Test', xy=(0.01, 1), xytext=(0.01, 1), color= "r", arrowprops={"facecolor": 'black', "shrink": 0.05}
)
One question remains, how can I use a different marker or color and write about it in the legend instead?
Thanks in advance :)
Something like this:
There is a very good package to do it in R. In python, the best that I could figure out is this, using the squarify package (inspired by a post on how to do treemaps):
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns # just to have better line color and width
import squarify
# for those using jupyter notebooks
%matplotlib inline
df = pd.DataFrame({
'v1': np.ones(100),
'v2': np.random.randint(1, 4, 100)})
df.sort_values(by='v2', inplace=True)
# color scale
cmap = mpl.cm.Accent
mini, maxi = df['v2'].min(), df['v2'].max()
norm = mpl.colors.Normalize(vmin=mini, vmax=maxi)
colors = [cmap(norm(value)) for value in df['v2']]
# figure
fig = plt.figure()
ax = fig.add_subplot(111, aspect="equal")
ax = squarify.plot(df['v1'], color=colors, ax=ax)
ax.set_xticks([])
ax.set_yticks([]);
But when I create not 100 but 200 elements (or other non-square numbers), the squares become misaligned.
Another problem is that if I change v2 to some categorical variable (e.g., a hundred As, Bs, Cs and Ds), I get this error:
could not convert string to float: 'a'
So, could anyone help me with these two questions:
how can I solve the alignment problem with non-square numbers of observations?
how can use categorical variables in v2?
Beyond this, I am really open if there are any other python packages that can create waffle plots more efficiently.
I spent a few days to build a more general solution, PyWaffle.
You can install it through
pip install pywaffle
The source code: https://github.com/gyli/PyWaffle
PyWaffle does not use matshow() method, but builds those squares one by one. That makes it easier for customization. Besides, what it provides is a custom Figure class, which returns a figure object. By updating attributes of the figure, you can basically control everything in the chart.
Some examples:
Colored or transparent background:
import matplotlib.pyplot as plt
from pywaffle import Waffle
data = {'Democratic': 48, 'Republican': 46, 'Libertarian': 3}
fig = plt.figure(
FigureClass=Waffle,
rows=5,
values=data,
colors=("#983D3D", "#232066", "#DCB732"),
title={'label': 'Vote Percentage in 2016 US Presidential Election', 'loc': 'left'},
labels=["{0} ({1}%)".format(k, v) for k, v in data.items()],
legend={'loc': 'lower left', 'bbox_to_anchor': (0, -0.4), 'ncol': len(data), 'framealpha': 0}
)
fig.gca().set_facecolor('#EEEEEE')
fig.set_facecolor('#EEEEEE')
plt.show()
Use icons replacing squares:
data = {'Democratic': 48, 'Republican': 46, 'Libertarian': 3}
fig = plt.figure(
FigureClass=Waffle,
rows=5,
values=data,
colors=("#232066", "#983D3D", "#DCB732"),
legend={'loc': 'upper left', 'bbox_to_anchor': (1, 1)},
icons='child', icon_size=18,
icon_legend=True
)
Multiple subplots in one chart:
import pandas as pd
data = pd.DataFrame(
{
'labels': ['Hillary Clinton', 'Donald Trump', 'Others'],
'Virginia': [1981473, 1769443, 233715],
'Maryland': [1677928, 943169, 160349],
'West Virginia': [188794, 489371, 36258],
},
).set_index('labels')
fig = plt.figure(
FigureClass=Waffle,
plots={
'311': {
'values': data['Virginia'] / 30000,
'labels': ["{0} ({1})".format(n, v) for n, v in data['Virginia'].items()],
'legend': {'loc': 'upper left', 'bbox_to_anchor': (1.05, 1), 'fontsize': 8},
'title': {'label': '2016 Virginia Presidential Election Results', 'loc': 'left'}
},
'312': {
'values': data['Maryland'] / 30000,
'labels': ["{0} ({1})".format(n, v) for n, v in data['Maryland'].items()],
'legend': {'loc': 'upper left', 'bbox_to_anchor': (1.2, 1), 'fontsize': 8},
'title': {'label': '2016 Maryland Presidential Election Results', 'loc': 'left'}
},
'313': {
'values': data['West Virginia'] / 30000,
'labels': ["{0} ({1})".format(n, v) for n, v in data['West Virginia'].items()],
'legend': {'loc': 'upper left', 'bbox_to_anchor': (1.3, 1), 'fontsize': 8},
'title': {'label': '2016 West Virginia Presidential Election Results', 'loc': 'left'}
},
},
rows=5,
colors=("#2196f3", "#ff5252", "#999999"), # Default argument values for subplots
figsize=(9, 5) # figsize is a parameter of plt.figure
)
I've put together a working example, below, which I think meets your needs. Some work is needed to fully generalize the approach, but I think you'll find that it's a good start. The trick was to use matshow() to solve your non-square problem, and to build a custom legend to easily account for categorical values.
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
# Let's make a default data frame with catagories and values.
df = pd.DataFrame({ 'catagories': ['cat1', 'cat2', 'cat3', 'cat4'],
'values': [84911, 14414, 10062, 8565] })
# Now, we define a desired height and width.
waffle_plot_width = 20
waffle_plot_height = 7
classes = df['catagories']
values = df['values']
def waffle_plot(classes, values, height, width, colormap):
# Compute the portion of the total assigned to each class.
class_portion = [float(v)/sum(values) for v in values]
# Compute the number of tiles for each catagories.
total_tiles = width * height
tiles_per_class = [round(p*total_tiles) for p in class_portion]
# Make a dummy matrix for use in plotting.
plot_matrix = np.zeros((height, width))
# Popoulate the dummy matrix with integer values.
class_index = 0
tile_index = 0
# Iterate over each tile.
for col in range(waffle_plot_width):
for row in range(height):
tile_index += 1
# If the number of tiles populated is sufficient for this class...
if tile_index > sum(tiles_per_class[0:class_index]):
# ...increment to the next class.
class_index += 1
# Set the class value to an integer, which increases with class.
plot_matrix[row, col] = class_index
# Create a new figure.
fig = plt.figure()
# Using matshow solves your "non-square" problem.
plt.matshow(plot_matrix, cmap=colormap)
plt.colorbar()
# Get the axis.
ax = plt.gca()
# Minor ticks
ax.set_xticks(np.arange(-.5, (width), 1), minor=True);
ax.set_yticks(np.arange(-.5, (height), 1), minor=True);
# Gridlines based on minor ticks
ax.grid(which='minor', color='w', linestyle='-', linewidth=2)
# Manually constructing a legend solves your "catagorical" problem.
legend_handles = []
for i, c in enumerate(classes):
lable_str = c + " (" + str(values[i]) + ")"
color_val = colormap(float(i+1)/len(classes))
legend_handles.append(mpatches.Patch(color=color_val, label=lable_str))
# Add the legend. Still a bit of work to do here, to perfect centering.
plt.legend(handles=legend_handles, loc=1, ncol=len(classes),
bbox_to_anchor=(0., -0.1, 0.95, .10))
plt.xticks([])
plt.yticks([])
# Call the plotting function.
waffle_plot(classes, values, waffle_plot_height, waffle_plot_width,
plt.cm.coolwarm)
Below is an example of the output this script produced. As you can see, it works fairly well for me, and meets all of your stated needs. Just let me know if it gives you any trouble. Enjoy!
You can use this function for automatic creation of a waffle with simple parameters:
def create_waffle_chart(categories, values, height, width, colormap, value_sign=''):
# compute the proportion of each category with respect to the total
total_values = sum(values)
category_proportions = [(float(value) / total_values) for value in values]
# compute the total number of tiles
total_num_tiles = width * height # total number of tiles
print ('Total number of tiles is', total_num_tiles)
# compute the number of tiles for each catagory
tiles_per_category = [round(proportion * total_num_tiles) for proportion in category_proportions]
# print out number of tiles per category
for i, tiles in enumerate(tiles_per_category):
print (df_dsn.index.values[i] + ': ' + str(tiles))
# initialize the waffle chart as an empty matrix
waffle_chart = np.zeros((height, width))
# define indices to loop through waffle chart
category_index = 0
tile_index = 0
# populate the waffle chart
for col in range(width):
for row in range(height):
tile_index += 1
# if the number of tiles populated for the current category
# is equal to its corresponding allocated tiles...
if tile_index > sum(tiles_per_category[0:category_index]):
# ...proceed to the next category
category_index += 1
# set the class value to an integer, which increases with class
waffle_chart[row, col] = category_index
# instantiate a new figure object
fig = plt.figure()
# use matshow to display the waffle chart
colormap = plt.cm.coolwarm
plt.matshow(waffle_chart, cmap=colormap)
plt.colorbar()
# get the axis
ax = plt.gca()
# set minor ticks
ax.set_xticks(np.arange(-.5, (width), 1), minor=True)
ax.set_yticks(np.arange(-.5, (height), 1), minor=True)
# add dridlines based on minor ticks
ax.grid(which='minor', color='w', linestyle='-', linewidth=2)
plt.xticks([])
plt.yticks([])
# compute cumulative sum of individual categories to match color schemes between chart and legend
values_cumsum = np.cumsum(values)
total_values = values_cumsum[len(values_cumsum) - 1]
# create legend
legend_handles = []
for i, category in enumerate(categories):
if value_sign == '%':
label_str = category + ' (' + str(values[i]) + value_sign + ')'
else:
label_str = category + ' (' + value_sign + str(values[i]) + ')'
color_val = colormap(float(values_cumsum[i])/total_values)
legend_handles.append(mpatches.Patch(color=color_val, label=label_str))
# add legend to chart
plt.legend(
handles=legend_handles,
loc='lower center',
ncol=len(categories),
bbox_to_anchor=(0., -0.2, 0.95, .1)
)