I tried to adapt animated scatter plot-example in such a way, so it shows real-time the results of an agent based-model I developed. However, the result shown in the graph are not that what I except them to be.
It goes wrong when updating the values, and strange patterns appear where the agents tend to cluster in a diagonal line.
I added some simple code that illustrates this problem. Does anyone has an idea what goes wrong?
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import os
n = 25 ## nr of agents
x,y = 10, 10 ## matrix of x by y dimension
dataX, dataY, binaryRaster = [],[],[]
class AnimatedScatter(object):
"""An animated scatter plot using matplotlib.animations.FuncAnimation."""
def __init__(self):
global n
self.numpoints = n
self.stream = self.data_stream()
self.fig, self.ax = plt.subplots()
self.ax.set_title("My first Agent Based Model (ABM)",fontsize=14)
self.ax.grid(True,linestyle='-',color='0.75')
self.ani = animation.FuncAnimation(self.fig, self.update, interval=100,
init_func=self.setup_plot, blit=True,
repeat=False)
def setup_plot(self):
"""Initial drawing of the scatter plot."""
global x,y
dataX,dataY = next(self.stream)
self.scat = self.ax.scatter(dataY, dataX, c="tomato", s=20, animated=True)
self.ax.axis([0, y, x, 0])
return self.scat,
def data_stream(self):
"""Generate a random walk (brownian motion). Data is scaled to produce
a soft "flickering" effect."""
global x,y, n
dataX,dataY = self.createRandomData()
#printing results to ascii for validation
lines = []
binaryData = np.zeros((x,y), dtype=np.int)
for i in range(n):
binaryData[dataX,dataY] =1
for i in range(x):
line = ""
for j in range(y):
line += str(binaryData[i,j])+ ","
line= line[:-1]+ "\n"
lines.append(line)
lines.append("\n")
yx = np.array([dataY,dataX])
cnt = 0
while cnt < 10:
dataX,dataY = self.createRandomData()
yx = np.array([dataY,dataX])
#printing results to ascii for validation
binaryData = np.zeros((x,y), dtype=np.int)
for i in range(n):
binaryData[dataX,dataY] =1
for i in range(x):
line = ""
for j in range(y):
line += str(binaryData[i,j])+ ","
line= line[:-1]+ "\n"
lines.append(line)
lines.append("\n")
cnt+=1
yield yx
#printing results to ascii for validation
outNm = os.getcwd()+"\\ScatterValidation.txt"
outfile = open(outNm, "w")
outfile.writelines(lines)
outfile.close()
return
def update(self, i):
"""Update the scatter plot."""
data = next(self.stream)
self.scat.set_offsets(data[:2, :])
return self.scat,
def show(self):
plt.show()
def createRandomData(self):
"""Positions n agents randomly on a raster of x by y cells.
Each cell can only hold a single agent."""
global x,y,n
binaryData = np.zeros((x,y), dtype=np.int)
newAgents = 0
dataX,dataY = [],[]
while newAgents < n:
row = np.random.randint(0,x,1)[0]
col = np.random.randint(0,y,1)[0]
if binaryData[row][col] != 1:
binaryData[row][col] = 1
newAgents+=1
for row in range(x):
for col in range(y):
if binaryData[row][col] == 1:
dataX.append(row)
dataY.append(col)
return dataX, dataY
def main():
global n, x, y, dataX, dataY, binaryRaster
a = AnimatedScatter()
a.show()
return
if __name__ == "__main__":
main()
You can fix your script in 2 ways, both involve changing the update function:
Using a scatter call in the update function, is clearer I think
Transposing the data array before calling set_offsets in update
Using a scatter call is the clearest fix, and you could increase the agents during your run:
def update(self, i):
"""Update the scatter plot."""
dataX, dataY = next(self.stream)
self.scat = self.ax.scatter(dataX, dataY, c="tomato", s=20, animated=True)
return self.scat,
Transposing the offsets array will also work:
def update(self, i):
"""Update the scatter plot."""
data = next(self.stream)
self.scat.set_offsets(data.transpose())
return self.scat,
Offsets are given as a N tuples of 2 items each while the data array is given as 2 tuples with N items each, transposing the data array will fix your problem.
Note: If you do not change the global variables, you do not need to specify the globals with a global statement, so in setup_plot, __init__ etc. you can remove the global n,x,y lines.
I would put n,x and y as instance variables of your class, plus there is no need for dataX, dataY and binaryRasted to be defined at the top of your script.
Related
I'm trying to animate a plot using matplotlib's FuncAnimation, however no frames of the animation are visible until the animation reaches the final frame. If I set repeat = True nothing is ever displayed. When I first run the code a matplotlib icon appears but nothing displays when I click on it until it shows me the final frame:
If I save the animation I see the animation display correctly so this leads me to think that my code is mostly correct so I hope this is a simple fix that I'm just missing.
Apologies if I'm dumping too much code but I'm not sure if there's anything that's not needed for the minimum reproducible example.
Here's the main code
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from quantum_custom.constants import spin_down, spin_up, H00, H11, H
import quantum_custom.walk as walk
class QuantumState:
def __init__(self, state):
self.state = state
#"coin flips"
max_N = 100 #this will be the final number of coin flips
positions = 2*max_N + 1
#initial conditions
initial_spin = spin_down
initial_position = np.zeros(positions)
initial_position[max_N] = 1
initial_state = np.kron(np.matmul(H, initial_spin), initial_position) #initial state is Hadamard acting on intial state, tensor product with the initial position
quantum_state = QuantumState(initial_state)
#plot the graph
fig, ax = plt.subplots()
plt.title("N = 0")
x = np.arange(positions)
line, = ax.plot([],[])
loc = range(0, positions, positions // 10)
plt.xticks(loc)
plt.xlim(0, positions)
plt.ylim((0, 1))
ax.set_xticklabels(range(-max_N, max_N + 1, positions // 10))
ax.set_xlabel("x")
ax.set_ylabel("Probability")
def init():
line.set_data([],[])
return line,
def update(N):
next_state = walk.flip_once(quantum_state.state, max_N)
probs = walk.get_prob(next_state, max_N)
quantum_state.state = next_state
start_index = N % 2 + 1
cleaned_probs = probs[start_index::2]
cleaned_x = x[start_index::2]
line.set_data(cleaned_x, cleaned_probs)
if cleaned_probs.max() != 0:
plt.ylim((0, cleaned_probs.max()))
plt.title(f"N = {N}")
return line,
anim = animation.FuncAnimation(
fig,
update,
frames = max_N + 1,
init_func = init,
interval = 20,
repeat = False,
blit = True,
)
anim.save("animated.gif", writer = "ffmpeg", fps = 15)
plt.show()
Here's my quantum_custom.constants module.
#define spin up and spin down vectors as standard basis
spin_up = np.array([1,0])
spin_down = np.array([0,1])
#define our Hadamard operator, H, in terms of ith, jth entries, Hij
H00 = np.outer(spin_up, spin_up)
H01 = np.outer(spin_up, spin_down)
H10 = np.outer(spin_down, spin_up)
H11 = np.outer(spin_down, spin_down)
H = (H00 + H01 + H10 - H11)/np.sqrt(2.0) #matrix representation of Hadamard gate in standard basis
Here's my quantum_custom.walk module.
import numpy as np
from quantum_custom.constants import H00, H11, H
#define walk operators
def walk_operator(max_N):
position_count = 2 * max_N + 1
shift_plus = np.roll(np.eye(position_count), 1, axis = 0)
shift_minus = np.roll(np.eye(position_count), -1, axis = 0)
step_operator = np.kron(H00, shift_plus) + np.kron(H11, shift_minus)
return step_operator.dot(np.kron(H, np.eye(position_count)))
def flip_once(state, max_N):
"""
Flips the Hadamard coin once and acts on the given state appropriately.
Returns the state after the Hadamard coin flip.
"""
walk_op = walk_operator(max_N)
next_state = walk_op.dot(state)
return next_state
def get_prob(state, max_N):
"""
For the given state, calculates the probability of being in any possible position.
Returns an array of probabilities.
"""
position_count = 2 * max_N + 1
prob = np.empty(position_count)
for k in range(position_count):
posn = np.zeros(position_count)
posn[k] = 1
posn_outer = np.outer(posn, posn)
alt_measurement_k = eye_kron(2, posn_outer)
proj = alt_measurement_k.dot(state)
prob[k] = proj.dot(proj.conjugate()).real
return prob
def eye_kron(eye_dim, mat):
"""
Speeds up the calculation of the tensor product of an identity matrix of dimension eye_dim with a given matrix.
This exploits the fact that majority of values in the resulting matrix will be zeroes apart from on the leading diagonal where we simply have copies of the given matrix.
Returns a matrix.
"""
mat_dim = len(mat)
result_dim = eye_dim * mat_dim #dimension of the resulting matrix
result = np.zeros((result_dim, result_dim))
result[0:mat_dim, 0:mat_dim] = mat
result[mat_dim:result_dim, mat_dim:result_dim] = mat
return result
I know that saving the animation is a solution but I'd really like to have the plot display just from running the code as opposed to having to save it. Thanks!
As per Sameeresque's suggestion I tried using different backends for matplot lib. This was done by altering by import statements as follows.
import matplotlib
matplotlib.use("tkagg")
import matplotlib.pyplot as plt
Note that it's important to add the two additional lines before import matplotlib.pyplot as plt otherwise it won't do anything.
I am building a Quadtree using python, and have managed to create a solution for randomly generated points. the main class is the QTree class:
class QTree():
def __init__(self, treshold, customerCount):
self.threshold = treshold
self.points = [Point(random.uniform(0, 100), random.uniform(0, 100)) for x in range(customerCount)]
self.root = Node(0, 0, 100, 100, self.points)
def add_point(x, y):
self.points.append(Point(x, y))
def get_points(self):
return self.points
def subdivide(self):
recursive_subdivide(self.root, self.threshold)
def graph(self):
fig = plt.figure(figsize=(12, 8))
x = [point.x for point in self.points]
y = [point.y for point in self.points]
ax = fig.add_subplot(111)
c = find_children(self.root)
print("\n\nNumber of segments: %d" % len(c))
areas = set()
for el in c:
areas.add(el.width * el.height)
print("Minimum segment area: %.3f units" % min(areas))
for n in c:
ax.add_patch(patches.Rectangle((n.x0, n.y0), n.width, n.height, fill=False))
plt.title("Quadtree")
plt.plot(x, y, 'ro', markersize=3, color='b')
plt.savefig('QuadtreeDiagram.png', dpi=1000)
plt.show()
return
I am obtaining my Quadtree diagram (https://i.stack.imgur.com/ojYHO.png) by calling the Qtree class and various other def funcitons:
def test(treshold, customerCount):
qt = QTree(treshold, customerCount)
qt.subdivide()
qt.graph()
# Tests
test(1, 50)
My question is: how do I change the random points and use my own CSV file with co-ordinates? Thank you :D
You can add a new function to load the data from file and set the points:
def load_points():
Df1 = pd.read_csv("D:\\test.txt", sep='\t' )
self.points=[Point(Df1["Col0"],Df1["Col1"]) for x in range(len(Df1))]
Note1: You must get the relation between len of you new data and customerCount
Note2: I assume that the file is for example like:
Col0 Col1
1 2
2 4
3 5
I am following the book «Learning form data» which has the following exercise:
Create a linearly separable Input data set.
Implement single layer perceptron on the generated data set.
To create a data set, it says to choose a two dimensional plane and then to choose a random line in the plane. Points to one side of the plane are classified positive and points to the other are classified negative. I was able to follow this with the help of https://datasciencelab.wordpress.com/2014/01/10/machine-learning-classics-the-perceptron/
import numpy as np
import random
import os, subprocess
class Perceptron:
def __init__(self, N):
# Random linearly separated data
xA,yA,xB,yB = [random.uniform(-1, 1) for i in range(4)]
self.V = np.array([xB*yA-xA*yB, yB-yA, xA-xB])
self.X = self.generate_points(N)
def generate_points(self, N):
X = []
for i in range(N):
x1,x2 = [random.uniform(-1, 1) for i in range(2)]
x = np.array([1,x1,x2])
s = int(np.sign(self.V.T.dot(x)))
X.append((x, s))
return X
def plot(self, mispts=None, vec=None, save=False):
fig = plt.figure(figsize=(5,5))
plt.xlim(-1,1)
plt.ylim(-1,1)
V = self.V
a, b = -V[1]/V[2], -V[0]/V[2]
l = np.linspace(-1,1)
plt.plot(l, a*l+b, 'k-')
cols = {1: 'r', -1: 'b'}
for x,s in self.X:
plt.plot(x[1], x[2], cols[s]+'o')
if mispts:
for x,s in mispts:
plt.plot(x[1], x[2], cols[s]+'.')
if vec != None:
aa, bb = -vec[1]/vec[2], -vec[0]/vec[2]
plt.plot(l, aa*l+bb, 'g-', lw=2)
if save:
if not mispts:
plt.title('N = %s' % (str(len(self.X))))
else:
plt.title('N = %s with %s test points' \
% (str(len(self.X)),str(len(mispts))))
plt.savefig('p_N%s' % (str(len(self.X))), \
dpi=200, bbox_inches='tight')
def classification_error(self, vec, pts=None):
# Error defined as fraction of misclassified points
if not pts:
pts = self.X
M = len(pts)
n_mispts = 0
for x,s in pts:
if int(np.sign(vec.T.dot(x))) != s:
n_mispts += 1
error = n_mispts / float(M)
return error
def choose_miscl_point(self, vec):
# Choose a random point among the misclassified
pts = self.X
mispts = []
for x,s in pts:
if int(np.sign(vec.T.dot(x))) != s:
mispts.append((x, s))
return mispts[random.randrange(0,len(mispts))]
def pla(self, save=False):
# Initialize the weigths to zeros
w = np.zeros(3)
X, N = self.X, len(self.X)
it = 0
# Iterate until all points are correctly classified
while self.classification_error(w) != 0:
it += 1
# Pick random misclassified point
x, s = self.choose_miscl_point(w)
# Update weights
w += s*x
if save:
self.plot(vec=w)
plt.title('N = %s, Iteration %s\n' \
% (str(N),str(it)))
plt.savefig('p_N%s_it%s' % (str(N),str(it)), \
dpi=200, bbox_inches='tight')
self.w = w
def check_error(self, M, vec):
check_pts = self.generate_points(M)
return self.classification_error(vec, pts=check_pts)
Image of the logic and code is given below:
2 D linearly separable data
I want to create a N dimensional linearly separable data set similarly.
for example a 10 dimensional set of points. Points to one side of 9 dimensional hyperplane are to be classified positive and to the other as negative.
I have no clue on how to proceed. Any help is appreciated.
I'm trying to do some animation with matplotlib (Conways game of life, to be specific) and have some problems with the .FuncAnimation
I figured out diffrent cases wich partly worked (but not the way I want) or result in diffrent errors. I would like to understand the errors and work out a proper version of the code. Thanks for your help!
The function called through the .FuncAnimation is gameoflife wich uses the variables w, h, grid to uptdate the image.
For the whole commented code see below.
Case 1: Global Variables
If I use global variables everthing works fine.
I define w, h, grid global before i call gameoflife(self) through anim = animation.FuncAnimation(fig, gameoflife)
In gameoflife(self) i also define w, h, grid as global variables
w, h, grid = "something"
def gameoflife(self):
global w
global h
global grid
.
.
.
img = ax.imshow(grid)
return img
fig, ax = plt.subplots()
plt.axis('off')
img = ax.imshow(grid)
anim = animation.FuncAnimation(fig, gameoflife)
plt.show()
As said, this results in the animation as wanted. But I would like to get rid of the global variables, because of which I tried something else:
Case 2: Passing Objects
I don't defined w, h, grid as globals in gameoflife but passed them with anim = animation.FuncAniation(fig, gameoflife(w,h,grid)).
(I know that w, h, grid are still global in my example. I work on another version where they are not but since the errors are the same I think this simplyfied version should do it.)
This results in the following Error:
TypeError: 'AxesImage' object is not callable
I dont understand this error, since I don't call ax with the code changes.
w, h, grid = "something"
def gameoflife(w, h, grid):
.
.
.
img = ax.imshow(grid)
return img
fig, ax = plt.subplots()
plt.axis('off')
img = ax.imshow(grid)
anim = animation.FuncAnimation(fig, gameoflife(w,h,grid))
plt.show()
Case 3: Passing Objects with fargs
In the third case I try to pass w, h, grid with the "frags" argument of .FuncAnimation resulting in just the first frame. (Or the first two, depending how you see it. The "frist" frame is accually drawn through img = ax.imshow(grid))
w, h, grid = "something"
def gameoflife(self, w, h, grid):
.
.
.
img = ax.imshow(grid)
return img
fig, ax = plt.subplots()
plt.axis('off')
img = ax.imshow(grid)
anim = animation.FuncAnimation(fig, gameoflife, fargs=(w,h,grid))
plt.show()
Complete Code
I hope its properly commented ;)
There are two parts (beginning and end) where you can comment/uncomment parts to generate the respective case. By deafault its Case 1.
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
##defining grid size
w= 20
h = 20
##generating random grid
grid = np.array([[random.randint(0,1) for x in range(w)] for y in range(h)])
######
# Choose for diffrent cases
######
##Case 1: Global Variables
def gameoflife(self):
global w
global h
global grid
##Case 2: Passing Objects
#def gameoflife(w, h, grid):
##Case 3: Passing Objects with fargs
#def gameoflife(self, w, h, grid):
####### Choose part over
# wt, ht as test values for position
# x,y to set calculation position
wt = w-1
ht = h-1
x,y = -1,0 #results in 0,0 for the first postion
# defining grid for calculation (calgrid)
calgrid = np.array([[0 for x in range(w)] for y in range(h)])
# testing for last position
while y<ht or x<wt:
# moving position through the grid
if x == wt:
y +=1
x = 0
else:
x += 1
#sorrounding cells check value
scv = 0
#counting living cells around position x,y
#if-else for exeptions at last column and row
if y == ht:
if x == wt:
scv = grid[x-1][y-1] + grid[x][y-1] + grid[0][y-1] + grid[x-1][y] + grid[0][y] + grid[x-1][0] + grid[x][0] + grid[0][0]
else:
scv = grid[x-1][y-1] + grid[x][y-1] + grid[x+1][y-1] + grid[x-1][y] + grid[x+1][y] + grid[x-1][0] + grid[x][0] + grid[x+1][0]
else:
if x == wt:
scv = grid[x-1][y-1] + grid[x][y-1] + grid[0][y-1] + grid[x-1][y] + grid[0][y] + grid[x-1][y+1] + grid[x][y+1] + grid[0][y+1]
else:
scv = grid[x-1][y-1] + grid[x][y-1] + grid[x+1][y-1] + grid[x-1][y] + grid[x+1][y] + grid[x-1][y+1] + grid[x][y+1] + grid[x+1][y+1]
# test cell to condidions and write result in calgrid
if grid[x][y] == 0:
if scv == 3:
calgrid [x][y] = 1
else :
if 1<scv<4:
calgrid [x][y] = 1
# updating grid, generating img and return it
grid = calgrid
img = ax.imshow(grid)
return img
fig, ax = plt.subplots()
plt.axis('off')
img = ax.imshow(grid) #generates "first" Frame from seed
#####
# Choose vor Case
#####
## Case 1: Global Variables
anim = animation.FuncAnimation(fig, gameoflife)
## Case 2: Passing Variables
#anim = anim = animation.FuncAnimation(fig, gameoflife(w,h,grid))
## Case 3: Passing Variables with fargs
#anim = animation.FuncAnimation(fig, gameoflife, fargs=(w,h,grid))
####### Choose part over
plt.show()
Tanks for help and everything
Greetings Tobias
Case 2: You call the function and pass the result into FuncAnimation.
def gameoflife(w,h,grid):
# ...
return ax.imshow(grid)
anim = animation.FuncAnimation(fig, gameoflife(w,h,grid))
Is essentially the same as
anim = animation.FuncAnimation(fig, ax.imshow(grid))
which will not work because the second argument is expected to be a function, not the return of a function (in this case an image).
To explain this better, consider a simple test case. g is a function and expects a function as input. It will return the function evaluated at 4. If you supply a function f, all works as expected, but if you supply the return of a function, it would fail if the return is not itself a function, which can be evaluated.
def f(x):
return 3*x
def g(func):
return func(4)
g(f) # works as expected
g(f(2)) # throws TypeError: 'int' object is not callable
Case 3: You calling the function repeatedly with the same arguments
In the case of
anim = animation.FuncAnimation(fig, gameoflife, fargs=(w,h,grid))
you call the function gameoflife with the same initial arguments w,h,grid for each frame in the animation. Hence you get a static animation (the plot is animated, but each frame is the same, because the same arguments are used).
Conclusion. Stay with Case 1
Because case 1 is working fine, I don't know why not use it. A more elegant way would be to use a class and use class variables as e.g. in this question.
I have a code in which for 3 different values of D ,i have 3 different values of dx and so,3 different plots.
I want to do a plot which will have all 3 plots in one.
...
D=(0.133e-4,0.243e-4,0.283e-4)
dx=sc.zeros(3)
for i in D:
dx[i]=sc.sqrt(D[i]*dt/M)
plt.ion()
while n<N:
Vw_n=Vw_n1
C_n=C_n1
R2=(Vw_n+B1)/(Vw_0+B1)
Cc=C_n1[0]/C0
F2_1=10000/3*Pw*A*(C0*Vw_0/Vw_n1-C_n[1])
dV=F2_1*dt
Vw_n1=Vw_n+dV
C_n1[0]=C0*Vw_0/Vw_n1
F_i_2=-D[i]/dx[i]*(C_n[1:7]-C_n[0:6])
C_n1[0:6]=C_n[0:6]-F_i_2*A*dt/(L/(V0/A)*V0/5)
n+=1
ttime=n*0.02*1000
#-----PLOT AREA---------------------------------#
mylabels=('T=273','T=293','T=298')
colors=('-b','or','+k')
if x==1:
plt.plot(ttime,R2,mylabels[i],colors[i])
elif x==2:
plt.plot(ttime,Cc,mylabels[i],colors[i])
plt.draw()
plt.show()
----------RUNNABLE--------------------------
import scipy as sc
import matplotlib.pyplot as plt
def graph(x):
A=1.67e-6
V0=88e-12
Vw_n1=71.7/100*V0
Pw=0.22
L=4e-4
B1=V0-Vw_n1
C7=0.447e-3
dt=0.2e-4
M=0.759e-1
C_n1=sc.zeros(7)
C_n1[0:6]=0.290e-3
C_n1[6]=0.447e-3
C0=C_n1[0]
Vw_0=Vw_n1
N=2000
n =1
D = ,0.243e-4
dx = sc.sqrt(D*dt/M)
plt.ion()
while n<N:
Vw_n=Vw_n1
C_n=C_n1
R2=(Vw_n+B1)/(Vw_0+B1)
Cc=C_n1[0]/C0
F2_1=10000/3*Pw*A*(C0*Vw_0/Vw_n1-C_n[1])
dV=F2_1*dt
Vw_n1=Vw_n+dV
C_n1[0]=C0*Vw_0/Vw_n1
F_i_2=-D/dx*(C_n[1:7]-C_n[0:6])
C_n1[0:6]=C_n[0:6]-F_i_2*A*dt/(L/(V0/A)*V0/5)
n+=1
ttime=n*0.02*1000
#-----PLOT AREA---------------------------------#
if x==1:
plt.plot(ttime,R2)
elif x==2:
plt.plot(ttime,Cc)
plt.draw()
plt.show()
My problem is that i want to plot (ttime,R2) and (ttime,Cc).
But i can't figure how to call R2 and Cc for the 3 different values of D (and dx).
Also, i am taking an error: tuple indices must be integers, not float
at dx[i]=sc.sqrt(D[i]*dt/M).
Thanks!
Consider these lines:
D=(0.133e-4,0.243e-4,0.283e-4)
for i in D:
dx[i]=sc.sqrt(D[i]*dt/M)
i is a float. It can not be used as an index into the tuple D.
(D[i] does not make sense.)
Perhaps you meant
D=(0.133e-4,0.243e-4,0.283e-4)
for i, dval in enumerate(D):
dx[i] = sc.sqrt(dval*dt/M)
Or, simply
import scipy as sc
D = sc.array([0.133e-4,0.243e-4,0.283e-4])
dx = sc.sqrt(D*dt/M)
Don't call plt.plot once for each point. That road leads to
sluggish behavior. Instead, accumulate an entire curve's worth of
data points, and then call plt.plot once for the entire curve.
To plot 3 curves on the same figure, simply call plt.plot 3 times.
Do that first before calling plt.show().
The while not flag loop was not ending when you enter 1 for x,
because if x==2 should have been elif x==2.
To animate a matplotlib plot, you should still try to avoid multiple
calls to plt.plot. Instead, use plt.plot once to make a Line2D
object, and then update the underlying data with calls to
line.set_xdata and line.set_ydata. See Joe Kington's example and this example from the matplotlib docs.
import scipy as sc
import matplotlib.pyplot as plt
def graph(x):
plt.ion()
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
lines = []
D = (0.133e-4, 0.243e-4, 0.283e-4)
temperatures = ('T = 273','T = 293','T = 298')
N = 2000
linestyles = ('ob', '-r', '+m')
for dval, linestyle, temp in zip(D, linestyles, temperatures):
line, = ax.plot([], [], linestyle, label = temp)
lines.append(line)
plt.xlim((0, N*0.02*1000))
if x == 1:
plt.ylim((0.7, 1.0))
else:
plt.ylim((1.0, 1.6))
plt.legend(loc = 'best')
for dval, line in zip(D, lines):
A = 1.67e-6
V0 = 88e-12
Vw_n1 = 71.7/100*V0
Pw = 0.22
L = 4e-4
B1 = V0-Vw_n1
C7 = 0.447e-3
dt = 0.2e-4
M = 0.759e-1
C_n1 = sc.zeros(7)
C_n1[0:6] = 0.290e-3
C_n1[6] = 0.447e-3
C0 = C_n1[0]
Vw_0 = Vw_n1
tvals = []
yvals = []
dx = sc.sqrt(dval*dt/M)
for n in range(1, N+1, 1):
Vw_n = Vw_n1
C_n = C_n1
R2 = (Vw_n+B1)/(Vw_0+B1)
Cc = C_n1[0]/C0
F2_1 = 10000/3*Pw*A*(C0*Vw_0/Vw_n1-C_n[1])
dV = F2_1*dt
Vw_n1 = Vw_n+dV
C_n1[0] = C0*Vw_0/Vw_n1
F_i_2 = -dval/dx*(C_n[1:7]-C_n[0:6])
C_n1[0:6] = C_n[0:6]-F_i_2*A*dt/(L/(V0/A)*V0/5)
tvals.append(n*0.02*1000)
yvals.append(R2 if x == 1 else Cc)
if not len(yvals) % 50:
line.set_xdata(tvals)
line.set_ydata(yvals)
fig.canvas.draw()
if __name__ == "__main__":
flag = False
while not flag:
try:
x = int(raw_input("Give a choice 1 or 2 : "))
flag = True
if x == 1:
plt.title('Change in cell volume ratio as a function of time \n\
at various temperatures')
plt.xlabel('Time')
plt.ylabel('Ceil volume ratio (V/V0)')
graph(x)
elif x == 2:
plt.title('Increase of solute concentration at various temperatures')
plt.xlabel('Time')
plt.ylabel('Solute concentration in the Ceil (Cc)')
graph(x)
else:
flag = False
print("You must input 1 or 2")
except ValueError:
print("You must input 1 or 2")
raw_input('Press a key when done')