How to update 3D arrow animation in matplotlib - python

I am trying to reproduce the left plot of this animation in python using matplotlib.
I am able to generate the vector arrows using the 3D quiver function, but as I read here, it does not seem possible to set the lengths of the arrows. So, my plot does not look quite right:
So, the question is: how do I generate a number of 3D arrows with different lengths? Importantly, can I generate them in such a way so that I can easily modify for each frame of the animation?
Here's my code so far, with the not-so-promising 3D quiver approach:
import numpy as np
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d
ax1 = plt.subplot(111,projection='3d')
t = np.linspace(0,10,40)
y = np.sin(t)
z = np.sin(t)
line, = ax1.plot(t,y,z,color='r',lw=2)
ax1.quiver(t,y,z, t*0,y,z)
plt.show()

As Azad suggests, an inelegant, but effective, solution is to simply edit the mpl_toolkits/mplot3d/axes3d.py to remove the normalization. Since I didn't want to mess with my actual matplotlib installation, I simply copied the axes3d.py file to the same directory as my other script and modified the line
norm = math.sqrt(u ** 2 + v ** 2 + w ** 2)
to
norm = 1
(Be sure to change the correct line. There is another use of "norm" a few lines higher.) Also, to get axes3d.py to function correctly when it's outside of the mpl directory, I changed
from . import art3d
from . import proj3d
from . import axis3d
to
from mpl_toolkits.mplot3d import art3d
from mpl_toolkits.mplot3d import proj3d
from mpl_toolkits.mplot3d import axis3d
And here is the nice animation that I was able to generate (not sure what's going wrong with the colors, it looks fine before I uploaded to SO).
And the code to generate the animation:
import numpy as np
import matplotlib.pyplot as plt
import axes3d_hacked
ax1 = plt.subplot(111,projection='3d')
plt.ion()
plt.show()
t = np.linspace(0,10,40)
for index,delay in enumerate(np.linspace(0,1,20)):
y = np.sin(t+delay)
z = np.sin(t+delay)
if delay > 0:
line.remove()
ax1.collections.remove(linecol)
line, = ax1.plot(t,y,z,color='r',lw=2)
linecol = ax1.quiver(t,y,z, t*0,y,z)
plt.savefig('images/Frame%03i.gif'%index)
plt.draw()
plt.ioff()
plt.show()
Now, if I could only get those arrows to look prettier, with nice filled heads. But that's a separate question...
EDIT: In the future, matplotlib will not automatically normalize the arrow lengths in the 3D quiver per this pull request.

Another solution is to call ax.quiever on each arrow, individually - with each call having an own length attribute. This is not very efficient but it will get you going.
And there's no need to change MPL-code

Related

Create a detailed svg graph with matplotlib

I use the Python library matplotlib to draw a graph with a lot of data. Upon executing plt.show() I can zoom in and see the details of the graph. However, I would like to save the graph into a svg file with plt.savefig and see these details which by default are not visible from the default non-zoomed-in view. How can I do that?
Please, note that increasing DPI or inch by inch dimensions is meaningless when working with vector graphics formats such as the svg file.
As an example, consider the following program.
import matplotlib.pyplot as plt
import numpy as np
import math
x = np.arange(0,100,0.00001)
y = x*np.sin(2*math.pi*(x**1.2))
plt.plot(y)
plt.savefig('test.svg')
We will get the following plot which even when we zoom, we cannot see the details of the sine wave periods.
But we can see the details of the sine wave when displaying the image with plt.show instead and then zooming in.
Add the size of the figure:
import matplotlib.pyplot as plt
import numpy as np
import math
x = np.arange(0,100,0.00001)
y = x*np.sin(2*math.pi*(x**1.2))
fig = plt.figure(figsize=(19.20,10.80))
plt.plot(y)
plt.savefig('test.svg')
and you get the kind of resolution you wish.
As correctly observed by JohanC, another good solution is to reduce the width of the line:
import matplotlib.pyplot as plt
import numpy as np
import math
x = np.arange(0,100,0.00001)
y = x*np.sin(2*math.pi*(x**1.2))
#fig = plt.figure(figsize=(19.20,10.80))
plt.plot(y, linewidth=0.1)
plt.savefig('test.svg')

Tracing functions in python

I was searching about how to trace function graphs, but not only linear ones, I know how to plot with simple points, they are the linear ones like this one below:
import numpy
import matplotlib.pyplot as plt
%matplotlib inline
_=plt.plot([4,7],[5,7],color ='w')
_=plt.plot([4,7],[7,7],color ='w')
ax = plt.gca()
ax.set_facecolor('xkcd:red')
plt.show()
then after a bit of searching, I've found this code:
import pylab
import numpy
x = numpy.linspace(-15,15,100) # 100 linearly spaced numbers
y = numpy.sin(x)/x # computing the values of sin(x)/x
# compose plot
pylab.plot(x,y) # sin(x)/x
pylab.plot(x,y,'co') # same function with cyan dots
pylab.plot(x,2*y,x,3*y) # 2*sin(x)/x and 3*sin(x)/x
pylab.show() # show the plot
That works perfectly! But what I'm wondering is: do we really need to use standard functions that have defined by Numpy?( like sin(x)/x here ) Or can we define a function ourselves and use it in Numpy function too, like x**3?
This solved issue, Thanks FlyingTeller
An example of y=x**3 graph:
import pylab
import numpy
x = numpy.linspace(-15,15,100) # 100 linearly spaced numbers
y = x**3 # we change this to tracer graphs as we want
# compose plot
pylab.plot(x,y)
pylab.show()

matplotlib: Don't compress plot in the horizontal direction

I am very new to matplotlib and I am having some difficulty with this figure:
I have a text file with x y point groups that I should plot. However, the x points overlap in each group, so I add an offset to each x axis point for each group.
Usually, the single groups look like this:
Note that the x axis in the first image ends where the x-axis in the second image begins.
My problem is that the resulting image is squelched/compressed and not really "readable".
I tried increasing the value that is added to the x-axis for each group/image, but it just compresses each group even more.
I tried suggestions to use rcParams or set the dpi value of the resulting image, but nothing does the job:
from pylab import rcParams
rcParams['figure.figsize'] = 50, 100
plt.savefig('result.png', dpi=200,pad_inches=5)
What am I doing wrong or looking for?
PS: The data and code is here. To see what my problem is, call python2.7 plotit.py text.txt
If I understand your question, you don't like your image being too compressed in the horizontal direction. It happens because by default matplotlib chooses the aspect ratio necessary to fill the given figure size. You were on the right track with changing changing figsize, but if you want to change it in rcParams, you have to put this call somewhere before you start plotting. The other approach is to use the stateless API, that is fig = plt.figure(figsize=(8,2)); s = fig.add_subplot(111); s.plot(...). That's what you get:
from pylab import rcParams
rcParams['figure.figsize'] = 8, 2
Note that I shrunk the circle sizes to make the lines more distinguishable:
plt.scatter(x,y,s=1)
if px!='':
plt.plot([px,x],[py,y],'-o',markersize=1)
For a more accurate control you can actually set the aspect ratio directly:
plt.axes().set_aspect(1)
or use some of the predefined modes, e.g.
plt.axis('equal')
plt.tight_layout()
Edit: for the reference, full code for the final picture:
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import sys
from pylab import rcParams
rcParams['figure.figsize'] = 8, 2
def parsexydata(fdata):
keys=[]
xy=[]
with open(fdata,'r') as f:
pre=''
for idx, i in enumerate(f.read().replace('\n','').split(',')[2:]):
if idx%2==0:
pre=i
continue
tmp = pre.split('.')
if len(tmp)!=3: continue
[before,key,after] = pre.split('.')
pre = before+'.'+after
if key not in keys: keys.append(key)
xy.append([pre,i,key])
return [xy,keys]
[xydata, keys] = parsexydata(sys.argv[1])
for idx, k in enumerate(keys):
px=py=''
for [x,y,key] in xydata:
if key!=k: continue
x=float(x)+float(k)
if key=='01': print(x)
plt.scatter(x,y,s=1)
if px!='':
plt.plot([px,x],[py,y],'-o',markersize=1)
px,py=x,y
plt.axis('equal')
plt.tight_layout()
plt.savefig('result.png', dpi=200)

Redrawing a plot in IPython / matplotlib

I use IPython/Matplotlib, and I want to create functions that can plot various graphs in the same plotting window. However, I have trouble with redrawing. This is my program test_plot_simple.py:
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
x = np.arange(10)
y2 = (x**2)/(10**2)
ye = (2**x)/(2**10)
fig, ax = plt.subplots()
def p_squared():
ax.plot(x,y2, 'r')
plt.show()
def p_exp():
ax.plot(x,ye, 'r')
plt.show()
I start IPython as $ python --matplotlib
On the IPython command line I do
In [1]: run test_plot_simple.py
In [2]: p_squared()
In [3]: p_exp()
After the second line, the squared graph is shown. But nothing happens after the second. Why is the plt.show() not working here?
It appears as though you call subplots without actually taking advantage of them, namely that you are trying to over plot on the same canvas. See here for a more thorough explanation. That being said, all you need is the following in order to have the functionality I think you want:
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
x = np.arange(10)
y2 = (x**2)/(10**2)
ye = (2**x)/(2**10)
def p_squared():
plt.plot(x,y2, 'r')
plt.show()
def p_exp():
plt.plot(x,ye, 'r')
plt.show()
Now both the p_squared() and p_exp() calls produce plots. Hope this helps.
After some digging I think I found the right way to go about this. It seems that show() is not really intended for this purpose, but rather draw() is. And if I want to keep it object-oriented, I should draw via my figure or my axis. It seems to me that something like this is the best approach:
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
x = np.arange(10)
y2 = (x**2)/(10**2)
ye = (2**x)/(2**10)
fig, ax = plt.subplots()
fig.show()
def p_squared():
ax.plot(x,y2, 'r')
fig.canvas.draw()
def p_exp():
ax.plot(x,ye, 'r')
fig.canvas.draw()
I.e., use fig.canvas.draw() in lieu of plt.show() (or fig.show(), for that matter.)
I still need one show() - I chose to do that right away after the figure has been created.

How to plot a 3d scatterplot with the color determined from another variable?

I made a 3d scatterplot that displays the position of galaxies in a cluster (basically like the latitude and longitude) as a function of their velocity. However, I've been asked to make the color of the data points be determined by another variable, h in the code. The purpose of the variable isn't important to know, but that in my actual code, every data point is determined from 4 arrays. After spending a long time looking up how to do this, I finally (almost) have it. The only problem is that when I plot it, the colors of the dots change as soon as I move the plot around to see it from a different direction. Also, I've been having issues trying to display a colorbar.
import pylab as p
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib.cm as cm
ra=np.random.random((100))
dec=np.random.random((100))
h=np.random.random((100))
z=np.random.random((100))
datamin=min(h)
datamax=max(h)
fig=p.figure()
ax3D=fig.add_subplot(111, projection='3d')
ax3D.scatter(ra, dec, z, c=h, vmin=datamin, vmax=datamax,
marker='o', cmap=cm.Spectral)
p.title("MKW4s-Position vs Velocity")
p.show()
'Changing color upon redraw' issue was a bug but looks like it's fixed in the latest release (1.1.1). I've tested and confirmed that it's working as it should with 1.1.1.
For the colorbar, it needs a mappable. You can use the collection returned from scatter:
import pylab as p
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib.cm as cm
ra=np.random.random((100))
dec=np.random.random((100))
h=np.random.random((100))
z=np.random.random((100))
datamin=min(h)
datamax=max(h)
fig=p.figure()
ax3D=fig.add_subplot(111, projection='3d')
collection = ax3D.scatter(ra, dec, z, c=h, vmin=datamin, vmax=datamax,
marker='o', cmap=cm.Spectral)
p.colorbar(collection)
p.title("MKW4s-Position vs Velocity")
p.show()

Categories