I found an excellent tutorial on drawing a heatmap for a confusion matrix, but I want to add some errors of commission and omission on the sides.
I'll try to explain using this image:
This means:
I need to insert a number beside each of the boxes containing 0, 6, and 9 just right of the right edge of the image, and to the left of the legend
I need to insert a number above the each of boxes containing 13, 0 and 0 just above the top edge of the image, just below the title.
(so 6 numbers in total)
Is this even possible? I know nothing about the plotting functions in Python, as I'm new to the language. It just seems like a very difficult task from where I'm standing.
Use the following modified function. The idea is following:
Add two twin axes - one to the right and other to the top.
Set the limits of the twin axes equal to that of the original axes
Set the positions of the ticks on the twin axes to be the same as that of the original axes
Hide the tick marks and assign the tick-labels
Shift the title a bit upward using y=1.1
def plot_confusion_matrix(y_true, y_pred, classes, normalize=False,
title=None, cmap=plt.cm.Blues):
if not title:
if normalize:
title = 'Normalized confusion matrix'
else:
title = 'Confusion matrix, without normalization'
cm = confusion_matrix(y_true, y_pred)
classes = classes[unique_labels(y_true, y_pred)]
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
fig, ax = plt.subplots(figsize=(6.5,6))
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
xticklabels=classes, yticklabels=classes,
ylabel='True label',
xlabel='Predicted label')
ax.set_title(title, y=1.1)
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Adding data to the right
ax2 = ax.twinx()
ax2.set_ylim(ax.get_ylim())
ax2.set_yticks(np.arange(cm.shape[0]))
ax2.set_yticklabels(cm[:, -1])
ax2.tick_params(axis="y", right=False)
# Adding data to the top
ax3 = ax.twiny()
ax3.set_xlim(ax.get_xlim())
ax3.set_xticks(np.arange(cm.shape[0]))
ax3.set_xticklabels(cm[:, 0])
ax3.tick_params(axis="x", top=False)
ax.set_aspect('auto')
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
return ax
You could do this using ticks.
Let me present this approach with the following easy plot:
from matplotlib import pyplot as plt
ax = plt.axes()
ax.set_xlim(0, 3)
ax.set_ylim(0, 3)
for i in range(3):
for j in range(3):
ax.fill_between((i, i+1), j, j+1)
ax.fill_between((i, i+1), j, j+1)
ax.fill_between((i, i+1), j, j+1)
plt.show()
I will not focus on the colors neither on the tick style, but know that you can change these very easily.
You can create an Axes object that will share ax's Y axis, with ax.twiny(). Then, you can add X ticks on this new Axes, which will appear on top of the plot:
from matplotlib import pyplot as plt
ax = plt.axes()
ax.set_xlim(0, 3)
ax.set_ylim(0, 3)
for i in range(3):
for j in range(3):
ax.fill_between((i, i+1), j, j+1)
ax.fill_between((i, i+1), j, j+1)
ax.fill_between((i, i+1), j, j+1)
ax2 = ax.twiny()
ax2.set_xlim(ax.get_xlim())
ax2.set_xticks([0.5, 1.5, 2.5])
ax2.set_xticklabels([13, 0, 0])
plt.show()
In order to display ticks for the X axis, you have to create an Axes object that shares ax's Y axis, with ax.twiny(). This might seem counter-intuitive, but if you used ax.twinx() instead, then modifying ax2's X ticks would modify ax's as well, because they're actually the same.
Then, you want to set the X window of ax2, so that it has three squares.
After that, you can set the ticks: one in every square, at the horizontal center, so at [0.5, 1.5, 2.5].
Finally, you can set the tick labels to display the desired value.
Then, you just do the same with the Y ticks:
from matplotlib import pyplot as plt
ax = plt.axes()
ax.set_xlim(0, 3)
ax.set_ylim(0, 3)
for i in range(3):
for j in range(3):
ax.fill_between((i, i+1), j, j+1)
ax.fill_between((i, i+1), j, j+1)
ax.fill_between((i, i+1), j, j+1)
ax2 = ax.twiny()
ax2.set_xlim(ax.get_xlim())
ax2.set_xticks([0.5, 1.5, 2.5])
ax2.set_xticklabels([13, 0, 0])
ax3 = ax.twinx()
ax3.set_ylim(ax.get_ylim())
ax3.set_yticks([0.5, 1.5, 2.5])
ax3.set_yticklabels([0, 6, 9])
plt.show()
A rather manual approach would consist of combining the following items until the result is satisfactory:
using twinx and twiny to get new axes on the top and on the right: twinax = ax.twinx().twiny()
using twinax.set(xlim=ax.get_xlim(), ylim=ax.get_ylim()) to match their range with the range of the original axes, then...
using twinax.set(xticks=ax.get_xticks(), yticks=ax.get_yticks, xticklabels=('0','1','2'), yticklabels = ('0','1','2')) to set the labels on the new axes as was done in your example (these two calls can be combined if you like).
(If you don't want the actual ticks (only the labels) you can give them 0 length through tick_params.)
You can reposition the axes with set_position.
See this question for info on how to move the colorbar.
Related
I'm trying to make a matplotlib plot with a second y-axis. This works so far, but I was wondering, wether it was possible to shorten the second y-axis?
Furthermore, I struggle on some other formatting issues.
a) I want to draw an arrow on the second y-axis, just as drawn on the first y-axis.
b) I want to align the second y-axis at -1, so that the intersection of x- and 2nd y-axis is at(...; -1)
c) The x-axis crosses the x- and y-ticks at the origin, which I want to avoid.
d) How can I get a common legend for both y-axis?
Here is my code snippet so far.
fig, ax = plt.subplots()
bx = ax.twinx() # 2nd y-axis
ax.spines['bottom'].set_position(('data',0))
ax.spines['left'].set_position(('data',0))
ax.xaxis.set_ticks_position('bottom')
bx.spines['left'].set_position(('data',-1))
bx.spines['bottom'].set_position(('data',-1))
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
bx.spines["top"].set_visible(False)
bx.spines["bottom"].set_visible(False)
bx.spines["left"].set_visible(False)
## Graphs
x_val = np.arange(0,10)
y_val = 0.1*x_val
ax.plot(x_val, y_val, 'k--')
bx.plot(x_val, -y_val+1, color = 'purple')
## Arrows
ms=2
#ax.plot(1, 0, ">k", ms=ms, transform=ax.get_yaxis_transform(), clip_on=False)
ax.plot(0, 1, "^k", ms=ms, transform=ax.get_xaxis_transform(), clip_on=False)
bx.plot(1, 1, "^k", ms=ms, transform=bx.get_xaxis_transform(), clip_on=False)
plt.ylim((-1, 1.2))
bx.set_yticks([-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5])
## Legend
ax.legend([r'$A_{hull}$'], frameon=False,
loc='upper left', bbox_to_anchor=(0.2, .75))
plt.show()
I've uploaded a screenshot of my plot so far, annotating the questioned points.
EDIT: I've changed the plotted values in the code snippet so that the example is easier to reproduce. However, the question is more or less only related to formatting issues so that the acutual values are not too important. Image is not changed, so don't be surprised when plotting the edited values, the graphs will look differently.
To avoid the strange overlap at x=0 and y=0, you could leave out the calls to ax.spines[...].set_position(('data',0)). You can change the transforms that place the arrows. Explicitly setting the x and y limits to start at 0 will also have the spines at those positions.
ax2.set_bounds(...) shortens the right y-axis.
To put items in the legend, each plotted item needs a label. get_legend_handles_labels can fetch the handles and labels of both axes, which can be combined in a new legend.
Renaming bx to something like ax2 makes the code easier to compare with existing example code. In matplotlib it often also helps to first put the plotting code and only later changes to limits, ticks and embellishments.
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
fig, ax = plt.subplots()
ax2 = ax.twinx() # 2nd y-axis
## Graphs
x_val = np.arange(0, 10)
y_val = 0.1 * x_val
ax.plot(x_val, y_val, 'k--', label=r'$A_{hull}$')
ax2.plot(x_val, -y_val + 1, color='purple', label='right axis')
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax2.spines["top"].set_visible(False)
ax2.spines["bottom"].set_visible(False)
ax2.spines["left"].set_visible(False)
ax2_upper_bound = 0.55
ax2.spines["right"].set_bounds(-1, ax2_upper_bound) # shorten the right y-axis
## add arrows to spines
ms = 2
# ax.plot(1, 0, ">k", ms=ms, transform=ax.get_yaxis_transform(), clip_on=False)
ax.plot(0, 1, "^k", ms=ms, transform=ax.transAxes, clip_on=False)
ax2.plot(1, ax2_upper_bound, "^k", ms=ms, transform=ax2.get_yaxis_transform(), clip_on=False)
# set limits to the axes
ax.set_xlim(xmin=0)
ax.set_ylim(ymin=0)
ax2.set_ylim((-1, 1.2))
ax2.set_yticks(np.arange(-1, 0.5001, 0.25))
## Legend
handles1, labels1 = ax.get_legend_handles_labels()
handles2, labels2 = ax2.get_legend_handles_labels()
ax.legend(handles1 + handles2, labels1 + labels2, frameon=False,
loc='upper left', bbox_to_anchor=(0.2, .75))
plt.show()
I am trying to place multiple histograms in a vertical stack. I am able to get the plots in a stack but then all the histograms are in the same graph.
fig, ax = plt.subplots(2, 1, sharex=True, figsize=(20, 18))
n = 2
axes = ax.flatten()
for i, j in zip(range(n), axes):
plt.hist(np.array(dates[i]), bins=20, alpha=0.3)
plt.show()
You have a 2x1 grid of axis objects. By directly looping over axes.flatten(), you are accessing one subplot at a time. You then need to use the corresponding axis instance for plotting the histogram.
fig, axes = plt.subplots(2, 1, sharex=True, figsize=(20, 18)) # axes here
n = 2
for i, ax in zip(range(n), axes.flatten()): # <--- flatten directly in the zip
ax.hist(np.array(dates[i]), bins=20, alpha=0.3)
Your original version was also correct except the fact that instead of plt, you just should have used the correct variable i.e. j instead of plt
for i, j in zip(range(n), axes):
j.hist(np.array(dates[i]), bins=20, alpha=0.3)
I have a Matplotlib function to create a confusion matrix and save it to a file:
def pretty_print_conf_matrix(y_true, y_pred,
classes,
normalize=False,
title='{} Confusion matrix'.format(describe_test_setup()),
cmap=plt.cm.Blues,
out_dir=None):
"""
Code adapted from: http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py
"""
fig, ax = plt.subplots(figsize=(15, 15))
cm = confusion_matrix(y_true, y_pred, labels=classes)
# Configure Confusion Matrix Plot Aesthetics (no text yet)
cax = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.set_title(title, fontsize=16)
ax.set_xticks(np.arange(len(classes)))
ax.set_yticks(np.arange(len(classes)))
ax.set_xticklabels(classes)
ax.set_yticklabels(classes)
ax.tick_params(axis='x', labelsize=14)
ax.tick_params(axis='y', labelsize=14)
plt.setp(ax.get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor')
plt.colorbar(cax)
ax.set_ylabel('True label', fontsize=16)
ax.set_xlabel('Predicted label', fontsize=16, rotation='horizontal')
# Calculate normalized values (so all cells sum to 1) if desired
if normalize:
cm = np.round(cm.astype('float') / cm.sum(), 2) # (axis=1)[:, np.newaxis]
thresh = cm.max() / 1.5 if normalize else cm.max() / 2
# Place Numbers as Text on Confusion Matrix Plot
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
ax.text(j, i, cm[i, j],
ha="center",
va="center",
color="white" if cm[i, j] > thresh else "black",
fontsize=12)
#fig.tight_layout()
plt.show(block=False)
if out_dir is not None:
out_file = os.path.join(out_dir, 'Confusion Matrix{}.png'.format(describe_test_setup()))
fig.savefig(out_file, dpi=300)
This works well on two of my machines, but on the third it produces ugly squashed images. They are all running the same source code.
Example of it working properly (resolution 4500x4500):
Example of it working poorly (resolution 1028x715):
I thought this could be caused by me running different matplotlib versions, but using pip freeze I can see matplotlib==3.1.2 on both machines.
Any ideas what the cause might be?
I'm using nested GridSpecFromSubplotSpec to create a nested grid of axes. I have two independent set of axes, a top one and a bottom one. Each set has four axes, arranged in a 2x2 grid.
Here is the code I'm using and the result I obtain:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gsp
fig = plt.figure()
global_gsp = gsp.GridSpec(2, 1)
for i in range(2):
axes = np.empty(shape=(2, 2), dtype=object)
local_gsp = gsp.GridSpecFromSubplotSpec(2, 2, subplot_spec=global_gsp[i])
for j in range(2):
for k in range(2):
ax = plt.Subplot(fig, local_gsp[j, k],
sharex=axes[0, 0], sharey=axes[0, 0])
fig.add_subplot(ax)
axes[j, k] = ax
for j in range(2):
for k in range(2):
ax = axes[j, k]
x = i + np.r_[0:1:11j]
y = 10*i + np.random.random(11)
ax.plot(x, y, color=f'C{i}')
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.show()
As you can see, the top set has blue lines, the bottom set has orange lines, and the blue lines are well represented with the limits [0, 1]x[0, 1], while the orange lines are represented with the limits [1, 2]x[10, 11]. When I create the subplots with plt.Subplot, I use the sharex and sharey arguments to have exactly the same scale on all four axes in each set (but different scale across different sets).
I would like to aviod the repetition of the label and the ticks of each axis. How can I achieve that?
Subplot axes have functions is_{first,last}_{col,row}() (although I could not find the documentation anywhere) as shown in this matplotlib tutorial. These functions are useful to only print the label(s) and/or ticks in the right spot. To hide the tick labels, shared_axis_demo.py recommends using setp(ax.get_{x,y}ticklabels(), visible=False)
fig = plt.figure()
global_gsp = gs.GridSpec(2, 1)
for i in range(2):
axes = np.empty(shape=(2, 2), dtype=object)
local_gsp = gs.GridSpecFromSubplotSpec(2, 2, subplot_spec=global_gsp[i])
for j in range(2):
for k in range(2):
ax = plt.Subplot(fig, local_gsp[j, k],
sharex=axes[0, 0], sharey=axes[0, 0])
fig.add_subplot(ax)
axes[j, k] = ax
for j in range(2):
for k in range(2):
ax = axes[j, k]
x = i + np.r_[0:1:11j]
y = 10*i + np.random.random(11)
ax.plot(x, y, color=f'C{i}')
#
# adjust axes and tick labels here
#
if ax.is_last_row():
ax.set_xlabel('x')
else:
plt.setp(ax.get_xticklabels(), visible=False)
if ax.is_first_col():
ax.set_ylabel('y')
else:
plt.setp(ax.get_yticklabels(), visible=False)
fig.tight_layout()
plt.show()
I'm trying to add a subplot onto the bottom of a figure of subplots. The issue I've got is that all of the first set of subplots want to share their x axis, but not the bottom one.
channels are the subplots that want to share the x axis.
So how can I add a subplot that does not share the x axis?
This is my code:
def plot(reader, tdata):
'''function to plot the channels'''
channels=[]
for i in reader:
channels.append(i)
fig, ax = plt.subplots(len(channels)+1, sharex=False, figsize=(30,16), squeeze=False)
plot=0
#where j is the channel name
for i, j in enumerate(reader):
y=reader["%s" % j]
ylim=np.ceil(np.nanmax(y))
x=range(len((reader["%s" % j])))
ax[plot,0].plot(y, lw=1, color='b')
ax[plot,0].set_title("%s" % j)
ax[plot,0].set_xlabel('Time / s')
ax[plot,0].set_ylabel('%s' % units[i])
ax[plot,0].set_ylim([np.nanmin(y), ylim+(ylim/100)*10])
plot=plot+1
###here is the new subplot that doesn't want to share the x axis###
ax[plot, 0].plot()
plt.tight_layout()
plt.show()
This code does NOT work because they are shareing the x-axis of the last subplot. The length of channels changes dependant onn what I specify earlier in the code.
Is it a valid option to use add_subplot somehow, ever though I don't have a constant number of channels?
thanks so much for any help
EDIT
Image for Joe:
It's easiest to use fig.add_subplot directly in this case.
As a quick example:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(6, 8))
# Axes that share the x-axis
ax = fig.add_subplot(4, 1, 1)
axes = [ax] + [fig.add_subplot(4, 1, i, sharex=ax) for i in range(2, 4)]
# The bottom independent axes
axes.append(fig.add_subplot(4, 1, 4))
# Let's hide the tick labels for all but the last shared-x axes
for ax in axes[:2]:
plt.setp(ax.get_xticklabels(), visible=False)
# And plot on the first subplot just to demonstrate that the axes are shared
axes[0].plot(range(21), color='lightblue', lw=3)
plt.show()