Recently I was trying out this problem and my code got 60% of the marks, with the remaining cases returning TLEs.
Bazza and Shazza do not like bugs. They wish to clear out all the bugs
on their garden fence. They come up with a brilliant idea: they buy
some sugar frogs and release them near the fence, letting them eat up
all the bugs.
The plan is a great success and the bug infestation is gone. But
strangely, they now have a sugar frog infestation. Instead of getting
rid of the frogs, Bazza and Shazza decide to set up an obstacle course
and watch the frogs jump along it for their enjoyment.
The fence is a series of \$N\$ fence posts of varying heights. Bazza and
Shazza will select three fence posts to create the obstacle course,
where the middle post is strictly higher than the other two. The frogs
are to jump up from the left post to the middle post, then jump down
from the middle post to the right post. The three posts do not have to
be next to each other as frogs can jump over other fence posts,
regardless of the height of those other posts.
The difficulty of an obstacle course is the height of the first jump
plus the height of the second jump. The height of a jump is equal to
the difference in height between it's two fence posts. Your task is to
help Bazza and Shazza find the most difficult obstacle course for the
frogs to jump.
Input
Your program should read from the file. The file will describe
a single fence.
The first line of input will contain one integer \$N\$: the number of
fence posts. The next \$N\$ lines will each contain one integer \$h_i\$: the
height of the ith fence post. You are guaranteed that there will be at
least one valid obstacle course: that is, there will be at least one
combination of three fence posts where the middle post is strictly
higher than the other two.
Output
Your program should write to the file. Your output file should
contain one line with one integer: the greatest difficulty of any
possible obstacle course.
Constraints
To evaluate your solution, the judges will run your
program against several different input files. All of these files will
adhere to the following bounds:
\$3 \leq N \leq 100,000\$ (the number of fence posts)
\$1 \leq h_i \leq 100,000\$ (the height of each post)
As some of the test cases will be quite large,
you may need to think about how well your solution scales for larger
input values. However, not all the cases will be large. In particular:
For 30% of the marks, \$N \leq 300\$. For an additional 30% of the
marks, \$N \leq 3,000\$. For the remaining 40% of the marks, no special > constraints apply.
Hence, I was wondering if anyone could think of a way to optimize my code (below), or perhaps provide a more elegant, efficient algorithm than the one I am currently using.
Here is my code:
infile = open('frogin.txt', 'r')
outfile = open('frogout.txt', 'w')
N = int(infile.readline())
l = []
for i in range(N):
l.append(int(infile.readline()))
m = 0
#find maximum z-x+z-y such that the middle number z is the largest of x, y, z
for j in range(1, N - 1):
x = min(l[0: j])
y = min(l[j + 1:])
z = l[j]
if x < z and y < z:
n = z - x + z - y
m = n if n > m else m
outfile.write(str(m))
infile.close()
outfile.close()
exit()
If you require additional information regarding my solution or the problem, please do comment below.
Ok, first let's evaluate your program. I created a test file like
from random import randint
n = 100000
max_ = 100000
with open("frogin.txt", "w") as outf:
outf.write(str(n) + "\n")
outf.write("\n".join(str(randint(1, max_)) for _ in range(n)))
then ran your code in IPython like
%load_ext line_profiler
def test():
infile = open('frogin.txt', 'r')
outfile = open('frogout.txt', 'w')
N = int(infile.readline())
l = []
for i in range(N):
l.append(int(infile.readline()))
m = 0
for j in range(1, N - 1):
pre_l = l[0: j] # I split these lines
x = min(pre_l) # for a bit more detail
post_l = l[j + 1:] # on exactly which operations
y = min(post_l) # are taking the most time
z = l[j]
if x < z and y < z:
n = z - x + z - y
m = n if n > m else m
outfile.write(str(m))
infile.close()
outfile.close()
%lprun -f test test() # instrument the `test` function, then run `test()`
which gave
Total time: 197.565 s
File: <ipython-input-37-afa35ce6607a>
Function: test at line 1
Line # Hits Time Per Hit % Time Line Contents
==============================================================
1 def test():
2 1 479 479.0 0.0 infile = open('frogin.txt', 'r')
3 1 984 984.0 0.0 outfile = open('frogout.txt', 'w')
4 1 195 195.0 0.0 N = int(infile.readline())
5 1 2 2.0 0.0 l = []
6 100001 117005 1.2 0.0 for i in range(N):
7 100000 269917 2.7 0.0 l.append(int(infile.readline()))
8 1 2 2.0 0.0 m = 0
9 99999 226984 2.3 0.0 for j in range(1, N - 1):
10 99998 94137525 941.4 12.2 pre_l = l[0: j]
11 99998 300309109 3003.2 38.8 x = min(pre_l)
12 99998 85915575 859.2 11.1 post_l = l[j + 1:]
13 99998 291183808 2911.9 37.7 y = min(post_l)
14 99998 441185 4.4 0.1 z = l[j]
15 99998 212870 2.1 0.0 if x < z and y < z:
16 99978 284920 2.8 0.0 n = z - x + z - y
17 99978 181296 1.8 0.0 m = n if n > m else m
18 1 114 114.0 0.0 outfile.write(str(m))
19 1 170 170.0 0.0 infile.close()
20 1 511 511.0 0.0 outfile.close()
which shows that 23.3% of your time (46 s) is spent repeatedly slicing your array, and 76.5% (151 s) is spent running min() on the slices 200k times.
So - how can we speed this up? Consider
a = min(l[0:50001]) # 50000 comparisons
b = min(l[0:50002]) # 50001 comparisons
c = min(a, l[50001]) # 1 comparison
Here's the magic: b and c are exactly equivalent but b takes something like 10k times longer to run. You have to have a calculated first - but you can repeat the same trick, shifted back by 1, to get a cheaply, and the same for the a's predecessor, and so on.
In one pass from start to end you can keep a running tally of 'minimum value seen previous to this index'. You can then do the same thing from end to start, keeping a running tally of 'minimum value seen after this index'. You can then zip all three arrays together and find the maximum achievable values.
I wrote a quick version,
def test():
ERROR_VAL = 1000000 # too big to be part of any valid solution
# read input file
with open("frogin.txt") as inf:
nums = [int(i) for i in inf.read().split()]
# check contents
n = nums.pop(0)
if len(nums) < n:
raise ValueError("Input file is too short!")
elif len(nums) > n:
raise ValueError("Input file is too long!")
# min_pre[i] == min(nums[:i])
min_pre = [0] * n
min_pre[0] = ERROR_VAL
for i in range(1, n):
min_pre[i] = min(nums[i - 1], min_pre[i - 1])
# min_post[i] == min(nums[i+1:])
min_post = [0] * n
min_post[n - 1] = ERROR_VAL
for i in range(n - 2, -1, -1):
min_post[i] = min(nums[i + 1], min_post[i + 1])
return max((nums[i] - min_pre[i]) + (nums[i] - min_post[i]) for i in range(1, n - 1) if min_pre[i] < nums[i] > min_post[i])
and profiled it,
Total time: 0.300842 s
File: <ipython-input-99-2097216e4420>
Function: test at line 1
Line # Hits Time Per Hit % Time Line Contents
==============================================================
1 def test():
2 1 5 5.0 0.0 ERROR_VAL = 1000000 # too big to be part of any valid solution
3 # read input file
4 1 503 503.0 0.0 with open("frogin.txt") as inf:
5 1 99903 99903.0 8.5 nums = [int(i) for i in inf.read().split()]
6 # check contents
7 1 212 212.0 0.0 n = nums.pop(0)
8 1 7 7.0 0.0 if len(nums) < n:
9 raise ValueError("Input file is too short!")
10 1 2 2.0 0.0 elif len(nums) > n:
11 raise ValueError("Input file is too long!")
12 # min_pre[i] == min(nums[:i])
13 1 994 994.0 0.1 min_pre = [0] * n
14 1 3 3.0 0.0 min_pre[0] = ERROR_VAL
15 100000 162915 1.6 13.8 for i in range(1, n):
16 99999 267593 2.7 22.7 min_pre[i] = min(nums[i - 1], min_pre[i - 1])
17 # min_post[i] == min(nums[i+1:])
18 1 1050 1050.0 0.1 min_post = [0] * n
19 1 3 3.0 0.0 min_post[n - 1] = ERROR_VAL
20 100000 167021 1.7 14.2 for i in range(n - 2, -1, -1):
21 99999 272080 2.7 23.1 min_post[i] = min(nums[i + 1], min_post[i + 1])
22 1 205222 205222.0 17.4 return max((nums[i] - min_pre[i]) + (nums[i] - min_post[i]) for i in range(1, n - 1) if min_pre[i] < nums[i] > min_post[i])
and you can see the run-time for processing 100k values has dropped from 197 s to 0.3 s.
Related
I have "random" points and would like to check which points can be connected by straight lines. Therefore I iterate through a list of points and draw a line at different angles. After all lines at all angles for every single point is drawn, I iterate over each line checking whether they are connecting 3 or more points. If the line connects 3 or more points, it is saved by appending it to a new list (newLines), if not the next line gets tested.
The problem which the following code is that it is way to slow... My testing image took about 30 min and my actual image was not done after about 14 hours. I read about speeding up for loops by using numpy (like in this article). I found plenty of examples for replacing for loops with numpy but in these example it was just simple iterating over a list without declaring the values as variables for usage.
Any hint for speeding up the following code is appreciated, it does not necessarily need to be numpy.
# list for saving rotated lines
lines=[]
for point in points:
# length of line is the diagonal of the point image so it still covers the whole image after rotation
length = sqrt(image.shape[0]**2+image.shape[1]**2)
start = Point(point)
end = Point(start.x+length, start.y)
line = LineString([start,end])
# rotating the generated line for 5 degrees and appeding it to the list
for a in range(0, 360, 5):
angle = np.deg2rad(a)
line = rotate(line, angle, origin=start, use_radians=True)
lines.append(line)
multiLines = MultiLineString(lines)
# list for rotated lines which connect 3 or more points
newLines = []
start = ()
for multiLine in multiLines.geoms:
lst = list(multiLine.coords)
# a: starting point of line | b: ending point of line
a = np.asarray(lst[0])
b = np.asarray(lst[1])
count = 0
# again iterating over point array to check which point is on line
for point in points:
p = np.asarray(point)
# check if point (p) is on line (a - b)
if np.cross(p-a,b-a) == 0:
if count == 0:
start = point
count += 1
else:
end = point
count += 1
if count >= 3:
line = (start, end)
newLines.append(line)
I'm not sure what your current benchmarks are, but you want to try with numpy you can do something like this. I'm using pandas which is a numpy wrapper, but it's effectively doing the same thing
I think this is doing the same thing as you want. I'm looking at each pair of points, calculating the m and c coefficients in the equation y = mx + c through the two points, then checking for cases where these match. I expect you might want some accepted error depending on your input data.
Sorry if I'm way off piste.
import pandas as pd
import numpy as np
import random
import itertools
import time
def get_matches(points):
# get all combinations of two points
combinations_of_points = ([(a[0], a[1], b[0], b[1]) for a, b in itertools.combinations(points, 2) if a != b])
data = pd.DataFrame(combinations_of_points, columns=['x1', 'y1', 'x2', 'y2'])
data['m'] = (data.y1 - data.y2) / (data.x1 - data.x2)
# swap negative gradients so all lines are in same direction
data.loc[np.isfinite(data.m) & data.m < 0, 'm'] = -(1 / data.m)
data.loc[np.isneginf(data.m), 'm'] = -data.m
# y = mx + c
data['c'] = data.y1 - (data.m * data.x1)
data = data.sort_values(['m', 'c', 'x1']).reset_index(drop=True)
# filter to items which are duplicated
filtered = data[
# matching m and c values
(np.isfinite(data.m) & data.duplicated(['m', 'c'], keep=False)) |
# infinite m and x equal (straight line up)
(np.isposinf(data.m) & data.duplicated(['m', 'x1'], keep=False))
]
return filtered
points = [(0, 0), (1, 1), (2, 2)]
print(get_matches(points))
random.seed(1)
count = 500
random_points = [(round(random.random(), 3), round(random.random(), 3)) for i in range(count)]
results = get_matches(random_points)
print(results)
print('\nPerformance with increasing points')
for i in [i ** 2 for i in range(5, 101, 5)]:
random.seed(1)
random_points = [(round(random.random(), 3), round(random.random(), 3)) for i in range(i)]
start = time.perf_counter()
results = get_matches(random_points)
stop = time.perf_counter()
print(f'{i:<9}{stop - start:03f}')
returns:
x1 y1 x2 y2 m c
0 0 0 1 1 1.0 0.0
1 0 0 2 2 1.0 0.0
2 1 1 2 2 1.0 0.0
x1 y1 x2 y2 m c
12243 0.606 0.262 0.400 0.880 -3.0 2.080
12244 0.606 0.262 0.440 0.760 -3.0 2.080
12251 0.378 0.970 0.506 0.586 -3.0 2.104
12252 0.505 0.589 0.378 0.970 -3.0 2.104
12253 0.505 0.589 0.506 0.586 -3.0 2.104
... ... ... ... ... ... ...
124741 0.971 0.382 0.971 0.716 inf -inf
124742 0.971 0.543 0.971 0.716 inf -inf
124744 0.983 0.593 0.983 0.296 inf -inf
124745 0.983 0.593 0.983 0.448 inf -inf
124746 0.983 0.296 0.983 0.448 inf -inf
[237 rows x 6 columns]
Performance with increasing points
25 0.010577
100 0.016897
225 0.045443
400 0.136834
625 0.338148
900 0.765913
1225 1.525819
1600 2.645753
2025 4.834811
2500 8.112012
3025 12.960043
3600 18.262522
4225 27.221498
4900 37.329662
5625 53.064736
6400 67.325213
7225 84.843119
8100 116.864120
9025 140.131420
10000 171.630961
As one of you comments pointed out earlier, the order of growth of the problem is approximately N ^ 2 because it is look at all the combinations of points so the performance very quickly degrades with increasing numbers of points. Note you could use this relationship to estimate how long it would take for your program to run if you know the number of points.
I am trying to compare the running time of factorial functions of different implementations. However, I found that the tail-recursive version of the factorial function is much slower than the iterative version and non-tail-recursive version and I can't figure out an explanation for this.
Here is my code implementation. I am using Python 3.7.4 to test the code.
import sys
from line_profiler import LineProfiler
sys.setrecursionlimit(20000)
def fact_iter(n):
"""Return the factorial of n, using iteration"""
product = 1
for i in range(2, n + 1):
product *= i
return product
def fact_non_tail_recur(n):
"""Return the factorial of n, using non-tail recursion"""
if n == 1:
return 1
product = n * fact_non_tail_recur(n - 1)
return product
def fact_tail_recur(n, product=1):
"""Return the factorial of n, using tail recursion"""
if n == 1:
return product
product *= n
return fact_tail_recur(n - 1, product)
def fact_tail_recur_2(n, i=1, product=1):
if i == n:
return product * i
product *= i
return fact_tail_recur_2(n, i+1, product)
def profile(f):
lp = LineProfiler()
lp_wrapper = lp(f)
lp_wrapper(10000)
lp.print_stats()
if __name__ == '__main__':
profile(fact_iter)
profile(fact_non_tail_recur)
profile(fact_tail_recur)
profile(fact_tail_recur_2)
And here is the running time profile of the functions.
Timer unit: 1e-06 s
Total time: 0.040521 s
File: fact.py
Function: fact_iter at line 6
Line # Hits Time Per Hit % Time Line Contents
==============================================================
6 def fact_iter(n):
7 """Return the factorial of n, using iteration"""
8 1 2.0 2.0 0.0 product = 1
9 10000 13513.0 1.4 33.3 for i in range(2, n + 1):
10 9999 27005.0 2.7 66.6 product *= i
11 1 1.0 1.0 0.0 return product
Timer unit: 1e-06 s
Total time: 0.042846 s
File: fact.py
Function: fact_non_tail_recur at line 14
Line # Hits Time Per Hit % Time Line Contents
==============================================================
14 def fact_non_tail_recur(n):
15 """Return the factorial of n, using non-tail recursion"""
16 10000 13389.0 1.3 31.2 if n == 1:
17 1 2.0 2.0 0.0 return 1
18 9999 16481.0 1.6 38.5 product = n * fact_non_tail_recur(n - 1)
19 9999 12974.0 1.3 30.3 return product
Timer unit: 1e-06 s
Total time: 0.085538 s
File: fact.py
Function: fact_tail_recur at line 22
Line # Hits Time Per Hit % Time Line Contents
==============================================================
22 def fact_tail_recur(n, product=1):
23 """Return the factorial of n, using tail recursion"""
24 10000 13812.0 1.4 16.1 if n == 1:
25 1 2.0 2.0 0.0 return product
26 9999 55390.0 5.5 64.8 product *= n
27 9999 16334.0 1.6 19.1 return fact_tail_recur(n - 1, product)
Timer unit: 1e-06 s
Total time: 0.07916 s
File: fact.py
Function: fact_tail_recur_2 at line 30
Line # Hits Time Per Hit % Time Line Contents
==============================================================
30 def fact_tail_recur_2(n, i=1, product=1):
31 10000 13521.0 1.4 17.1 if i == n:
32 1 12.0 12.0 0.0 return product * i
33 9999 49390.0 4.9 62.4 product *= i
34 9999 16237.0 1.6 20.5 return fact_tail_recur_2(n, i+1, product)
The graph of running time measured by timeit
I am working on a problem where I have been asked to a) output Fibonacci numbers in a sequence based on user input, as I have done below, and b) divide and print the ratio of the two most recent terms.
fixed_start = [0, 1]
def fib(fixed_start, n):
if n == 0:
return fixed_start
else:
fixed_start.append(fixed_start[-1] + fixed_start[-2])
return fib(fixed_start, n -1)
numb = int(input('How many terms: '))
fibonacci_list = fib(fixed_start, numb)
print(fibonacci_list[:-1])
I would like for my output to look something like the below:
"How many terms:" 3
1 1
the ratio is 1.0
1 2
the ratio is 2.0
2 3
the ratio is 1.5
Are you looking for ratio of the last 2 items in the list? If yes, this should work.
print(fibonacci_list[-2:])
print(float(fibonacci_list[-1]/fibonacci_list[-2]))
Or, if you are looking for ratio between every 2 numbers (except 0 & 1 right at the start), the below code should do the trick
for x,y in zip(fibonacci_list[1:],fibonacci_list[2:]):
print(x,y)
print('the ratio is ' + str(round((y/x),3)))
output is something like below for a fibonacci list of 15 terms
1 1
the ratio is 1.0
1 2
the ratio is 2.0
2 3
the ratio is 1.5
3 5
the ratio is 1.667
5 8
the ratio is 1.6
8 13
the ratio is 1.625
13 21
the ratio is 1.615
21 34
the ratio is 1.619
34 55
the ratio is 1.618
55 89
the ratio is 1.618
89 144
the ratio is 1.618
144 233
the ratio is 1.618
233 377
the ratio is 1.618
377 610
the ratio is 1.618
610 987
the ratio is 1.618
As you have already solved the part one of generating the Fibonacci series in the form of a list, you can access the last two elements (most recent) from it and take their ratio. Python allows us to access the elements of the list from backwards using the negative indexing
def fibonacci_ratio(fibonacci_list):
last_element = fibonacci_list[-1]
second_last_element = fibonacci_list[-2]
ratio = last_element//second_last_element
return ratio
The double // in python will ensure floating point division.
Hope this helps!
I'm fairly new to python, so I don't know all the tips and tricks quite yet. But I'm trying to read in data line by line from a file, then into a numpy array. I have to read it in line by line in this manner, but I have freedom when it comes to moving that data into the array. Here is the relevant code:
xyzi_point_array = np.zeros((0,4))
x_list = []
y_list = []
z_list = []
i_list = []
points_read = 0
while True: #FOR EVERY LINE DO:
line = decryptLine(inFile.readline()) #grabs the next line of data
if not line: break
.
.
.
index = 0
for entry in line: #FOR EVERY VALUE IN THE LINE
x_list.append(X)
y_list.append(Y)
z_list.append(z_catalog[index])
i_list.append(entry)
index += 1
points_read += 1
xyzi_point_array = np.zeros((points_read,4))
xyzi_point_array[:,0] = x_list
xyzi_point_array[:,1] = y_list
xyzi_point_array[:,2] = z_list
xyzi_point_array[:,3] = i_list
Where X and Y are scalars which are different for each line, and where z_catalog is a 1D numpy array.
For smaller data sets, the imbedded for loop is the biggest draw, with the xyzi_point_array[points_read,:] = line taking the majority of processor time. However with larger data sets, working with tempArr to expand xyzi_point_array becomes the worst, so I'll need to optimize both.
Any ideas? General tips on how to better handle numpy arrays are also welcome, I come from a C++ background, and am probably not handling these arrays in the most pythonic way..
For reference, here is the lineprofiler readout for this bit of the code:
Line # Hits Time Per Hit % Time Line Contents
==============================================================
138 150 233 1.6 0.0 index = 0
139 489600 468293 1.0 11.6 for entry in line: #FOR EVERY VALUE IN THE LINE
140 489450 457227 0.9 11.4 x_list.append(lineX)
141 489450 441687 0.9 11.0 y_list.append(lineY)
142 489450 541891 1.1 13.5 z_list.append(z_catalog[index])
143 489450 450191 0.9 11.2 i_list.append(entry)
144 489450 421573 0.9 10.5 index += 1
145 489450 408764 0.8 10.2 points_read += 1
146
149 1 78 78.0 0.0 xyzi_point_array = np.zeros((points_read,4))
150 1 39539 39539.0 1.0 xyzi_point_array[:,0] = x_list
151 1 33876 33876.0 0.8 xyzi_point_array[:,1] = y_list
152 1 48619 48619.0 1.2 xyzi_point_array[:,2] = z_list
153 1 47219 47219.0 1.2 xyzi_point_array[:,3] = i_list
I'm trying to work out how to speed up a Python function which uses numpy. The output I have received from lineprofiler is below, and this shows that the vast majority of the time is spent on the line ind_y, ind_x = np.where(seg_image == i).
seg_image is an integer array which is the result of segmenting an image, thus finding the pixels where seg_image == i extracts a specific segmented object. I am looping through lots of these objects (in the code below I'm just looping through 5 for testing, but I'll actually be looping through over 20,000), and it takes a long time to run!
Is there any way in which the np.where call can be speeded up? Or, alternatively, that the penultimate line (which also takes a good proportion of the time) can be speeded up?
The ideal solution would be to run the code on the whole array at once, rather than looping, but I don't think this is possible as there are side-effects to some of the functions I need to run (for example, dilating a segmented object can make it 'collide' with the next region and thus give incorrect results later on).
Does anyone have any ideas?
Line # Hits Time Per Hit % Time Line Contents
==============================================================
5 def correct_hot(hot_image, seg_image):
6 1 239810 239810.0 2.3 new_hot = hot_image.copy()
7 1 572966 572966.0 5.5 sign = np.zeros_like(hot_image) + 1
8 1 67565 67565.0 0.6 sign[:,:] = 1
9 1 1257867 1257867.0 12.1 sign[hot_image > 0] = -1
10
11 1 150 150.0 0.0 s_elem = np.ones((3, 3))
12
13 #for i in xrange(1,seg_image.max()+1):
14 6 57 9.5 0.0 for i in range(1,6):
15 5 6092775 1218555.0 58.5 ind_y, ind_x = np.where(seg_image == i)
16
17 # Get the average HOT value of the object (really simple!)
18 5 2408 481.6 0.0 obj_avg = hot_image[ind_y, ind_x].mean()
19
20 5 333 66.6 0.0 miny = np.min(ind_y)
21
22 5 162 32.4 0.0 minx = np.min(ind_x)
23
24
25 5 369 73.8 0.0 new_ind_x = ind_x - minx + 3
26 5 113 22.6 0.0 new_ind_y = ind_y - miny + 3
27
28 5 211 42.2 0.0 maxy = np.max(new_ind_y)
29 5 143 28.6 0.0 maxx = np.max(new_ind_x)
30
31 # 7 is + 1 to deal with the zero-based indexing, + 2 * 3 to deal with the 3 cell padding above
32 5 217 43.4 0.0 obj = np.zeros( (maxy+7, maxx+7) )
33
34 5 158 31.6 0.0 obj[new_ind_y, new_ind_x] = 1
35
36 5 2482 496.4 0.0 dilated = ndimage.binary_dilation(obj, s_elem)
37 5 1370 274.0 0.0 border = mahotas.borders(dilated)
38
39 5 122 24.4 0.0 border = np.logical_and(border, dilated)
40
41 5 355 71.0 0.0 border_ind_y, border_ind_x = np.where(border == 1)
42 5 136 27.2 0.0 border_ind_y = border_ind_y + miny - 3
43 5 123 24.6 0.0 border_ind_x = border_ind_x + minx - 3
44
45 5 645 129.0 0.0 border_avg = hot_image[border_ind_y, border_ind_x].mean()
46
47 5 2167729 433545.8 20.8 new_hot[seg_image == i] = (new_hot[ind_y, ind_x] + (sign[ind_y, ind_x] * np.abs(obj_avg - border_avg)))
48 5 10179 2035.8 0.1 print obj_avg, border_avg
49
50 1 4 4.0 0.0 return new_hot
EDIT I have left my original answer at the bottom for the record, but I have actually looked into your code in more detail over lunch, and I think that using np.where is a big mistake:
In [63]: a = np.random.randint(100, size=(1000, 1000))
In [64]: %timeit a == 42
1000 loops, best of 3: 950 us per loop
In [65]: %timeit np.where(a == 42)
100 loops, best of 3: 7.55 ms per loop
You could get a boolean array (that you can use for indexing) in 1/8 of the time you need to get the actual coordinates of the points!!!
There is of course the cropping of the features that you do, but ndimage has a find_objects function that returns enclosing slices, and appears to be very fast:
In [66]: %timeit ndimage.find_objects(a)
100 loops, best of 3: 11.5 ms per loop
This returns a list of tuples of slices enclosing all of your objects, in 50% more time thn it takes to find the indices of one single object.
It may not work out of the box as I cannot test it right now, but I would restructure your code into something like the following:
def correct_hot_bis(hot_image, seg_image):
# Need this to not index out of bounds when computing border_avg
hot_image_padded = np.pad(hot_image, 3, mode='constant',
constant_values=0)
new_hot = hot_image.copy()
sign = np.ones_like(hot_image, dtype=np.int8)
sign[hot_image > 0] = -1
s_elem = np.ones((3, 3))
for j, slice_ in enumerate(ndimage.find_objects(seg_image)):
hot_image_view = hot_image[slice_]
seg_image_view = seg_image[slice_]
new_shape = tuple(dim+6 for dim in hot_image_view.shape)
new_slice = tuple(slice(dim.start,
dim.stop+6,
None) for dim in slice_)
indices = seg_image_view == j+1
obj_avg = hot_image_view[indices].mean()
obj = np.zeros(new_shape)
obj[3:-3, 3:-3][indices] = True
dilated = ndimage.binary_dilation(obj, s_elem)
border = mahotas.borders(dilated)
border &= dilated
border_avg = hot_image_padded[new_slice][border == 1].mean()
new_hot[slice_][indices] += (sign[slice_][indices] *
np.abs(obj_avg - border_avg))
return new_hot
You would still need to figure out the collisions, but you could get about a 2x speed-up by computing all the indices simultaneously using a np.unique based approach:
a = np.random.randint(100, size=(1000, 1000))
def get_pos(arr):
pos = []
for j in xrange(100):
pos.append(np.where(arr == j))
return pos
def get_pos_bis(arr):
unq, flat_idx = np.unique(arr, return_inverse=True)
pos = np.argsort(flat_idx)
counts = np.bincount(flat_idx)
cum_counts = np.cumsum(counts)
multi_dim_idx = np.unravel_index(pos, arr.shape)
return zip(*(np.split(coords, cum_counts) for coords in multi_dim_idx))
In [33]: %timeit get_pos(a)
1 loops, best of 3: 766 ms per loop
In [34]: %timeit get_pos_bis(a)
1 loops, best of 3: 388 ms per loop
Note that the pixels for each object are returned in a different order, so you can't simply compare the returns of both functions to assess equality. But they should both return the same.
One thing you could do to same a little bit of time is to save the result of seg_image == i so that you don't need to compute it twice. You're computing it on lines 15 & 47, you could add seg_mask = seg_image == i and then reuse that result (It might also be good to separate out that piece for profiling purposes).
While there a some other minor things that you could do to eke out a little bit of performance, the root issue is that you're using a O(M * N) algorithm where M is the number of segments and N is the size of your image. It's not obvious to me from your code whether there is a faster algorithm to accomplish the same thing, but that's the first place I'd try and look for a speedup.