I wrote the following code based on the matplotlib site example.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
fig = plt.figure()
nFreqs = 1024
nFFTWindows = 512
viewport = np.ones((nFreqs, nFFTWindows))
im = plt.imshow(viewport, animated=True)
def updatefig(*args):
global viewport
print viewport
viewport = np.roll(viewport, -1, axis=1)
viewport[:, -1] = 0
im.set_array(viewport)
return im,
ani = animation.FuncAnimation(fig, updatefig, interval=50, blit=True)
plt.show()
Before changing the animation works, but now it doesn't. I expected it to start with a purple plot, which slowly turns yellow from the right edge to the left. The viewport variable does update correctly (checked it with print in my function).
I get the static image (all ones, like it was initially):
Where did I go wrong here?
The problem is you are defining a plot initially with a single colour (1.0) so the colour range is set to this. When you update the figure, the range of colours is 1.0 +- some small value so you don't see the change. You need to set the colour range to between one and zero with vmin/vmax arguments as follows:
im = plt.imshow(viewport, animated=True, vmin=0., vmax=1.)
The rest of the code stays the same and this should work as expected. Another alternative is to add the call,
im.autoscale()
after im.set_array(viewpoint) to force the colour range to be updated each time.
The imshow plot is initialized with one single value (1 in this case), so any value normalized to the range between 1 and 1 becomes the same color.
In order to change this, you may
initiate the imshowplot with limits for the color (vmin=0, vmax=1).
initiate the imshow plot with a normalization instance
norm = matplotlib.colors.Normalize(vmin=0, vmax=1)
im = plt.imshow(arr, norm=norm)
Set the limits afterwards using im.set_clim(0,1).
Preferences > IPython Console > Graphics > Backend and change it from "Inline" to "Automatic"
Do not forget to restart you IDE (Spyder, PyCharm, etc.) after applying above change.
Cheers
:)
Related
I have a figure which contains a labelled colourbar below the x axis of the main plot. When I attempt to save this using plt.savefig(), the very bottom of the subscript character in the label is cropped from the saved image, like this, despite using bbox_inches="tight". However, if I simply save the figure manually in the pop-up window, the subscript character is not cropped, like this.
Although the latter image could be manually cropped, or cropped using additional lines in the code, I would be grateful for any advice on how to resolve this issue without the need for this additional work.
I have tried to add a line break to the colourbar label like so:
label="$U/U_{"+(u"\u221e")+"}$\n"
But this simply adds white space below the label; the bottom of the subscript character is still cropped.
I have also tried to add the line:
cb.set_label(label,labelpad=5)
But this simply offsets the label from the bottom of the colourbar; no additional padding is provided below the label to fully display the subscript character.
The code is below:
import numpy
import random
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as mcolors
import matplotlib.colorbar as cbar
from matplotlib import cm
##########################################################
# Centre colourmap so 0=white
class MidpointNormalize(mpl.colors.Normalize):
def __init__(self,vmin=None,vmax=None,midpoint=None,clip=False):
self.midpoint=midpoint
mpl.colors.Normalize.__init__(self,vmin,vmax,clip)
def __call__(self,value,clip=None):
x,y=[self.vmin,self.midpoint,self.vmax],[0,0.5,1]
return numpy.ma.masked_array(numpy.interp(value,x,y),numpy.isnan(value))
##########################################################
# Set min and max values
xymin=0
xymax=10
valmin=-5
valmax=5
val=numpy.zeros((xymax,xymax),dtype=float)
# Configure plot
fig,ax=plt.subplots()
ax.set_xlim([xymin,xymax])
ax.set_ylim([xymin,xymax])
# Configure colour bar
colours=plt.cm.RdBu(numpy.linspace(0,1,256))
colourmap=mcolors.LinearSegmentedColormap.from_list('colourmap',colours)
normalisecolors=mpl.colors.Normalize(vmin=valmin,vmax=valmax)
scalecolors=cm.ScalarMappable(norm=normalisecolors,cmap=colourmap)
label="$U/U_{"+(u"\u221e")+"}$"
for ix in range(xymin,xymax):
for iy in range(xymin,xymax):
xlow=ix*+1 # Calculate vertices of patch
xhigh=(ix*1)+1
ylow=iy*1
yhigh=(iy*1)+1
val[ix][iy]=random.randint(valmin,valmax) # Patch value
rgbacolor=scalecolors.to_rgba(val[ix][iy]) # Calculate RGBA colour for value
ax.add_patch(patches.Polygon([(xlow,ylow),(xlow,yhigh),(xhigh,yhigh),(xhigh,ylow)],fill=True,facecolor=rgbacolor)) # Add value as polygon patch
cax,_=cbar.make_axes(ax,orientation="horizontal")
cb=cbar.ColorbarBase(cax,cmap=colourmap,norm=MidpointNormalize(midpoint=0,vmin=valmin,vmax=valmax),orientation="horizontal",label=label)
plt.savefig("C:/Users/Christopher/Desktop/test.png",dpi=1200,bbox_inches="tight")
plt.clf
plt.close()
I'm afraid I don't really have a good answer for you. This appears to be related to this bug https://github.com/matplotlib/matplotlib/issues/15313
The good news is that it is being worked on, the bad news is that there is no fix as of yet.
Two points to consider anyway (based on reading the thread on github):
the higher the dpi, the worst it is. So you may want to save at a lower dpi (300 works fine for me)
the problem is not present on the pdf backend, so you could save your plot in pdf (and eventually convert to png if needed)
BTW (this is unrelated to the bug in question): I'm confused by the complexity of your code. It seems to me the following code produces the same output:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
N=10
valmin=-5
valmax=5
valmid=0
val=np.random.randint(low=valmin, high=valmax, size=(N,N))
cmap = 'RdBu'
norm = TwoSlopeNorm(vcenter=valmid, vmin=valmin, vmax=valmax)
label="$U/U_{"+(u"\u221e")+"}$"
# Configure plot
fig, ax=plt.subplots()
im = ax.imshow(val, cmap=cmap, norm=norm, aspect='auto', origin='lower')
cbar = fig.colorbar(im, orientation='horizontal', label=label)
fig.savefig('./test-1200.png',dpi=1200,bbox_inches="tight") # subscript is cut
fig.savefig('./test-300.png',dpi=300,bbox_inches="tight") # subscript is not cut
fig.savefig('./test-pdf.pdf',dpi=1200,bbox_inches="tight") # subscript is not cut
1200 dpi:
300 dpi:
pdf:
I have two imshow() problems that I suspect are closely related.
First, I can't figure out how to use set_data() to update an image I've created with imshow().
Second, I can't figure out why the colorbar I add to the imshow() plot after I'm done updating the plot doesn't match the colorbar I add to an imshow() plot of the same data that I create from scratch after I'm done taking data. The colorbar of the second plot appears to be correct.
Background.
I'm collecting measurement data in two nested loops, with each loop controlling one of the measurement conditions. I'm using pyplot.imshow() to plot my results, and I'm updating the imshow() plot every time I take data in the inner loop.
What I have works in terms of updating the imshow() plot but it seems to be getting increasingly slower as I add more loop iterations, so it's not scaling well. (The program I've included in with this post creates a plot that is eight rows high and six columns wide. A "real" plot might be 10x or 20x this size, in both dimensions.)
I think what I want to do is use the image's set_data() method but I can't figure out how. What I've tried either throws an error or doesn't appear to have any effect on the imshow() plot.
Once I'm done with the "take the data" loops, I add a colorbar to the imshow() plot I've been updating. However, the colorbar scale is obviously bogus.
In contrast, if I take create an entirely new imshow() plot, using the data I took in the loops, and then add a colorbar, the colorbar appears to be correct.
My hunch is the problem is with the vmin and vmax values associated with the imshow() plot I'm updating in the loops but I can't figure out how to fix it.
I've already looked at several related StackOverflow posts. For example:
update a figure made with imshow(), contour() and quiver()
Update matplotlib image in a function
How to update matplotlib's imshow() window interactively?
These have helped, in that they've pointed me to set_data() and given me solutions to some other
problems I had, but I still have the two problems I mentioned at the start.
Here's a simplified version of my code. Note that there are repeated zero values on the X and Y axes. This is on purpose.
I'm running Python 3.5.1, matplotlib 1.5.1, and numpy 1.10.4. (Yes, some of these are quite old. Corporate IT reasons.)
import numpy as np
import matplotlib.pyplot as plt
import random
import time
import warnings
warnings.filterwarnings("ignore", ".*GUI is implemented.*") # Filter out bogus matplotlib warning.
# Create the simulated data for plotting
v_max = 120
v_step_size = 40
h_max = 50
h_step_size = 25
scale = 8
v_points = np.arange(-1*abs(v_max), 0, abs(v_step_size))
v_points = np.append(v_points, [-0.0])
reversed_v_points = -1 * v_points[::-1] # Not just reverse order, but reversed sign
v_points = np.append(v_points, reversed_v_points)
h_points = np.arange(-1*abs(h_max), 0, abs(h_step_size))
h_points = np.append(h_points, [-0.0])
reversed_h_points = -1 * h_points[::-1] # Not just reverse order, but reversed sign
h_points = np.append(h_points, reversed_h_points)
h = 0 # Initialize
v = 0 # Initialize
plt.ion() # Turn on interactive mode.
fig, ax = plt.subplots() # So I have access to the figure and the axes of the plot.
# Initialize the data_points
data_points = np.zeros((v_points.size, h_points.size))
im = ax.imshow(data_points, cmap='hot', interpolation='nearest') # Specify the color map and interpolation
ax.set_title('Dummy title for initial plot')
# Set up the X-axis ticks and label them
ax.set_xticks(np.arange(len(h_points)))
ax.set_xticklabels(h_points)
ax.set_xlabel('Horizontal axis measurement values')
# Set up the Y-axis ticks and label them
ax.set_yticks(np.arange(len(v_points)))
ax.set_yticklabels(v_points)
ax.set_ylabel('Vertical axis measurement values')
plt.pause(0.0001) # In interactive mode, need a small delay to get the plot to appear
plt.show()
for v, v_value in enumerate(v_points):
for h, h_value in enumerate(h_points):
# Measurement goes here.
time.sleep(0.1) # Simulate the measurement delay.
measured_value = scale * random.uniform(0.0, 1.0) # Create simulated data
data_points[v][h] = measured_value # Update data_points with the simulated data
# Update the heat map with the latest point.
# - I *think* I want to use im.set_data() here, not ax.imshow(), but how?
ax.imshow(data_points, cmap='hot', interpolation='nearest') # Specify the color map and interpolation
plt.pause(0.0001) # In interactive mode, need a small delay to get the plot to appear
plt.draw()
# Create a colorbar
# - Except the colorbar here is wrong. It goes from -0.10 to +0.10 instead
# of matching the colorbar in the second imshow() plot, which goes from
# 0.0 to "scale". Why?
cbar = ax.figure.colorbar(im, ax=ax)
cbar.ax.set_ylabel('Default heatmap colorbar label')
plt.pause(0.0001) # In interactive mode, need a small delay to get the colorbar to appear
plt.show()
fig2, ax2 = plt.subplots() # So I have access to the figure and the axes of the plot.
im = ax2.imshow(data_points, cmap='hot', interpolation='nearest') # Specify the color map and interpolation
ax2.set_title('Dummy title for plot with pseudo-data')
# Set up the X-axis ticks and label them
ax2.set_xticks(np.arange(len(h_points)))
ax2.set_xticklabels(h_points)
ax2.set_xlabel('Horizontal axis measurement values')
# Set up the Y-axis ticks and label them
ax2.set_yticks(np.arange(len(v_points)))
ax2.set_yticklabels(v_points)
ax2.set_ylabel('Vertical axis measurement values')
# Create a colorbar
cbar = ax2.figure.colorbar(im, ax=ax2)
cbar.ax.set_ylabel('Default heatmap colorbar label')
plt.pause(0.0001) # In interactive mode, need a small delay to get the plot to appear
plt.show()
dummy = input("In interactive mode, press the Enter key when you're done with the plots.")
OK, I Googled some more. More importantly, I Googled smarter and figured out my own answer.
To update my plot inside my nested loops, I was using the command:
ax.imshow(data_points, cmap='hot', interpolation='nearest') # Specify the color map and interpolation
What I tried using to update my plot more efficiently was:
im.set_data(data_points)
What I should have used was:
im.set_data(data_points)
im.autoscale()
This updates the pixel scaling, which fixed both my "plot doesn't update" problem and my "colorbar has the wrong scale" problem.
I have N=lots of 256x256 images (grayscales saved as numpy.ndarray with shape=(N, 256, 256)) and want to look at all of them by use of animation. I also want to add some label showing details related to each of the images, such as its index, its maximum value, etc. I'm using matplotlib, which I'm not familiar with.
There are a number of StackOverflow topics concerned with this exact problem (e.g. 1, 2, 4), as well as numerous tutorials (e.g. 3). I pieced together below attempts at solving the problem from these sources.
The two possibilities I have tried are using the matplotlib.animation classes FuncAnimation and ArtistAnimation. I'm not happy with my solutions because:
I have not been able to display and animate text information together with the images. I can display animated text on top of the images using axes.text but don't know how to put text next to the image.
I strongly dislike the FuncAnimation solution for aesthetic reasons (use of global variables, etc.)
I also want an animated colorbar. I think this is possible (somehow) with FuncAnimation but I don't see how it is possible with ArtistAnimation
ArtistAnimation gets slow since a large number of Artists (each picture) are required
# python 3.6
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, ArtistAnimation
images = np.random.rand(1000,256,256)
fig, ax = plt.subplots()
# ####################### Solution using ArtistAnimation ##################################################
# for much larger numbers of pictures this gets very slow
# How do I display information about the current picture as text next to the plot?
ims = []
for i in range(images.shape[0]):
ims.append([plt.imshow(images[i], animated=True)])
ani = ArtistAnimation(fig, ims, interval=250, blit=True, repeat_delay=5000)
plt.show()
# ####################### Solution using FuncAnimation ##################################################
# I don't like to use global variables in principle (but still want to know how to make this work).
# I can't figure out a way to display text while animating.
# Here I try to animate title and return it from update_figure (since it's an Artist and should update?!) but it has no effect.
nof_frames = images.shape[0]
i = 0
im = plt.imshow(images[0], animated=True)
# I do know that variables that aren't changed need not be declared global.
# However, I want to mark them and don't like accessing global variables in the first place.
def update_figure(frame, *frargs):
global i, nof_frames, ax, images, im
if i < nof_frames - 1:
i += 1
else:
i = 0
im.set_array(images[i])
ax.set_title(str(i)) # this has no effect
return im, ax
ani = FuncAnimation(fig, update_figure, interval=300, blit=True)
plt.show()
Question: Is there a way to check if a color bar already exists?
I am making many plots with a loop. The issue is that the color bar is drawn every iteration!
If I could determine if the color bar exists then I can put the color bar function in an if statement.
if cb_exists:
# do nothing
else:
plt.colorbar() #draw the colorbar
If I use multiprocessing to make the figures, is it possible to prevent multiple color bars from being added?
import numpy as np
import matplotlib.pyplot as plt
import multiprocessing
def plot(number):
a = np.random.random([5,5])*number
plt.pcolormesh(a)
plt.colorbar()
plt.savefig('this_'+str(number))
# I want to make a 50 plots
some_list = range(0,50)
num_proc = 5
p = multiprocessing.Pool(num_proc)
temps = p.map(plot, some_list)
I realize I can clear the figure with plt.clf() and plt.cla() before plotting the next iteration. But, I have data on my basemap layer I don't want to re-plot (that adds to the time it takes to create the plot). So, if I could remove the colorbar and add a new one I'd save some time.
Is is actually not easy to remove a colorbar from a plot and later draw a new one to it.
The best solution I can come up with at the moment is the following, which assumes that there is only one axes present in the plot. Now, if there was a second axis, it must be the colorbar beeing present. So by checking how many axes we find on the plot, we can judge upon whether or not there is a colorbar.
Here we also mind the user's wish not to reference any named objects from outside. (Which does not makes much sense, as we need to use plt anyways, but hey.. so was the question)
import matplotlib.pyplot as plt
import numpy as np
fig, ax = plt.subplots()
im = ax.pcolormesh(np.array(np.random.rand(2,2) ))
ax.plot(np.cos(np.linspace(0.2,1.8))+0.9, np.sin(np.linspace(0.2,1.8))+0.9, c="k", lw=6)
ax.set_title("Title")
cbar = plt.colorbar(im)
cbar.ax.set_ylabel("Label")
for i in range(10):
# inside this loop we should not access any variables defined outside
# why? no real reason, but questioner asked for it.
#draw new colormesh
im = plt.gcf().gca().pcolormesh(np.random.rand(2,2))
#check if there is more than one axes
if len(plt.gcf().axes) > 1:
# if so, then the last axes must be the colorbar.
# we get its extent
pts = plt.gcf().axes[-1].get_position().get_points()
# and its label
label = plt.gcf().axes[-1].get_ylabel()
# and then remove the axes
plt.gcf().axes[-1].remove()
# then we draw a new axes a the extents of the old one
cax= plt.gcf().add_axes([pts[0][0],pts[0][1],pts[1][0]-pts[0][0],pts[1][1]-pts[0][1] ])
# and add a colorbar to it
cbar = plt.colorbar(im, cax=cax)
cbar.ax.set_ylabel(label)
# unfortunately the aspect is different between the initial call to colorbar
# without cax argument. Try to reset it (but still it's somehow different)
cbar.ax.set_aspect(20)
else:
plt.colorbar(im)
plt.show()
In general a much better solution would be to operate on the objects already present in the plot and only update them with the new data. Thereby, we suppress the need to remove and add axes and find a much cleaner and faster solution.
import matplotlib.pyplot as plt
import numpy as np
fig, ax = plt.subplots()
im = ax.pcolormesh(np.array(np.random.rand(2,2) ))
ax.plot(np.cos(np.linspace(0.2,1.8))+0.9, np.sin(np.linspace(0.2,1.8))+0.9, c="k", lw=6)
ax.set_title("Title")
cbar = plt.colorbar(im)
cbar.ax.set_ylabel("Label")
for i in range(10):
data = np.array(np.random.rand(2,2) )
im.set_array(data.flatten())
cbar.set_clim(vmin=data.min(),vmax=data.max())
cbar.draw_all()
plt.draw()
plt.show()
Update:
Actually, the latter approach of referencing objects from outside even works together with the multiprocess approach desired by the questioner.
So, here is a code that updates the figure, without the need to delete the colorbar.
import matplotlib.pyplot as plt
import numpy as np
import multiprocessing
import time
fig, ax = plt.subplots()
im = ax.pcolormesh(np.array(np.random.rand(2,2) ))
ax.plot(np.cos(np.linspace(0.2,1.8))+0.9, np.sin(np.linspace(0.2,1.8))+0.9, c="w", lw=6)
ax.set_title("Title")
cbar = plt.colorbar(im)
cbar.ax.set_ylabel("Label")
tx = ax.text(0.2,0.8, "", fontsize=30, color="w")
tx2 = ax.text(0.2,0.2, "", fontsize=30, color="w")
def do(number):
start = time.time()
tx.set_text(str(number))
data = np.array(np.random.rand(2,2)*(number+1) )
im.set_array(data.flatten())
cbar.set_clim(vmin=data.min(),vmax=data.max())
tx2.set_text("{m:.2f} < {ma:.2f}".format(m=data.min(), ma= data.max() ))
cbar.draw_all()
plt.draw()
plt.savefig("multiproc/{n}.png".format(n=number))
stop = time.time()
return np.array([number, start, stop])
if __name__ == "__main__":
multiprocessing.freeze_support()
some_list = range(0,50)
num_proc = 5
p = multiprocessing.Pool(num_proc)
nu = p.map(do, some_list)
nu = np.array(nu)
plt.close("all")
fig, ax = plt.subplots(figsize=(16,9))
ax.barh(nu[:,0], nu[:,2]-nu[:,1], height=np.ones(len(some_list)), left=nu[:,1], align="center")
plt.show()
(The code at the end shows a timetable which allows to see that multiprocessing has indeed taken place)
If you can access to axis and image information, colorbar can be retrieved
as a property of the image (or the mappable to which associate colorbar).
Following a previous answer (How to retrieve colorbar instance from figure in matplotlib), an example could be:
ax=plt.gca() #plt.gca() for current axis, otherwise set appropriately.
im=ax.images #this is a list of all images that have been plotted
if im[-1].colorbar is None: #in this case I assume to be interested to the last one plotted, otherwise use the appropriate index or loop over
plt.colorbar() #plot a new colorbar
Note that an image without colorbar returns None to im[-1].colorbar
One approach is:
initially (prior to having any color bar drawn), set a variable
colorBarPresent = False
in the method for drawing the color bar, check to see if it's already drawn. If not, draw it and set the colorBarPresent variable True:
def drawColorBar():
if colorBarPresent:
# leave the function and don't draw the bar again
else:
# draw the color bar
colorBarPresent = True
There is an indirect way of guessing (with reasonable accuracy for most applications, I think) whether an Axes instance is home to a color bar. Depending on whether it is a horizontal or vertical color bar, either the X axis or Y axis (but not both) will satisfy all of these conditions:
No ticks
No tick labels
No axis label
Axis range is (0, 1)
So here's a function for you:
def is_colorbar(ax):
"""
Guesses whether a set of Axes is home to a colorbar
:param ax: Axes instance
:return: bool
True if the x xor y axis satisfies all of the following and thus looks like it's probably a colorbar:
No ticks, no tick labels, no axis label, and range is (0, 1)
"""
xcb = (len(ax.get_xticks()) == 0) and (len(ax.get_xticklabels()) == 0) and (len(ax.get_xlabel()) == 0) and \
(ax.get_xlim() == (0, 1))
ycb = (len(ax.get_yticks()) == 0) and (len(ax.get_yticklabels()) == 0) and (len(ax.get_ylabel()) == 0) and \
(ax.get_ylim() == (0, 1))
return xcb != ycb # != is effectively xor in this case, since xcb and ycb are both bool
Thanks to this answer for the cool != xor trick: https://stackoverflow.com/a/433161/6605826
With this function, you can see if a colorbar exists by:
colorbar_exists = any([is_colorbar(ax) for ax in np.atleast_1d(gcf().axes).flatten()])
or if you're sure the colorbar will always be last, you can get off easy with:
colorbar_exists = is_colorbar(gcf().axes[-1])
Does someone perhaps know if it is possible to rotate a legend on a plot in matplotlib? I made a simple plot with the below code, and edited the graph in paint to show what I want.
plt.plot([4,5,6], label = 'test')
ax = plt.gca()
ax.legend()
plt.show()
I went to a similar problem and solved it by writing the function legendAsLatex that generates a latex code to be used as the label of the y-axis. The function gathers the color, the marker, the line style, and the label provided to the plot function. It requires enabling the latex and loading the required packages. Here is the code to generate your plot with extra curves that use both vertical axis.
from matplotlib import pyplot as plt
import matplotlib.colors as cor
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath} \usepackage{wasysym}'+
r'\usepackage[dvipsnames]{xcolor} \usepackage{MnSymbol} \usepackage{txfonts}')
def legendAsLatex(axes, rotation=90) :
'''Generate a latex code to be used instead of the legend.
Uses the label, color, marker and linestyle provided to the pyplot.plot.
The marker and the linestyle must be defined using the one or two character
abreviations shown in the help of pyplot.plot.
Rotation of the markers must be multiple of 90.
'''
latexLine = {'-':'\\textbf{\Large ---}',
'-.':'\\textbf{\Large --\:\!$\\boldsymbol{\cdot}$\:\!--}',
'--':'\\textbf{\Large --\,--}',':':'\\textbf{\Large -\:\!-}'}
latexSymbol = {'o':'medbullet', 'd':'diamond', 's':'filledmedsquare',
'D':'Diamondblack', '*':'bigstar', '+':'boldsymbol{\plus}',
'x':'boldsymbol{\\times}', 'p':'pentagon', 'h':'hexagon',
',':'boldsymbol{\cdot}', '_':'boldsymbol{\minus}','<':'LHD',
'>':'RHD','v':'blacktriangledown', '^':'blacktriangle'}
rot90=['^','<','v','>']
di = [0,-1,2,1][rotation%360//90]
latexSymbol.update({rot90[i]:latexSymbol[rot90[(i+di)%4]] for i in range(4)})
return ', '.join(['\\textcolor[rgb]{'\
+ ','.join([str(x) for x in cor.to_rgb(handle.get_color())]) +'}{'
+ '$\\'+latexSymbol.get(handle.get_marker(),';')+'$'
+ latexLine.get(handle.get_linestyle(),'') + '} ' + label
for handle,label in zip(*axes.get_legend_handles_labels())])
ax = plt.axes()
ax.plot(range(0,10), 'b-', label = 'Blue line')
ax.plot(range(10,0,-1), 'sm', label = 'Magenta squares')
ax.set_ylabel(legendAsLatex(ax))
ax2 = plt.twinx()
ax2.plot([x**0.5 for x in range(0,10)], 'ro', label = 'Red circles')
ax2.plot([x**0.5 for x in range(10,0,-1)],'g--', label = 'Green dashed line')
ax2.set_ylabel(legendAsLatex(ax2))
plt.savefig('legend.eps')
plt.close()
Figure generated by the code:
I spent a few hours chipping away at this yesterday, and made a bit of progress so I'll share that below along with some suggestions moving forward.
First, it seems that we can certainly rotate and translate the bounding box (bbox) or frame around the legend. In the first example below you can see that a transform can be applied, albeit requiring some oddly large translation numbers after applying the 90 degree rotation. But, there are actually problems saving the translated legend frame to an image file so I had to take a screenshot from the IPython notebook. I've added some comments as well.
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import matplotlib.transforms
fig = plt.figure()
ax = fig.add_subplot('121') #make room for second subplot, where we are actually placing the legend
ax2 = fig.add_subplot('122') #blank subplot to make space for legend
ax2.axis('off')
ax.plot([4,5,6], label = 'test')
transform = matplotlib.transforms.Affine2D(matrix=np.eye(3)) #start with the identity transform, which does nothing
transform.rotate_deg(90) #add the desired 90 degree rotation
transform.translate(410,11) #for some reason we need to play with some pretty extreme translation values to position the rotated legend
legend = ax.legend(bbox_to_anchor=[1.5,1.0])
legend.set_title('test title')
legend.get_frame().set_transform(transform) #This actually works! But, only for the frame of the legend (see below)
frame = legend.get_frame()
fig.subplots_adjust(wspace = 0.4, right = 0.9)
fig.savefig('rotate_legend_1.png',bbox_extra_artists=(legend,frame),bbox_inches='tight', dpi = 300) #even with the extra bbox parameters the legend frame is still getting clipped
Next, I thought it would be smart to explore the get_methods() of other legend components. You can sort of dig through these things with dir(legend) and legend.__dict__ and so on. In particular, I noticed that you can do this: legend.get_title().set_transform(transform), which would seem to imply that we could translate the legend text (and not just the frame as above). Let's see what happens when I tried that:
fig2 = plt.figure()
ax = fig2.add_subplot('121')
ax2 = fig2.add_subplot('122')
ax2.axis('off')
ax.plot([4,5,6], label = 'test')
transform = matplotlib.transforms.Affine2D(matrix=np.eye(3))
transform.rotate_deg(90)
transform.translate(410,11)
legend = ax.legend(bbox_to_anchor=[1.5,1.0])
legend.set_title('test title')
legend.get_frame().set_transform(transform)
legend.get_title().set_transform(transform) #one would expect this to apply the same transformation to the title text in the legend, rotating it 90 degrees and translating it
frame = legend.get_frame()
fig2.subplots_adjust(wspace = 0.4, right = 0.9)
fig2.savefig('rotate_legend_1.png',bbox_extra_artists=(legend,frame),bbox_inches='tight', dpi = 300)
The legend title seems to have disappeared in the screenshot from the IPython notebook. But, if we look at the saved file the legend title is now in the bottom left corner and seems to have ignored the rotation component of the transformation (why?):
I had similar technical difficulties with this type of approach:
bbox = matplotlib.transforms.Bbox([[0.,1],[1,1]])
trans_bbox = matplotlib.transforms.TransformedBbox(bbox, transform)
legend.set_bbox_to_anchor(trans_bbox)
Other notes and suggestions:
It might be a sensible idea to dig into the differences in behaviour between the legend title and frame objects--why do they both accept transforms, but only the frame accepts a rotation? Perhaps it would be possible to subclass the legend object in the source code and make some adjustments.
We also need to find a solution for the rotated / translated legend frame not being saved to output, even after following various related suggestion on SO (i.e., Matplotlib savefig with a legend outside the plot).