How to reuse figures and axes created with subplot(num= ... ) - python

I have a function I want to use several times to plot data in the same Figure/Axis.
The function toto() in the following code works fine. I use plt.figure(num='Single plot') which check, whenever it is called, if a figure with id 'Single plot' as been already created. If yes, the same figure is reused (the figure instance is global I think) :
import matplotlib.pyplot as plt
import numpy as np
def toto(x, y):
plt.figure('Single plot')
plt.plot(x, y)
x = np.linspace(0, 2)
toto(x, x)
toto(x, x**2)
plt.show()
Now, I want to use subplot()instead of figure() because it will be more useful for me. This function would be :
def toto2(x, y):
fig, axes = plt.subplots(num='Single plot 2')
axes.plot(x, y)
But the result is not as expected at all : y ticks are overprinted and the linear line is not plotted (or as been overprinted).
I'd like to understand the rationale behind this and what to modify so that it works as expected.

Every time you execute plt.subplots, you create a new figure and a new set of axes. If you want to reuse the figure, you should create the fig and axes objects outside of your function, and pass them as argument to your function.
import matplotlib.pyplot as plt
import numpy as np
def toto3(fig, axes, x, y):
axes.plot(x, y)
x = np.linspace(0, 2)
fig, axes = plt.subplots(num='Single plot 2')
toto3(fig, axes, x, x)
toto3(fig, axes, x, x**2)
plt.show()
(in this case, you could even simplify your code to only pass axes to your function, but the idea is applicable to both the fig and the axes objects)

Related

How to add axis labels to imshow plots in python?

