I'm using parallel processing to generate a plot of functions using complex numbers. My script allows you to zoom in on an area of the plot using the standard matplotlib controls and then regenerate the plot within the new limits to improve resolution.
This is my first foray into parallel processing and I've got as far as understanding that I need to preface with if __name__ == __main__: to allow the module to be imported properly. When running my script, the first plot is successfully generated and appears as expected. However, when the plotting function is called again from my event handler it instead hangs indefinitely. I assume that the hang is caused by some similar issue to that of requiring if __name__ == __main__:, as the parallel processes are being spawned from outside the main body of the script, but I haven't figured out anything further than this.
import numpy as np
import matplotlib.pyplot as plt
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
res = [1000, 1000]
base_factor = 2.
cpuNum = multiprocessing.cpu_count()
def brot(c, depth=200):
z = complex(0)
for i in range(depth):
z = pow(z, 2) + c
if abs(z) > 2:
return i
return -1
def brot_gen(span):
re_span = span[0]
im_span = span[1]
mset = np.zeros([len(im_span), len(re_span)])
for re in range(len(re_span)):
for im in range(len(im_span)):
mset[im][re] = brot(complex(re_span[re], im_span[im]))
return mset
def brot_gen_parallel(re_lim, im_lim):
re_span = np.linspace(re_lim[0], re_lim[1], res[0])
im_span = np.linspace(im_lim[0], im_lim[1], res[1])
split_re_span = np.array_split(re_span, cpuNum)
packages = [(sec, im_span) for sec in split_re_span]
print("Generating set between", re_lim, "and", im_lim, "...")
with ProcessPoolExecutor(max_workers = cpuNum) as executor:
result = executor.map(brot_gen, packages)
mset = np.concatenate(list(result), axis=1)
print("Set generated")
return mset
def handler(ax):
def action(event):
if event.button == 2:
cur_re_lim = ax.get_xlim()
cur_im_lim = ax.get_ylim()
mset = brot_gen_parallel(cur_re_lim, cur_im_lim)
ax.cla()
ax.imshow(mset, extent=[cur_re_lim[0], cur_re_lim[1], cur_im_lim[0], cur_im_lim[1]], origin="lower", vmin=0, vmax=200, interpolation="bilinear")
plt.draw()
fig = ax.get_figure()
fig.canvas.mpl_connect('button_release_event', action)
return action
if __name__ == "__main__":
re_lim = np.array([-2.5, 2.5])
im_lim = res[1]/res[0] * re_lim
mset = brot_gen_parallel(re_lim, im_lim)
plt.imshow(mset, extent=[re_lim[0], re_lim[1], im_lim[0], im_lim[1]], origin="lower", vmin=0, vmax=200, interpolation="bilinear")
ax = plt.gca()
f = handler(ax)
plt.show()
EDIT: I wondered if there was a bug in the code causing an exception, but that this might not be being successfully passed back to the console, however I tested this by running the same task without splitting it into parallel tasks and it completed successfully.
I have discovered the answer to my own question. The answer lies in the IDE I was using. In my experience, in most IDEs plt.show() blocks execution by default, however in Spyder the default seems to be the equivalent of plt.show(block=False), meaning that the script completed and so whatever was required to successfully start the parallel processes was no longer available, causing the hang. This was solved by simply changing the statement to plt.show(block=True), meaning that the script was still live.
I'm still very new to parallel processing so I'd be very interested in any more information anyone can give on what was lacking to stop the parallel processing from working.
Related
I'm using scipy.integrate.solve_ivp to solve a system of ODEs because it has the event functions.
The reason why I need this function is that during the integration sometimes I get a singular matrix, and everytime that happens I need to finish the integration and restart it with new parameters.
I would like to know if is possible to restart the scipy.integrate.solve_ivp with new parameters after a terminal event has occurred, and if so how could I do it.
Any help would be very much appreciated.
This is my current script based on an example from
https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from animate_plot import animate
def upward_cannon(t, y):
return [y[1], -0.5]
def hit_ground(t, y):
return y[0]
def apex(t,y):
return y[1]
hit_ground.terminal = True
hit_ground.direction = -1
t0 = 0
tf = 10
sol = solve_ivp(upward_cannon, t_span=[t0, tf], y0=[90, 10], t_eval=np.arange(t0,tf, 0.01), events=[hit_ground,apex],
dense_output=True)
linesData = { 1: [[-0.0, 0.0],[0.0, 0.0]]}#,
# 2: [[-0.5, 0],[0.5, 0.0]]}#, 3: [[-0.5, 0],[0.5, 0]]}
pointsofInterest = {}#3: [[0.5, 0.0]]}#, 2: [[180.0, 10]]}
model_markers = np.array([])
plot_title = 'Upward Particle'
plot_legend = ['Forward Dynamics']
q_rep = sol.y.T[:,0]
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(t0, tf, 0.01)
for idx in range(0,q_rep.shape[0]): #looping statement;declare the total number of frames
y=q_rep[idx] # traveling Sine wave
ax.cla()
ax.scatter(xs[idx],y, s=50)
plt.ylim([-10, 190])
plt.xlim([-100, 100])
plt.pause(0.001)
plt.show()
Thank you in advance.
Kind Regards
You have two options, both are recursive.
Option 1: Write the function to call itself inside of the script. This would be true recursion and elegant.
Option 2: If your function comes across these values you need to resolve, use argparsing and os to call the function with specified values.
Example:
os.system(python3 filename.py -f argparseinputs)
I'm having an issue exactly like this post, but slightly more frustrating.
I'm using matplotlib to read from a file that is being fed data from another application. For some reason, the ends of the data only connect after the animation (animation.FuncAnimation) has completed its first refresh. Here are some images:
This is before the refresh:
And this is after the refresh:
Any ideas as to why this could be happening? Here is my code:
import json
import itertools
import dateutil.parser
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib import style
import scipy.signal as sig
import numpy as np
import pylab as plt
sensors = {}
data = []
lastLineReadNum = 0
class Sensor:
def __init__(self, name, points = 0, lastReading = 0):
self.points = points
self.lastReading = lastReading
self.name = name
self.x = []
self.y = []
class ScanResult:
def __init__(self, name, id, rssi, macs, time):
self.name = name
self.id = id
self.rssi = rssi
self.macs = macs
# Is not an integer, but a datetime.datetime
self.time = time
def readJSONFile(filepath):
with open(filepath, "r") as read_file:
global lastLineReadNum
# Load json results into an object that holds important info
for line in itertools.islice(read_file, lastLineReadNum, None):
temp = json.loads(line)
# Only reads the most recent points...
data.append(ScanResult(name = temp["dev_id"],
id = temp["hardware_serial"],
rssi = temp["payload_fields"]["rssis"],
macs = temp["payload_fields"]["macs"],
time = dateutil.parser.parse(temp["metadata"]["time"])))
lastLineReadNum += 1
return data
style.use('fivethirtyeight')
fig = plt.figure()
ax1 = fig.add_subplot(1, 1, 1)
def smooth(y, box_pts):
box = np.ones(box_pts)/box_pts
y_smooth = np.convolve(y, box, mode='same')
return y_smooth
def determineClosestSensor():
global sensors
#sensors.append(Sensor(time = xs3, rssi = ys3))
def determineXAxisTime(scanresult):
return ((scanresult.time.hour * 3600) + (scanresult.time.minute * 60) + (scanresult.time.second)) / 1000.0
def animate(i):
data = readJSONFile(filepath = "C:/python_testing/rssi_logging_json.json")
for scan in data:
sensor = sensors.get(scan.name)
# First time seeing the sensor
if(sensor == None):
sensors[scan.name] = Sensor(scan.name)
sensor = sensors.get(scan.name)
sensor.name = scan.name
sensor.x.append(determineXAxisTime(scan))
sensor.y.append(scan.rssi)
else:
sensor.x.append(determineXAxisTime(scan))
sensor.y.append(scan.rssi)
ax1.clear()
#basic smoothing using nearby averages
#y_smooth3 = smooth(np.ndarray.flatten(np.asarray(sensors.get("sentrius_sensor_3").y)), 1)
for graphItem in sensors.itervalues():
smoothed = smooth(np.ndarray.flatten(np.asarray(graphItem.y)), 1)
ax1.plot(graphItem.x, smoothed, label = graphItem.name, linewidth = 2.0)
ax1.legend()
determineClosestSensor()
fig.suptitle("Live RSSI Graph from Sentrius Sensors", fontsize = 14)
def main():
ani = animation.FuncAnimation(fig, animate, interval = 15000)
plt.show()
if __name__ == "__main__":
main()
As far as I can tell you are regenerating your data in each animation step by appending to the existing datasets, but then this means that your last x point from the first step is followed by the first x point in the second step, leading to a rewind in the plot. This appears as the line connecting the last datapoint with the first one; the rest of the data is unchanged.
The relevant part of animate:
def animate(i):
data = readJSONFile(filepath = "C:/python_testing/rssi_logging_json.json")
for scan in data:
sensor = sensors.get(scan.name)
# First time seeing the sensor
if(sensor is None): # always check for None with `is`!
... # stuff here
else:
sensor.x.append(determineXAxisTime(scan)) # always append!
sensor.y.append(scan.rssi) # always append!
... # rest of the stuff here
So, in each animation step you
1. load the same JSON file
2. append the same data to an existing sensor identified by sensors.get(scan.name)
3. plot stuff without ever using i.
Firstly, your animate should naturally make use of the index i: you're trying to do something concerning step i. I can't see i being used anywhere.
Secondly, your animate should be as lightweigh as possible in order to get a smooth animation. Load your data once before plotting, and only handle the drawing differences in animate. This will involve slicing or manipulating your data as a function of i.
Of course if the file really does change from step to step, and this is the actual dynamics in the animation (i.e. i is a dummy variable that is never used), all you need to do is zero-initialize all the plotting data in each step. Start with a clean slate. Then you'll stop seeing the lines corresponding to these artificial jumps in the data. But again, if you want a lightweigh animate, you should look into manipulating the underlying data of existing plots rather than replotting everything all the time (especially since calls to ax1.plot will keep earlier points on the canvas, which is not what you usually want in an animation).
try changing :
ani = animation.FuncAnimation(fig, animate, interval = 15000)
to :
ani = animation.FuncAnimation(fig, animate, interval = 15000, repeat = False)
I'm using Python's multiprocessing lib to speed up some code (least squares fitting with scipy).
It works fine on 3 different machines, but it shows a strange behaviour on a 4th machine.
The code:
import numpy as np
from scipy.optimize import least_squares
import time
import parmap
from multiprocessing import Pool
p0 = [1., 1., 0.5]
def f(p, xx):
return p[0]*np.exp(-xx ** 2 / p[1] ** 2) + p[2]
def errorfunc(p, xx, yy):
return f(p, xx) - yy
def do_fit(yy, xx):
return least_squares(errorfunc, p0[:], args=(xx, yy))
if __name__ == '__main__':
# create data
x = np.linspace(-10, 10, 1000)
y = []
np.random.seed(42)
for i in range(1000):
y.append(f([np.random.rand(1) * 10, np.random.rand(1), 0.], x) + np.random.rand(len(x)))
# fit without multiprocessing
t1 = time.time()
for y_data in y:
p1 = least_squares(errorfunc, p0[:], args=(x, y_data))
t2 = time.time()
print t2 - t1
# fit with multiprocessing lib
times = []
for p in range(1,13):
my_pool = Pool(p)
t3 = time.time()
results = parmap.map(do_fit, y, x, pool=my_pool)
t4 = time.time()
times.append(t4-t3)
my_pool.close()
print times
For the 3 machines where it works, it speeds up roughly in the expected way. E.g. on my i7 laptop it gives:
[4.92650294303894, 2.5883090496063232, 1.7945551872253418, 1.629533052444458,
1.4896039962768555, 1.3550388813018799, 1.1796400547027588, 1.1852478981018066,
1.1404039859771729, 1.2239141464233398, 1.1676840782165527, 1.1416618824005127]
I'm running Ubuntu 14.10, Python 2.7.6, numpy 1.11.0 and scipy 0.17.0.
I tested it on another Ubuntu machine, a Dell PowerEdge R210 with similar results and on a MacBook Pro Retina (here with Python 2.7.11, and same numpy and scipy versions).
The computer that causes issues is a PowerEdge R710 (two hexcores) running Ubuntu 15.10, Python 2.7.11 and same numpy and scipy version as above.
However, I don't observe any speedup. Times are around 6 seconds, no matter what poolsize I use. In fact, it is slightly better for a poolsize of 2 and gets worse for more processes.
htop shows that somehow more processes get spawned than I would expect.
E.g. on my laptop htop shows one entry per process (which matches the poolsize) and eventually each process shows 100% CPU load.
On the PowerEdge R710 I see about 8 python processes for a poolsize of 1 and about 20 processes for a poolsize of 2 etc. each of which shows 100% CPU load.
I checked BIOS settings of the R710 and I couldn't find anything unusual.
What should I look for?
EDIT:
Answering to the comment, I used another simple script. Surprisingly this one seems to 'work' for all machines:
from multiprocessing import Pool
import time
import math
import numpy as np
def f_np(x):
return x**np.sin(x)+np.fabs(np.cos(x))**np.arctan(x)
def f(x):
return x**math.sin(x)+math.fabs(math.cos(x))**math.atan(x)
if __name__ == '__main__':
print "#pool", ", numpy", ", pure python"
for p in range(1,9):
pool = Pool(processes=p)
np.random.seed(42)
a = np.random.rand(1000,1000)
t1 = time.time()
for i in range(5):
pool.map(f_np, a)
t2 = time.time()
for i in range(5):
pool.map(f, range(1000000))
print p, t2-t1, time.time()-t2
pool.close()
gives:
#pool , numpy , pure python
1 1.34186911583 5.87641906738
2 0.697530984879 3.16030216217
3 0.470160961151 2.20742988586
4 0.35701417923 1.73128080368
5 0.308979988098 1.47339701653
6 0.286448001862 1.37223601341
7 0.274246931076 1.27663207054
8 0.245123147964 1.24748778343
on the machine that caused the trouble. There are no more threads (or processes?) spawned than I would expect.
It looks like numpy is not the problem, but as soon as I use scipy.optimize.least_squares the issue arises.
Using on htop on the processes shows a lot of sched_yield() calls which I don't see if I don't use scipy.optimize.least_squares and which I also don't see on my laptop even when using least_squares.
According to here, there is an issue when OpenBLAS is used together with joblib.
Similar issues occur when MKL is used (see here).
The solution given here, also worked for me:
Adding
import os
os.environ['MKL_NUM_THREADS'] = '1'
at the beginning of my python script solves the issue.
I write a small program using web.py, and in one of classes I use numpy/plot.
I found that every first I visit the page , it works fine. but after several minutes, the function of plt.figure() frozen! this function will never return! That's so weird.
please have a look of my codes:
def DrawMapMain(MapParameter,inputfile='out.txt',imgfile='out.png'):
print "DrawMapMain..."
plt.ioff() # turn off interactive mode
plt.close('all')
xmin,xmax,ymin,ymax = MapParameter['xmin'],MapParameter['xmax'],MapParameter['ymin'],MapParameter['ymax']
print('LevelFile:',MapParameter['LevelFile'])
LonCenter = (xmin+xmax)/2.0
LatCenter = (ymin+ymax)/2.0
nx, ny = 200,200
if(not os.path.isfile(inputfile)):
print(u'输入文件%s不存在,请检查!'%(inputfile))
sys.exit(0)
Region = np.loadtxt(inputfile)
#print(Region)
x,y,z = Region[:,1],Region[:,2],Region[:,3]
lon_array = np.linspace(xmin, xmax, nx)
lat_array = np.linspace(ymin, ymax, ny)
print('Data lon/lat box :',x.min(),x.max(),y.min(),y.max())
print(u'离散点插值到网格')
zi,xi,yi = Interpolater.griddata_all(x,y,z,lon_array,lat_array,func='line_rbf')#scipy_idw')# #line_rbf
print(u'扩展矩阵插值: ')
zi,xi,yi,lon_array,lat_array,nx,ny=Interpolater.extened_grid(zi,lon_array,lat_array,zoom=int(2)) #
print(u'mask非绘图区域')
grid1 = Interpolater.build_inside_mask_array(MapParameter['ShapeFile'],lon_array,lat_array)
zi[np.logical_not(grid1)]=np.NaN
#-----------------------------------------------------------------------------------
print(u'Create figure...')
#fig = plt.figure(num=1,figsize=(12, 9), dpi=100)
fig = plt.figure(figsize=(12, 9), dpi=100)
#fig = plt.figure()
print(u'Create figure...Done')
.........skipped
first time I visit the page, I got:
mask非绘图区域
Create figure...
Create figure...Done
(104, 35, 108, 39.5)
this is ok, but after a while, visit again, I got:
mask非绘图区域
Create figure...
and I can see the process 'python' take 25% of my cpu(which have 4 core), that means it falls into a deadloop!
this is my web.py class, , for reference:
class Month:
def POST(self):
form = ParameterForm()
if not form.validates():
return render.Month(form)
else:
StationInfoFile='./StationsId.txt' # make sure this file is exist.
if(not os.path.isfile(StationInfoFile)):
print(u'StationInfoFile 文件%s不存在!'%(inputfile))
sys.exit(0)
StationsInfo = np.loadtxt(StationInfoFile) # load all data as integer and float, not string
StationsId,StationsLon,StationsLat = StationsInfo[:,0].astype(np.int64),StationsInfo[:,1],StationsInfo[:,2]
basedir, DataCats, DataCatsDict=u'D:/测试数据',[ u'逐日平均', u'逐日降水'],{ u'逐日平均':'td', u'逐日降水':'rd'}
iFrom,iEnd= \
int(form['Start Year'].value)*10000+ int(form['Start Month'].value)*100+ int(form['Start Day'].value), \
int(form['End Year'].value)*10000+ int(form['End Month'].value)*100+ int(form['End Day'].value) # value from form is string!
MapParameter=GetMapParameter()
if (u'温度' == form['Data Source'].value):
d=u'逐日平均'
tmpDataTxt='Test_temp.txt'
tmpOutPNG='./static/'+'Test_temp.png'
MapParameter['LevelFile']='.\maplev_temp.LEV'
MapParameter['Title']=u'逐日平均'
elif (u'降水(mm)' == form['Data Source'].value):
d=u'逐日降水'
tmpDataTxt='Test_pred.txt'
tmpOutPNG='./static/'+'Test_temp.png'
MapParameter['LevelFile']='.\maplev_rain.LEV'
MapParameter['Title']=u'逐日降水'
else:
print "form['Data Source'].value=",form['Data Source'].value
print "----------- PROCESSING FOR CATEGORY:",d
tmpMeanVal=[]
for i in range(len(StationsId)):
s,lo,la=StationsId[i],StationsLon[i],StationsLat[i]
#print basedir,d, str(s),DataCatsDict[d]+'.txt'
datafile=os.path.join(basedir,d, str(s))+DataCatsDict[d]+'.txt'
print datafile,iFrom,iEnd
data=getdata.GetData(datafile,iFrom,iEnd)
a=np.mean(np.array(data)[:,1])*0.1
tmpMeanVal.append([s,lo,la,a])
rec=np.array(tmpMeanVal,dtype=[('int','int'),('float','float')])
print 'Writing data ...'
np.savetxt(tmpDataTxt,tmpMeanVal,fmt="%6i %-7.2f %-7.2f %8.2f")
print 'Writing data ... Done.'
sssss=open(tmpDataTxt,'r')
print sssss.read()
sssss.close()
DrawMapMain(MapParameter,inputfile=tmpDataTxt,imgfile=tmpOutPNG)
return render.Reports(tmpOutPNG)
First I suspect that the plt.figure may have some memory leak problems, so I us clf,plt.close('all') at the beginning and end of the function both! I even wrote a segment of test code :
if __name__ == "__main__":
MapParameter=GetMapParameter()
MapParameter['LevelFile']='.\maplev_rain.LEV'
MapParameter['Title']=u'逐日降水'
for iloop in range(0,10):
DrawMapMain(MapParameter,inputfile='Test_pred.txt',imgfile='c:/Test_pred'+str(iloop)+'.png')
MapParameter['LevelFile']='.\maplev_temp.LEV'
MapParameter['Title']=u'逐日temp'
for iloop in range(0,10):
DrawMapMain(MapParameter,inputfile='Test_temp.txt',imgfile='c:/Test_temp'+str(iloop)+'.png')
this code works fine. It's so wired, does anybody know some clue? very thanks!
This question has an answer in the comments:
What matplotlib backend are you using? You should be using one of the non-interactive ones if you're running things from a webserver. E.g. do import matplotlib; matplotlib.use('Agg') before import matplotlib.pyplot as plt. – Joe Kington Jan 16 at 17:56
For more information on matplotlib backends, see: http://matplotlib.org/faq/usage_faq.html#what-is-a-backend – Joe Kington Jan 17 at 15:21
I have a large number of files to process. I have written a script that get, sort and plot the datas I want. So far, so good. I have tested it and it gives the desired result.
Then I wanted to do this using multithreading. I have looked into the doc and examples on the internet, and using one thread in my program works fine. But when I use more, at some point I get random matplotlib error, and I suspect some conflict there, even though I use a function with names for the plots, and iI can't see where the problem could be.
Here is the whole script should you need more comment, i'll add them. Thank you.
#!/usr/bin/python
import matplotlib
matplotlib.use('GTKAgg')
import numpy as np
from scipy.interpolate import griddata
import matplotlib.pyplot as plt
import matplotlib.colors as mcl
from matplotlib import rc #for latex
import time as tm
import sys
import threading
import Queue #queue in 3.2 and Queue in 2.7 !
import pdb #the debugger
rc('text', usetex=True)#for latex
map=0 #initialize the map index. It will be use to index the array like this: array[map,[x,y]]
time=np.zeros(1) #an array to store the time
middle_h=np.zeros((0,3)) #x phi c
#for the middle of the box
current_file=open("single_void_cyl_periodic_phi_c_middle_h_out",'r')
for line in current_file:
if line.startswith('# === time'):
map+=1
np.append(time,[float(line.strip('# === time '))])
elif line.startswith('#'):
pass
else:
v=np.fromstring(line,dtype=float,sep=' ')
middle_h=np.vstack( (middle_h,v[[1,3,4]]) )
current_file.close()
middle_h=middle_h.reshape((map,-1,3)) #3d array: map, x, phi,c
#####
def load_and_plot(): #will load a map file, and plot it along with the corresponding profile loaded before
while not exit_flag:
print("fecthing work ...")
#try:
if not tasks_queue.empty():
map_index=tasks_queue.get()
print("----> working on map: %s" %map_index)
x,y,zp=np.loadtxt("single_void_cyl_growth_periodic_post_map_"+str(map_index),unpack=True, usecols=[1, 2,3])
for i,el in enumerate(zp):
if el<0.:
zp[i]=0.
xv=np.unique(x)
yv=np.unique(y)
X,Y= np.meshgrid(xv,yv)
Z = griddata((x, y), zp, (X, Y),method='nearest')
figure=plt.figure(num=map_index,figsize=(14, 8))
ax1=plt.subplot2grid((2,2),(0,0))
ax1.plot(middle_h[map_index,:,0],middle_h[map_index,:,1],'*b')
ax1.grid(True)
ax1.axis([-15, 15, 0, 1])
ax1.set_title('Profiles')
ax1.set_ylabel(r'$\phi$')
ax1.set_xlabel('x')
ax2=plt.subplot2grid((2,2),(1,0))
ax2.plot(middle_h[map_index,:,0],middle_h[map_index,:,2],'*r')
ax2.grid(True)
ax2.axis([-15, 15, 0, 1])
ax2.set_ylabel('c')
ax2.set_xlabel('x')
ax3=plt.subplot2grid((2,2),(0,1),rowspan=2,aspect='equal')
sub_contour=ax3.contourf(X,Y,Z,np.linspace(0,1,11),vmin=0.)
figure.colorbar(sub_contour,ax=ax3)
figure.savefig('single_void_cyl_'+str(map_index)+'.png')
plt.close(map_index)
tasks_queue.task_done()
else:
print("nothing left to do, other threads finishing,sleeping 2 seconds...")
tm.sleep(2)
# except:
# print("failed this time: %s" %map_index+". Sleeping 2 seconds")
# tm.sleep(2)
#####
exit_flag=0
nb_threads=2
tasks_queue=Queue.Queue()
threads_list=[]
jobs=list(range(map)) #each job is composed of a map
print("inserting jobs in the queue...")
for job in jobs:
tasks_queue.put(job)
print("done")
#launch the threads
for i in range(nb_threads):
working_bee=threading.Thread(target=load_and_plot)
working_bee.daemon=True
print("starting thread "+str(i)+' ...')
threads_list.append(working_bee)
working_bee.start()
#wait for all tasks to be treated
tasks_queue.join()
#flip the flag, so the threads know it's time to stop
exit_flag=1
for t in threads_list:
print("waiting for threads %s to stop..."%t)
t.join()
print("all threads stopped")
Following David's suggestion, I did it in multiprocessing. I get a 5 times speed up with 8 processors. I believe the rest is do to the single-process work at the begining of my script.
edit: However sometimes the script "hangs" at the last map, even though it produces the right maps, with the following error:
File "single_void_cyl_plot_mprocess.py", line 90, in tasks_queue.join()
File "/usr/local/epd-7.0-2-rh5-x86_64/lib/python2.7/multiprocessing/queues.py", line 316, in join self._cond.wait()
File "/usr/local/epd-7.0-2-rh5-x86_64/lib/python2.7/multiprocessing/synchronize.py", line 220, in wait self._wait_semaphore.acquire(True, timeout)
import numpy as np
from scipy.interpolate import griddata
import matplotlib.pyplot as plt
from matplotlib import rc #for latex
from multiprocessing import Process, JoinableQueue
import pdb #the debugger
rc('text', usetex=True)#for latex
map=0 #initialize the map index. It will be use to index the array like this: array[map,x,y,...]
time=np.zeros(1) #an array to store the time
middle_h=np.zeros((0,3)) #x phi c
#for the middle of the box
current_file=open("single_void_cyl_periodic_phi_c_middle_h_out",'r')
for line in current_file.readlines():
if line.startswith('# === time'):
map+=1
np.append(time,[float(line.strip('# === time '))])
elif line.startswith('#'):
pass
else:
v=np.fromstring(line,dtype=float,sep=' ')
middle_h=np.vstack( (middle_h,v[[1,3,4]]) )
current_file.close()
middle_h=middle_h.reshape((map,-1,3)) #3d array: map, x, phi,c
#######
def load_and_plot(): #will load a map file, and plot it along with the corresponding profile loaded before
while tasks_queue.empty()==False:
print("fecthing work ...")
try:
map_index=tasks_queue.get() #get some work to do from the queue
print("----> working on map: %s" %map_index)
x,y,zp=np.loadtxt("single_void_cyl_growth_periodic_post_map_"+str(map_index),\
unpack=True, usecols=[1, 2,3])
for i,el in enumerate(zp):
if el<0.:
zp[i]=0.
xv=np.unique(x)
yv=np.unique(y)
X,Y= np.meshgrid(xv,yv)
Z = griddata((x, y), zp, (X, Y),method='nearest')
figure=plt.figure(num=map_index,figsize=(14, 8))
ax1=plt.subplot2grid((2,2),(0,0))
ax1.plot(middle_h[map_index,:,0],middle_h[map_index,:,1],'*b')
ax1.grid(True)
ax1.axis([-15, 15, 0, 1])
ax1.set_title('Profiles')
ax1.set_ylabel(r'$\phi$')
ax1.set_xlabel('x')
ax2=plt.subplot2grid((2,2),(1,0))
ax2.plot(middle_h[map_index,:,0],middle_h[map_index,:,2],'*r')
ax2.grid(True)
ax2.axis([-15, 15, 0, 1])
ax2.set_ylabel('c')
ax2.set_xlabel('x')
ax3=plt.subplot2grid((2,2), (0,1),rowspan=2,aspect='equal')
sub_contour=ax3.contourf(X,Y,Z,np.linspace(0,1,11),vmin=0.)
figure.colorbar(sub_contour,ax=ax3)
figure.savefig('single_void_cyl_'+str(map_index)+'.png')
plt.close(map_index)
tasks_queue.task_done() #work for this item finished
except:
print("failed this time: %s" %map_index)
#######
nb_proc=8 #number of processes
tasks_queue=JoinableQueue() #a queue to pile up the work to do
jobs=list(range(map)) #each job is composed of a map
print("inserting jobs in the queue...")
for job in jobs:
tasks_queue.put(job)
print("done")
#launch the processes
for i in range(nb_proc):
current_process=Process(target=load_and_plot)
current_process.start()
#wait for all tasks to be treated
tasks_queue.join()