I copied from this website (and simplified) the following code to plot the result of a function with two variables using imshow.
from numpy import exp,arange
from pylab import meshgrid,cm,imshow,contour,clabel,colorbar,axis,title,show
# the function that I'm going to plot
def z_func(x,y):
return (x+y**2)
x = arange(-3.0,3.0,0.1)
y = arange(-3.0,3.0,0.1)
X,Y = meshgrid(x, y) # grid of point
Z = z_func(X, Y) # evaluation of the function on the grid
im = imshow(Z,cmap=cm.RdBu) # drawing the function
colorbar(im) # adding the colobar on the right
show()
How do I add axis labels (like 'x' and 'y' or 'var1 and 'var2') to the plot? In R I would use xlab = 'x' within (most of) the plotting function(s).
I tried im.ylabel('y') with the
AttributeError: 'AxesImage' object has no attribute 'ylabel'
Beside this, I only found how to remove the axis labels, but not how to add them.
Bonus question: how to have the ticks range from -3 to 3 and not from 0 to 60?
To specify axes labels:
matplotlib.pyplot.xlabel for the x axis
matplotlib.pyplot.ylabel for the y axis
Regarding your bonus question, consider extent kwarg. (Thanks to #Jona).
Moreover, consider absolute imports as recommended by PEP 8 -- Style Guide for Python Code:
Absolute imports are recommended, as they are usually more readable
and tend to be better behaved (or at least give better error messages)
if the import system is incorrectly configured (such as when a
directory inside a package ends up on sys.path)
import matplotlib.pyplot as plt
import numpy as np
# the function that I'm going to plot
def z_func(x,y):
return (x+y**2)
x = np.arange(-3.0,3.0,0.1)
y = np.arange(-3.0,3.0,0.1)
X,Y = np.meshgrid(x, y) # grid of point
Z = z_func(X, Y) # evaluation of the function on the grid
plt.xlabel('x axis')
plt.ylabel('y axis')
im = plt.imshow(Z,cmap=plt.cm.RdBu, extent=[-3, 3, -3, 3]) # drawing the function
plt.colorbar(im) # adding the colobar on the right
plt.show()
and you get:

Matplotlib: animate plot_surface using ArtistAnimation [duplicate]

I am looking to create an animation in a surface plot. The animation has fixed x and y data (1 to 64 in each dimension), and reads through an np array for the z information. An outline of the code is like so:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
def update_plot(frame_number, zarray, plot):
#plot.set_3d_properties(zarray[:,:,frame_number])
ax.collections.clear()
plot = ax.plot_surface(x, y, zarray[:,:,frame_number], color='0.75')
fig = plt.figure()
ax = plt.add_subplot(111, projection='3d')
N = 64
x = np.arange(N+1)
y = np.arange(N+1)
x, y = np.meshgrid(x, y)
zarray = np.zeros((N+1, N+1, nmax+1))
for i in range(nmax):
#Generate the data in array z
#store data into zarray
#zarray[:,:,i] = np.copy(z)
plot = ax.plot_surface(x, y, zarray[:,:,0], color='0.75')
animate = animation.FuncAnimation(fig, update_plot, 25, fargs=(zarray, plot))
plt.show()
So the code generates the z data and updates the plot in FuncAnimation. This is very slow however, I suspect it is due to the plot being redrawn every loop.
I tried the function
ax.set_3d_properties(zarray[:,:,frame_number])
but it comes up with an error
AttributeError: 'Axes3DSubplot' object has no attribute 'set_3d_properties'
How can I update the data in only the z direction without redrawing the whole plot? (Or otherwise increase the framerate of the graphing procedure)
There is a lot going on under the surface when calling plot_surface. You would need to replicate all of it when trying to set new data to the Poly3DCollection.
This might actually be possible and there might also be a way to do that slightly more efficient than the matplotlib code does it. The idea would then be to calculate all the vertices from the gridpoints and directly supply them to Poly3DCollection._vec.
However, the speed of the animation is mainly determined by the time it takes to perform the 3D->2D projection and the time to draw the actual plot. Hence the above will not help much, when it comes to drawing speed.
At the end, you might simply stick to the current way of animating the surface, which is to remove the previous plot and plot a new one. Using less points on the surface will significantly increase speed though.
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
def update_plot(frame_number, zarray, plot):
plot[0].remove()
plot[0] = ax.plot_surface(x, y, zarray[:,:,frame_number], cmap="magma")
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
N = 14
nmax=20
x = np.linspace(-4,4,N+1)
x, y = np.meshgrid(x, x)
zarray = np.zeros((N+1, N+1, nmax))
f = lambda x,y,sig : 1/np.sqrt(sig)*np.exp(-(x**2+y**2)/sig**2)
for i in range(nmax):
zarray[:,:,i] = f(x,y,1.5+np.sin(i*2*np.pi/nmax))
plot = [ax.plot_surface(x, y, zarray[:,:,0], color='0.75', rstride=1, cstride=1)]
ax.set_zlim(0,1.5)
animate = animation.FuncAnimation(fig, update_plot, nmax, fargs=(zarray, plot))
plt.show()
Note that the speed of the animation itself is determined by the interval argument to FuncAnimation. In the above it is not specified and hence the default of 200 milliseconds. Depending on the data, you can still decrease this value before running into issues of lagging frames, e.g. try 40 milliseconds and adapt it depending on your needs.
animate = animation.FuncAnimation(fig, update_plot, ..., interval=40, ...)
set_3d_properties() is a function of the Poly3DCollection class, not the Axes3DSubplot.
You should run
plot.set_3d_properties(zarray[:,:,frame_number])
as you have it commented in your update function BTW, instead of
ax.set_3d_properties(zarray[:,:,frame_number])
I don't know if that will solve your problem though, but I'm not sure since the function set_3d_properties has no documentation attached. I wonder if you'd be better off trying plot.set_verts() instead.

populating matplotlib subplots through a loop and a function

I need to draw subplots of a figure through loop iterations; each iteration calls a function defined in another module (=another py file), which draws a pair of subplots. Here is what I tried -- and alas does not work:
1) Before the loop, create a figure with the adequate number of rows, and 2 columns:
import matplotlib.pyplot as plt
fig, axarr = plt.subplots(nber_rows,2)
2) Inside the loop, at iteration number iter_nber, call on the function drawing each subplot:
fig, axarr = module.graph_function(fig,axarr,iter_nber,some_parameters, some_data)
3) The function in question is basically like this; each iteration creates a pair of subplots on the same row:
def graph_function(fig,axarr,iter_nber,some_parameters, some_data):
axarr[iter_nber,1].plot(--some plotting 1--)
axarr[iter_nber,2].plot(--some plotting 2--)
return fig,axarr
This does not work. I end up with an empty figure at the end of the loop.
I have tried various combinations of the above, like leaving only axarr in the function's return argument, to no avail. Obviously I do not understand the logic of this figure and its subplots.
Any suggestions much appreciated.
The code you've posted seems largely correct. Other than the indexing, as #hitzg mentioned, nothing you're doing looks terribly out of the ordinary.
However, it doesn't make much sense to return the figure and axes array from your plotting function. (If you need access to the figure object, you can always get it through ax.figure.) It won't change anything to pass them in and return them, though.
Here's a quick example of the type of thing it sounds like you're trying to do. Maybe it helps clear some confusion?
import numpy as np
import matplotlib.pyplot as plt
def main():
nrows = 3
fig, axes = plt.subplots(nrows, 2)
for row in axes:
x = np.random.normal(0, 1, 100).cumsum()
y = np.random.normal(0, 0.5, 100).cumsum()
plot(row, x, y)
plt.show()
def plot(axrow, x, y):
axrow[0].plot(x, color='red')
axrow[1].plot(y, color='green')
main()

Matplotlib multiple lines on graph

I've been having an issue with saving matplotlib graphs as images.
The images are saving differently from what shows up when I call the .show() method on the graph.
An example is here:
http://s1.postimg.org/lbyei5cfz/blue5.png
I'm not sure what else to do. I've spent the past hours trying to figure out what's causing it, but I can't figure it out.
Here is my code in it's entirety.
import matplotlib.pyplot as plt
import random
turn = 1 #for the x values
class Graph():
def __init__(self, name, color):
self.currentValue = 5 #for the y values
self.x = [turn]
self.y = [self.currentValue]
self.name = name
self.color = color
def update(self):
if random.randint(0,1): #just to show if the graph's value goes up or down
self.currentValue += random.randint(0,10)
self.y.append(self.currentValue)
else:
self.currentValue -= random.randint(0,10)
self.y.append(self.currentValue)
self.x.append(turn)
def plot(self):
lines = plt.plot(self.x,self.y)
plt.setp(lines, 'color',self.color)
plt.savefig(self.name + str(turn))
#plt.show() will have a different result from plt.savefig(args)
graphs = [Graph("red",'r'),Graph("blue",'b'),Graph("green",'g')]
for i in range(5):
for i in graphs:
i.update() #changes the x and y value
i.plot() #saves the picture of the graph
turn += 1
Sorry if this is a stupid mistake I'm making, I just find it peculiar how plt.show() and plt.savefig are different.
Thanks for the help.
As stated correctly by David, plt.show() resets current figure. plt.savefig(), however, does not, so you need to reset it explicitly. plt.clf() or plt.figure() are two functions that can do it dor you. Just insert the call right after plt.savefig:
plt.savefig(self.name + str(turn))
plt.clf()
If you want to save the figure after displaying it, you'll need to hold on to the figure instance. The reason that plt.savefig doesn't work after calling show is that the current figure has been reset.
pyplot keeps track of which figures, axes, etc are "current" (i.e. have not yet been displayed with show) behind-the-scenes. gcf and gca get the current figure and current axes instances, respectively. plt.savefig (and essentially any other pyplot method) just does plt.gcf().savefig(...). In other words, get the current figure instance and call its savefig method. Similarly plt.plot basically does plt.gca().plot(...).
After show is called, the list of "current" figures and axes is empty.
In general, you're better off directly using the figure and axes instances to plot/save/show/etc, rather than using plt.plot, etc, to implicitly get the current figure/axes and plot on it. There's nothing wrong with using pyplot for everything (especially interactively), but it makes it easier to shoot yourself in the foot.
Use pyplot for plt.show() and to generate a figure and an axes object(s), but then use the figure or axes methods directly. (e.g. ax.plot(x, y) instead of plt.plot(x, y), etc) The main advantage of this is that it's explicit. You know what objects you're plotting on, and don't have to reason about what the pyplot state-machine does (though it's not that hard to understand the state-machine interface, either).
As an example of the "recommended" way of doing things, do something like:
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-1, 1, 100)
y = x**2
fig, ax = plt.subplots()
ax.plot(x, y)
fig.savefig('fig1.pdf')
plt.show()
fig.savefig('fig2.pdf')
If you'd rather use the pyplot interface for everything, then just grab the figure instance before you call show. For example:
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-1, 1, 100)
y = x**2
plt.plot(x, y)
fig = plt.gcf()
fig.savefig('fig1.pdf')
plt.show()
fig.savefig('fig2.pdf')
source

Matplotlib 3D plot doesn't plot correctly

I'm having a problem trying to plot a series of lines in a 3D plot in MatPlotLib.
When I run the code below all the lines are plotted at the last value of y??? Even though y is correctly incremented in the loop.
Any Help understanding this would be appreciated.
Thanks
David
#========== Code Start=================
import numpy as np
import matplotlib
from matplotlib.figure import Figure
import pylab as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = Axes3D(fig)
x=np.arange(5)
y=np.zeros(len(x))
for i in range(1,10):
y.fill(i)
z=plt.randn(len(y))
ax.plot(xs=x, ys=y, zs=z)#, zdir='z', label='ys=0, zdir=z')
plt.draw()
print i,len(y),y,x,z
plt.xlabel('X')
plt.ylabel('Y')
plt.zlabel('Z')
plt.show()
#========== Code End=================
It looks like y might be pointed to by all plots. So you are passing the reference to y when you execute ax.plot. It is the same reference each time, but the values are changed on each pass. When the plt.show() is executed the reference to y is used and it is now set at 9. So, create a different object for y on each pass with the values you want for that pass:
y = np.zeros(len(x))
y.file(i)
There might be a numpy command that fills with the value you want in one go, but you get the picture.

Categories