Find the substring avoiding the use of recursive function - python

I am studying algorithms in Python and solving a question that is:
Let x(k) be a recursively defined string with base case x(1) = "123"
and x(k) is "1" + x(k-1) + "2" + x(k-1) + "3". Given three positive
integers k,s, and t, find the substring x(k)[s:t].
For example, if k = 2, s = 1 and t = 5,x(2) = 112321233 and x(2)[1:5]
= 1232.
I have solved it using a simple recursive function:
def generate_string(k):
if k == 1:
return "123"
part = generate_string(k -1)
return ("1" + part + "2" + part + "3")
print(generate_string(k)[s,t])
Although my first approach gives correct answer, the problem is that it takes too long to build string x when k is greater than 20. The program need to be finished within 16 seconds while k is below 50. I have tried to use memoization but it does not help as I am not allowed to cache each test case. I thus think that I must avoid using recursive function to speed up the program. Is there any approaches I should consider?

We can see that the string represented by x(k) grows exponentially in length with increasing k:
len(x(1)) == 3
len(x(k)) == len(x(k-1)) * 2 + 3
So:
len(x(k)) == 3 * (2**k - 1)
For k equal to 100, this amounts to a length of more than 1030. That's more characters than there are atoms in a human body!
Since the parameters s and t will take (in comparison) a tiny, tiny slice of that, you should not need to produce the whole string. You can still use recursion though, but keep passing an s and t range to each call. Then when you see that this slice will actually be outside of the string you would generate, then you can just exit without recursing deeper, saving a lot of time and (string) space.
Here is how you could do it:
def getslice(k, s, t):
def recur(xsize, s, t):
if xsize == 0 or s >= xsize or t <= 0:
return ""
smaller = (xsize - 3) // 2
return ( ("1" if s <= 0 else "")
+ recur(smaller, s-1, t-1)
+ ("2" if s <= smaller+1 < t else "")
+ recur(smaller, s-smaller-2, t-smaller-2)
+ ("3" if t >= xsize else "") )
return recur(3 * (2**k - 1), s, t)
This doesn't use any caching of x(k) results... In my tests this was fast enough.

Based on #FMc's answer, here's some python3 code that calculates x(k, s, t):
from functools import lru_cache
from typing import *
def f_len(k) -> int:
return 3 * ((2 ** k) - 1)
#lru_cache(None)
def f(k) -> str:
if k == 1:
return "123"
return "1" + f(k - 1) + "2" + f(k - 1) + "3"
def substring_(k, s, t, output) -> None:
# Empty substring.
if s >= t or k == 0:
return
# (An optimization):
# If all the characters need to be included, just calculate the string and cache it.
if s == 0 and t == f_len(k):
output.append(f(k))
return
if s == 0:
output.append("1")
sub_len = f_len(k - 1)
substring_(k - 1, max(0, s - 1), min(sub_len, t - 1), output)
if s <= 1 + sub_len < t:
output.append("2")
substring_(k - 1, max(0, s - sub_len - 2), min(sub_len, t - sub_len - 2), output)
if s <= 2 * (1 + sub_len) < t:
output.append("3")
def substring(k, s, t) -> str:
output: List[str] = []
substring_(k, s, t, output)
return "".join(output)
def test(k, s, t) -> bool:
actual = substring(k, s, t)
expected = f(k)[s:t]
return actual == expected
assert test(1, 0, 3)
assert test(2, 2, 6)
assert test(2, 1, 5)
assert test(2, 0, f_len(2))
assert test(3, 0, f_len(3))
assert test(8, 44, 89)
assert test(10, 1001, 2022)
assert test(14, 12345, 45678)
assert test(17, 12345, 112345)
# print(substring(30, 10000, 10100))
print("Tests passed")

This is an interesting problem. I'm not sure whether I'll have time to write the code, but here's an outline of how you can solve it. Note: see the better answer from trincot.
As discussed in the comments, you cannot generate the actual string: you will quickly run out of memory as k grows. But you can easily compute the length of that string.
First some notation:
f(k) : The generated string.
n(k) : The length of f(k).
nk1 : n(k-1), which is used several times in table below.
For discussion purposes, we can divide the string into the following regions. The start/end values use standard Python slice numbering:
Region | Start | End | Len | Subtring | Ex: k = 2
-------------------------------------------------------------------
A | 0 | 1 | 1 | 1 | 0:1 1
B | 1 | 1 + nk1 | nk1 | f(k-1) | 1:4 123
C | 1 + nk1 | 2 + nk1 | 1 | 2 | 4:5 2
D | 2 + nk1 | 2 + nk1 + nk1 | nk1 | f(k-1) | 5:8 123
E | 2 + nk1 + nk1 | 3 + nk1 + nk1 | 1 | 3 | 8:9 3
Given k, s, and t we need to figure out which region of the string is relevant. Take a small example:
k=2, s=6, and t=8.
The substring defined by 6:8 does not require the full f(k). We only need
region D, so we can turn our attention to f(k-1).
To make the shift from k=2 to k=1, we need to adjust s and t: specifically,
we need to subtract the total length of regions A + B + C. For k=2, that
length is 5 (1 + nk1 + 1).
Now we are dealing with: k=1, s=1, and t=3.
Repeat as needed.
Whenever k gets small enough, we stop this nonsense and actually generate the string so we can grab the needed substring directly.
It's possible that some values of s and t could cross region boundaries. In that case, divide the problem into two subparts (one for each region needed). But the general idea is the same.

Here's a commented iterative version in JavaScript that's very easy to convert to Python.
In addition to being what you asked for, that is non-recursive, it allows us to solve things like f(10000, 10000, 10050), which seem to exceed Python default recursion depth.
// Generates the full string
function g(k){
if (k == 1)
return "123";
prev = g(k - 1);
return "1" + prev + "2" + prev + "3";
}
function size(k){
return 3 * ((1 << k) - 1);
}
// Given a depth and index,
// we'd like (1) a string to
// output, (2) the possible next
// part of the same depth to
// push to the stack, and (3)
// possibly the current section
// mapped deeper to also push to
// the stack. (2) and (3) can be
// in a single list.
function getParams(depth, i){
const psize = size(depth - 1);
if (i == 0){
return ["1", [[depth, 1 + psize], [depth - 1, 0]]];
} else if (i < 1 + psize){
return ["", [[depth, 1 + psize], [depth - 1, i - 1]]];
} else if (i == 1 + psize){
return ["2", [[depth, 2 + 2 * psize], [depth - 1, 0]]];
} else if (i < 2 + 2 * psize){
return ["", [[depth, 2 + 2 * psize], [depth - 1, i - 2 - psize]]];
} else {
return ["3", []];
}
}
function f(k, s, t){
let len = t - s;
let str = "";
let stack = [[k, s]];
while (str.length < len){
const [depth, i] = stack.pop();
if (depth == 1){
const toTake = Math.min(3 - i, len - str.length);
str = str + "123".substr(i, toTake);
} else {
const [s, rest] = getParams(depth, i);
str = str + s;
stack.push(...rest);
}
}
return str;
}
function test(k, s, t){
const l = g(k).substring(s, t);
const r = f(k, s, t);
console.log(g(k).length);
//console.log(g(k))
console.log(l);
console.log(r);
console.log(l == r);
}
test(1, 0, 3);
test(2, 2, 6);
test(2, 1, 5);
test(4, 44, 45);
test(5, 30, 40);
test(7, 100, 150);

Related

Efficiently Find Median of an Unordered Set of Data

Background
I was looking into the statistics.median() function (link) in the Standard Library and decided to see how it was implemented in the source code. To my surprise, the median is calculated by sorting the entire data set and returning the "middle value".
Example:
def median(data):
data = sorted(data)
n = len(data)
if n == 0:
raise StatisticsError("no median for empty data")
if n % 2 == 1:
return data[n // 2]
i = n // 2
return (data[i - 1] + data[i]) / 2
This is an okay implementation for smaller data sets, but with large data sets, this can be costly.
So, I went through numerous sources and decided that the algorithm designed by Floyd and Rivest (link) would be the best for finding the median. Some of the other algorithms I saw are:
Quickselect
Introselect
I chose the Floyd-Rivest algorithm because it has an amazing average time complexity and seems resistant to cases such as the Median of 3s Killer Sequence.
Floyd-Rivest Algorithm
Python 3.10 with type hints
from math import (
exp,
log,
sqrt)
def sign(value: int | float) -> int:
return bool(value > 0) - bool(value < 0)
def swap(sequence: list[int | float], x: int, y: int) -> None:
sequence[x], sequence[y] = sequence[y], sequence[x]
return
def floyd_rivest(sequence: list[int | float], left: int, right: int, k: int) -> int | float:
while right > left:
if right - left > 600:
n: int = right - left + 1
i: int = k - left + 1
z: float = log(n)
s: float = 0.5 * exp(2 * z / 3)
sd: float = 0.5 * sqrt(z * s * (n - s) / n) * sign(i - n / 2)
new_left: int = max((left, int(k - i * s / n + sd)))
new_right: int = min((right, int(k + (n - i) * s / n + sd)))
floyd_rivest(sequence, new_left, new_right, k)
t: int | float = sequence[k]
sliding_left: int = left
sliding_right: int = right
swap(sequence, left, k)
if sequence[right] > t:
swap(sequence, left, right)
while sliding_left < sliding_right:
swap(sequence, sliding_left, sliding_right)
sliding_left += 1
sliding_right -= 1
while sequence[sliding_left] < t:
sliding_left += 1
while sequence[sliding_right] > t:
sliding_right -= 1
if sequence[left] == t:
swap(sequence, left, sliding_right)
else:
sliding_right += 1
swap(sequence, right, sliding_right)
if sliding_right <= k:
left = sliding_right + 1
if k <= sliding_right:
right = sliding_right - 1
return sequence[k]
def median(data: Iterable[int | float] | Sequence[int | float]) -> int | float:
sequence: list[int | float] = list(data)
length: int = len(sequence)
end: int = length - 1
midpoint: int = end // 2
if length % 2 == 1:
return floyd_rivest(sequence, 0, end, midpoint)
return (floyd_rivest(sequence, 0, end, midpoint) + floyd_rivest(sequence, 0, end, midpoint + 1)) / 2
Question
Apparently, the Floyd-Rivest algorithm does not perform as well with nondistinct data, for example, a list containing multiple 1s: [1, 1, 1, 1, 2, 3, 4, 5]. However, this was studied and seemingly solved by a person named Krzysztof C. Kiwiel who wrote a paper titled, "On Floyd and Rivest's SELECT algorithm". They modified the algorithm to perform better with nondistinct data.
My question is how would I implement/code Kiwiel's modified Floyd-Rivest algorithm?
Extra Considerations
In Kiwiel's paper, they also mention a non-recursive version of the algorithm. If you feel inclined, it would be nice to have an iterative algorithm to prevent overflowing stack frames (deep recursion). I am aware that a stack can be mimicked, but if you can find a way to rewrite the algorithm in such a way that it is elegantly written iteratively, that would be preferable.
Lastly, any input on speeding up the algorithm or using alternative ("better") algorithms is welcome! (I know Numpy has a median function and I know languages such as C would be more performant, but I am looking to find the "best" algorithm logic and not just generically making it faster)

Volume of pile of cubes

I'm trying a challenge. The idea is the following:
"Your task is to construct a building which will be a pile of n cubes.
The cube at the bottom will have a volume of n^3, the cube above will
have volume of (n-1)^3 and so on until the top which will have a
volume of 1^3.
You are given the total volume m of the building. Being given m can
you find the number n of cubes you will have to build? If no such n
exists return -1"
I saw that apparently:
2³ + 1 = 9 = 3² and 3 - 1 = 2
3³ + 2³ + 1 = 36 = 6² and 6 - 3 = 3
4³ + 3³ + 2³ + 1 = 100 = 10² and 10 - 6 = 4
5³ + 4³ + 3³ + 2³ + 1 = 225 = 15² and 15 - 10 = 5
6³ + 5³ + 4³ + 3³ + 2³ + 1 = 441 = 21² and 21 - 15 = 6
So if I thought, if I check that a certain number is a square root I can already exclude a few. Then I can start a variable at 1 at take that value (incrementing it) from the square root. The values will eventually match or the former square root will become negative.
So I wrote this code:
def find_nb(m):
x = m**0.5
if (x%1==0):
c = 1
while (x != c and x > 0):
x = x - c
c = c + 1
if (x == c):
return c
else:
return -1
return -1
Shouldn't this work? What am I missing?
I fail a third of the sample set, per example: 10170290665425347857 should be -1 and in my program it gives 79863.
Am I missing something obvious?
You're running up against a floating point precision problem. Namely, we have
In [101]: (10170290665425347857)**0.5
Out[101]: 3189089316.0
In [102]: ((10170290665425347857)**0.5) % 1
Out[102]: 0.0
and so the inner branch is taken, even though it's not actually a square:
In [103]: int((10170290665425347857)**0.5)**2
Out[103]: 10170290665425347856
If you borrow one of the many integer square root options from this question and verify that the sqrt squared gives the original number, you should be okay with your algorithm, at least if I haven't overlooked some corner case.
(Aside: you've already noticed the critical pattern. The numbers 1, 3, 6, 10, 15.. are quite famous and have a formula of their own, which you could use to solve for whether there is such a number that works directly.)
DSM's answer is the one, but to add my two cents to improve the solution...
This expression from Brilliant.org is for summing cube numbers:
sum of k**3 from k=1 to n:
n**2 * (n+1)**2 / 4
This can of course be solved for the total volume in question. This here is one of the four solutions (requiring both n and v to be positive):
from math import sqrt
def n(v):
return 1/2*(sqrt(8*sqrt(v) + 1) - 1)
But this function also returns 79863.0. Now, if we sum all the cube numbers from 1 to n, we get a slightly different result due to the precision error:
v = 10170290665425347857
cubes = n(v) # 79863
x = sum([i**3 for i in range(cubes+1)])
# x = 10170290665425347857, original
x -> 10170290665425347856
I don't know if your answer is correct, but I have another solution to this problem which is waaaay easier
def max_level(remain_volume, currLevel):
if remain_volume < currLevel ** 3:
return -1
if remain_volume == currLevel ** 3:
return currLevel
return max_level(remain_volume - currLevel**3, currLevel + 1)
And you find out the answer with max_level(m, 0). It takes O(n) time and O(1) memory.
I have found a simple solution over this in PHP as per my requirement.
function findNb($m) {
$total = 0;
$n = 0;
while($total < $m) {
$n += 1;
$total += $n ** 3;
}
return $total === $m ? $n : -1;
}
In Python it would be:
def find_nb(m):
total = 0
n = 0
while (total < m):
n = n + 1
total = total + n ** 3
return n if total == m else -1

Maximum tip calculator - naive solution

I am working through a Geekforgeeks practice question. I have come up with a naive recursive solution to the "maximum tip calculator" problem.
The problem definition is:
Restaurant recieves N orders. If Rahul takes the ith order, gain
$A[i]. If Ankit takes this order, the tip would be $B[i] One order
per person. Rahul takes max X orders. Ankit takes max Y orders.
X + Y >= N. Find out the maximum possible amount of total tip money
after processing all the orders.
Input:
The first line contains one integer, number of test cases. The second
line contains three integers N, X, Y. The third line contains N
integers. The ith integer represents Ai. The fourth line contains N
integers. The ith integer represents Bi.
Output: Print a single integer representing the maximum tip money they
would receive.
My Code and working sample:
def max_tip(N, A, B, X, Y, n= 0):
if n == len(A) or N == 0:
return 0
if X == 0 and Y > 0: # rahul cannot take more orders
return max(B[n] + max_tip(N - 1, A, B, X, Y - 1, n + 1), # ankit takes the order
max_tip(N, A, B, X, Y, n + 1)) # ankit does not take order
elif Y == 0 and X > 0: # ankit cannot take more orders
return max(A[n] + max_tip(N - 1, A, B, X - 1, Y, n + 1), # rahul takes the order
max_tip(N, A, B, X, Y, n + 1)) # rahul does not take order
elif Y == 0 and X == 0: # neither can take orders
return 0
else:
return max(A[n] + max_tip(N - 1, A, B, X - 1, Y, n + 1), # rahul takes the order
B[n] + max_tip(N - 1, A, B, X, Y - 1, n + 1), #ankit takes the order
max_tip(N, A, B, X, Y, n + 1)) # nobody takes the order
T = int(input())
for i in range(T):
nxy = [int(n) for n in input().strip().split(" ")]
N = nxy[0]
X = nxy[1]
Y = nxy[2]
A = [int(n) for n in input().strip().split(" ")]
B = [int(n) for n in input().strip().split(" ")]
print(max_tip(N, A, B, X, Y))
I've annotated my recursive call decisions. Essentially I extended the naive solution for 0-1 knapsack in another dimension two waiters, either one takes, the other takes, or both do not take the order depending on the orders left constraint.
The solution checker is complaining for the following testcase:
Input:
7 3 3
8 7 15 19 16 16 18
1 7 15 11 12 31 9
Its Correct output is:
110
And Your Code's Output is:
106
This confuses me because the optimal solution seems to be what my code is getting (19 + 16 + 18) + (7 + 15 + 31). The immediate issue seems to be that X + Y < N. My thought is my code should work for the case where X + Y < N as well.
What's going on?
you are considering the case, where nobody takes the tip. But that case doesn't exist as X+Y >= n. This java code worked for me, have a look.
private static int getMaxTip(int x, int y, int n, int[] A, int[] B) {
int[][] dp = new int[x + 1][y + 1];
dp[0][0] = 0;
for (int i = 1;i <= x;i++) {
dp[i][0] = (i <= n) ? dp[i - 1][0] + A[i - 1] : dp[i - 1][0];
}
for (int i = 1;i <= y;i++) {
dp[0][i] = (i <= n) ? dp[0][i - 1] + B[i - 1] : dp[0][i - 1];
}
for (int i = 1;i <= x;i++) {
for (int j = 1;j <= y;j++) {
if (i + j <= n) {
dp[i][j] = Math.max(dp[i - 1][j] + A[i + j - 1], dp[i][j - 1] + B[i + j - 1]);
}
}
}
int ans = Integer.MIN_VALUE;
for (int i = 0;i <= x;i++) {
for (int j = 0;j <= y;j++) {
if (i + j == n) {
ans = Math.max(ans, dp[i][j]);
}
}
}
return ans;
}
You are considering a case when nobody takes the order that should not be considered as it is mentioned in the question that x+y>=n always.Removing that condition will work.
I am assuming, this is your source of question:
https://practice.geeksforgeeks.org/problems/maximum-tip-calculator/0
Here is my solution written in Python that passed all case:
https://github.com/Madhu-Guddana/My-Solutions/blob/master/adhoc/max_tip.py
Explanation:
zip corresponding element of tips and create new array.
Sort the new array based on difference amount value for Rahul and Ankit,
Then we can safely consider the elements from 2 ends of the array, which ever end is giving more profit, add the value to count.

Primitive Calculator - Dynamic Approach

I'm having some trouble getting the correct solution for the following problem:
Your goal is given a positive integer n, find the minimum number of
operations needed to obtain the number n starting from the number 1.
More specifically the test case I have in the comments below.
# Failed case #3/16: (Wrong answer)
# got: 15 expected: 14
# Input:
# 96234
#
# Your output:
# 15
# 1 2 4 5 10 11 22 66 198 594 1782 5346 16038 16039 32078 96234
# Correct output:
# 14
# 1 3 9 10 11 22 66 198 594 1782 5346 16038 16039 32078 96234
# (Time used: 0.10/5.50, memory used: 8601600/134217728.)
def optimal_sequence(n):
sequence = []
while n >= 1:
sequence.append(n)
if n % 3 == 0:
n = n // 3
optimal_sequence(n)
elif n % 2 == 0:
n = n // 2
optimal_sequence(n)
else:
n = n - 1
optimal_sequence(n)
return reversed(sequence)
input = sys.stdin.read()
n = int(input)
sequence = list(optimal_sequence(n))
print(len(sequence) - 1)
for x in sequence:
print(x, end=' ')
It looks like I should be outputting 9 where I'm outputting 4 & 5 but I'm not sure why this isn't the case. What's the best way to troubleshoot this problem?
You are doing a greedy approach.
When n == 10, you check and see if it's divisible by 2 assuming that's the best step, which is wrong in this case.
What you need to do is proper dynamic programming. v[x] will hold the minimum number of steps to get to result x.
def solve(n):
v = [0]*(n+1) # so that v[n] is there
v[1] = 1 # length of the sequence to 1 is 1
for i in range(1,n):
if not v[i]: continue
if v[i+1] == 0 or v[i+1] > v[i] + 1: v[i+1] = v[i] + 1
# Similar for i*2 and i*3
solution = []
while n > 1:
solution.append(n)
if v[n-1] == v[n] - 1: n = n-1
if n%2 == 0 and v[n//2] == v[n] -1: n = n//2
# Likewise for n//3
solution.append(1)
return reverse(solution)
One more solution:
private static List<Integer> optimal_sequence(int n) {
List<Integer> sequence = new ArrayList<>();
int[] arr = new int[n + 1];
for (int i = 1; i < arr.length; i++) {
arr[i] = arr[i - 1] + 1;
if (i % 2 == 0) arr[i] = Math.min(1 + arr[i / 2], arr[i]);
if (i % 3 == 0) arr[i] = Math.min(1 + arr[i / 3], arr[i]);
}
for (int i = n; i > 1; ) {
sequence.add(i);
if (arr[i - 1] == arr[i] - 1)
i = i - 1;
else if (i % 2 == 0 && (arr[i / 2] == arr[i] - 1))
i = i / 2;
else if (i % 3 == 0 && (arr[i / 3] == arr[i] - 1))
i = i / 3;
}
sequence.add(1);
Collections.reverse(sequence);
return sequence;
}
List<Integer> sequence = new ArrayList<Integer>();
while (n>0) {
sequence.add(n);
if (n % 3 == 0&&n % 2 == 0)
n=n/3;
else if(n%3==0)
n=n/3;
else if (n % 2 == 0&& n!=10)
n=n/2;
else
n=n-1;
}
Collections.reverse(sequence);
return sequence;
Here's my Dynamic programming (bottom-up & memoized)solution to the problem:
public class PrimitiveCalculator {
1. public int minOperations(int n){
2. int[] M = new int[n+1];
3. M[1] = 0; M[2] = 1; M[3] = 1;
4. for(int i = 4; i <= n; i++){
5. M[i] = M[i-1] + 1;
6. M[i] = Math.min(M[i], (i %3 == 0 ? M[i/3] + 1 : (i%3 == 1 ? M[(i-1)/3] + 2 : M[(i-2)/3] + 3)));
7. M[i] = Math.min(M[i], i%2 == 0 ? M[i/2] + 1: M[(i-1)/2] + 2);
8. }
9. return M[n];
10. }
public static void main(String[] args) {
System.out.println(new PrimitiveCalculator().minOperations(96234));
}
}
Before going ahead with the explanation of the algorithm I would like to add a quick disclaimer:
A DP solution is not reached at first attempt unless you have good
experience solving lot of DP problems.
Approach to solving through DP
If you are not comfortable with DP problems then the best approach to solve the problem would be following:
Try to get a working brute-force recursive solution.
Once we have a recursive solution, we can look for ways to reduce the recursive step by adding memoization, where in we try remember the solution to the subproblems of smaller size already solved in a recursive step - This is generally a top-down solution.
After memoization, we try to flip the solution around and solve it bottom up (my Java solution above is a bottom-up one)
Once you have done above 3 steps, you have reached a DP solution.
Now coming to the explanation of the solution above:
Given a number 'n' and given only 3 operations {+1, x2, x3}, the minimum number of operations needed to reach to 'n' from 1 is given by recursive formula:
min_operations_to_reach(n) = Math.min(min_operations_to_reach(n-1), min_operations_to_reach(n/2), min_operations_to_reach(n/3))
If we flip up the memoization process and begin with number 1 itself then the above code starts to make better sense.
Starting of with trivial cases of 1, 2, 3
min_operations_to_reach(1) = 0 because we dont need to do any operation.
min_operations_to_reach(2) = 1 because we can either do (1 +1) or (1 x2), in either case number of operations is 1.
Similarly, min_operations_to_reach(3) = 1 because we can multiply 1 by 3 which is one operation.
Now taking any number x > 3, the min_operations_to_reach(x) is the minimum of following 3:
min_operations_to_reach(x-1) + 1 because whatever is the minimum operations to reach (x-1) we can add 1 to it to get the operation count to make it number x.
Or, if we consider making number x from 1 using multiplication by 3 then we have to consider following 3 options:
If x is divisible by 3 then min_operations_to_reach(x/3) + 1,
if x is not divisible by 3 then x%3 can be 1, in which case its min_operations_to_reach((x-1)/3) + 2, +2 because one operation is needed to multiply by 3 and another operation is needed to add 1 to make the number 'x'
Similarly if x%3 == 2, then the value will be min_operations_to_reach((x-2)/3) + 3. +3 because 1 operation to multiply by 3 and then add two 1s subsequently to make the number x.
Or, if we consider making number x from 1 using multiplication by 2 then we have to consider following 2 options:
if x is divisible by 2 then its min_operations_to_reach(x/2) + 1
if x%2 == 1 then its min_operations_to_reach((x-1)/2) + 2.
Taking the minimum of above 3 we can get the minimum number of operations to reach x. Thats what is done in code above in lines 5, 6 and 7.
def DPoptimal_sequence(n,operations):
MinNumOperations=[0]
l_no=[]
l_no2=[]
for i in range(1,n+1):
MinNumOperations.append(None)
for operation in operations:
if operation==1:
NumOperations=MinNumOperations[i-1]+1
if operation==2 and i%2==0:
NumOperations=MinNumOperations[int(i/2)]+1
if operation==3 and i%3==0:
NumOperations=MinNumOperations[int(i/3)]+1
if MinNumOperations[i]==None:
MinNumOperations[i]=NumOperations
elif NumOperations<MinNumOperations[i]:
MinNumOperations[i]=NumOperations
if MinNumOperations[i] == MinNumOperations[i-1]+1:
l_no2.append((i,i-1))
elif MinNumOperations[i] == MinNumOperations[int(i/2)]+1 and i%2 == 0:
l_no2.append((i,int(i/2)))
elif MinNumOperations[i] == MinNumOperations[int(i/3)]+1 and i%3 == 0:
l_no2.append((i,int(i/3)))
l_no.append((i,MinNumOperations[i]-1))
#print(l_no)
#print(l_no2)
x=MinNumOperations[n]-1
#print('x',x)
l_no3=[n]
while n>1:
a,b = l_no2[n-1]
#print(a,b)
if b == a-1:
n = n-1
#print('1111111111111')
#print('n',n)
l_no3.append(n)
elif b == int(a/2) and a%2==0:
n = int(n/2)
#print('22222222222222222')
#print('n',n)
l_no3.append(n)
elif b == int(a/3) and a%3==0:
n = int(n/3)
#print('333333333333333')
#print('n',n)
l_no3.append(n)
#print(l_no3)
return x,l_no3
def optimal_sequence(n):
hop_count = [0] * (n + 1)
hop_count[1] = 1
for i in range(2, n + 1):
indices = [i - 1]
if i % 2 == 0:
indices.append(i // 2)
if i % 3 == 0:
indices.append(i // 3)
min_hops = min([hop_count[x] for x in indices])
hop_count[i] = min_hops + 1
ptr = n
optimal_seq = [ptr]
while ptr != 1:
candidates = [ptr - 1]
if ptr % 2 == 0:
candidates.append(ptr // 2)
if ptr % 3 == 0:
candidates.append(ptr // 3)
ptr = min(
[(c, hop_count[c]) for c in candidates],
key=lambda x: x[1]
)[0]
optimal_seq.append(ptr)
return reversed(optimal_seq)
private int count(int n, Map<Integer, Integer> lookup) {
if(lookup.containsKey(n)) {
return lookup.get(n);
}
if(n==1) {
return 0;
} else {
int result;
if(n%2==0 && n%3==0) {
result =1+
//Math.min(count(n-1, lookup),
Math.min(count(n/2, lookup),
count(n/3, lookup));
} else if(n%2==0) {
result = 1+ Math.min(count(n-1, lookup),
count(n/2, lookup));
} else if(n%3==0) {
result = 1+ Math.min(count(n-1, lookup), count(n/3, lookup));
} else {
result = 1+ count(n-1, lookup);
}
//System.out.println(result);
lookup.put(n, result);
return result;
}
}

Transforming expression given in prefix notation, identifying common subexpressions and dependencies

I am given a bunch of expressions in prefix notation in an ANSI text file. I would like to produce another ANSI text file containing the step-by-step evaluation of these expressions. For example:
- + ^ x 2 ^ y 2 1
should be turned into
t1 = x^2
t2 = y^2
t3 = t1 + t2
t4 = t3 - 1
t4 is the result
I also have to identify common subexpressions. For example given
expression_1: z = ^ x 2
expression_2: - + z ^ y 2 1
expression_3: - z y
I have to generate an output saying that x appears in expressions 1, 2 and 3 (through z).
I have to identify dependecies: expression_1 depends only on x, expression_2 depends on x and y, etc.
The original problem is more difficult than the examples above and I have no control over the input format, it is in prefix notation in a much more complicated way than the above ones.
I already have a working implementation in C++ however it is a lot of pain doing such things in C++.
What programming language is best suited for these type problems?
Could you recommend a tutorial / website / book where I could start?
What keywords should I look for?
UPDATE: Based on the answers, the above examples are somewhat unfortunate, I have unary, binary and n-ary operators in the input. (If you are wondering, exp is an unary operator, sum over a range is an n-ary operator.)
To give you an idea how this would look like in Python, here is some example code:
operators = "+-*/^"
def parse(it, count=1):
token = next(it)
if token in operators:
op1, count = parse(it, count)
op2, count = parse(it, count)
tmp = "t%s" % count
print tmp, "=", op1, token, op2
return tmp, count + 1
return token, count
s = "- + ^ x 2 ^ y 2 1"
a = s.split()
res, dummy = parse(iter(a))
print res, "is the result"
The output is the same as your example output.
This example aside, I think any of the high-level languages you listed are almost equally suited for the task.
The sympy python package does symbolic algebra, including common subexpression elimination and generating evaluation steps for a set of expressions.
See: http://docs.sympy.org/dev/modules/rewriting.html (Look at the cse method at the bottom of the page).
The Python example is elegantly short, but I suspect that you won't actually get enough control over your expressions that way. You're much better off actually building an expression tree, even though it takes more work, and then querying the tree. Here's an example in Scala (suitable for cutting and pasting into the REPL):
object OpParser {
private def estr(oe: Option[Expr]) = oe.map(_.toString).getOrElse("_")
case class Expr(text: String, left: Option[Expr] = None, right: Option[Expr] = None) {
import Expr._
def varsUsed: Set[String] = text match {
case Variable(v) => Set(v)
case Op(o) =>
left.map(_.varsUsed).getOrElse(Set()) ++ right.map(_.varsUsed).getOrElse(Set())
case _ => Set()
}
def printTemp(n: Int = 0, depth: Int = 0): (String,Int) = text match {
case Op(o) =>
val (tl,nl) = left.map(_.printTemp(n,depth+1)).getOrElse(("_",n))
val (tr,nr) = right.map(_.printTemp(nl,depth+1)).getOrElse(("_",n))
val t = "t"+(nr+1)
println(t + " = " + tl + " " + text + " " + tr)
if (depth==0) println(t + " is the result")
(t, nr+1)
case _ => (text, n)
}
override def toString: String = {
if (left.isDefined || right.isDefined) {
"(" + estr(left) + " " + text + " " + estr(right) + ")"
}
else text
}
}
object Expr {
val Digit = "([0-9]+)"r
val Variable = "([a-z])"r
val Op = """([+\-*/^])"""r
def parse(s: String) = {
val bits = s.split(" ")
val parsed = (
if (bits.length > 2 && Variable.unapplySeq(bits(0)).isDefined && bits(1)=="=") {
parseParts(bits,2)
}
else parseParts(bits)
)
parsed.flatMap(p => if (p._2<bits.length) None else Some(p._1))
}
def parseParts(as: Array[String], from: Int = 0): Option[(Expr,Int)] = {
if (from >= as.length) None
else as(from) match {
case Digit(n) => Some(Expr(n), from+1)
case Variable(v) => Some(Expr(v), from+1)
case Op(o) =>
parseParts(as, from+1).flatMap(lhs =>
parseParts(as, lhs._2).map(rhs => (Expr(o,Some(lhs._1),Some(rhs._1)), rhs._2))
)
case _ => None
}
}
}
}
This may be a little much to digest all at once, but then again, this does rather a lot.
Firstly, it's completely bulletproof (note the heavy use of Option where a result might fail). If you throw garbage at it, it will just return None. (With a bit more work, you could make it complain about the problem in an informative way--basically the case Op(o) which then does parseParts nested twice could instead store the results and print out an informative error message if the op didn't get two arguments. Likewise, parse could complain about trailing values instead of just throwing back None.)
Secondly, when you're done with it, you have a complete expression tree. Note that printTemp prints out the temporary variables you wanted, and varsUsed lists the variables used in a particular expression, which you can use to expand to a full list once you parse multiple lines. (You might need to fiddle with the regexp a little if your variables can be more than just a to z.) Note also that the expression tree prints itself out in normal infix notation. Let's look at some examples:
scala> OpParser.Expr.parse("4")
res0: Option[OpParser.Expr] = Some(4)
scala> OpParser.Expr.parse("+ + + + + 1 2 3 4 5 6")
res1: Option[OpParser.Expr] = Some((((((1 + 2) + 3) + 4) + 5) + 6))
scala> OpParser.Expr.parse("- + ^ x 2 ^ y 2 1")
res2: Option[OpParser.Expr] = Some((((x ^ 2) + (y ^ 2)) - 1))
scala> OpParser.Expr.parse("+ + 4 4 4 4") // Too many 4s!
res3: Option[OpParser.Expr] = None
scala> OpParser.Expr.parse("Q#$S!M$#!*)000") // Garbage!
res4: Option[OpParser.Expr] = None
scala> OpParser.Expr.parse("z =") // Assigned nothing?!
res5: Option[OpParser.Expr] = None
scala> res2.foreach(_.printTemp())
t1 = x ^ 2
t2 = y ^ 2
t3 = t1 + t2
t4 = t3 - 1
t4 is the result
scala> res2.map(_.varsUsed)
res10: Option[Set[String]] = Some(Set(x, y))
Now, you could do this in Python also without too much additional work, and in a number of the other languages besides. I prefer to use Scala, but you may prefer otherwise. Regardless, I do recommend creating the full expression tree if you want to retain maximum flexibility for handling tricky cases.
Prefix notation is really simple to do with plain recursive parsers. For instance:
object Parser {
val Subexprs = collection.mutable.Map[String, String]()
val Dependencies = collection.mutable.Map[String, Set[String]]().withDefaultValue(Set.empty)
val TwoArgsOp = "([-+*/^])".r // - at the beginning, ^ at the end
val Ident = "(\\p{Alpha}\\w*)".r
val Literal = "(\\d+)".r
var counter = 1
def getIdent = {
val ident = "t" + counter
counter += 1
ident
}
def makeOp(op: String) = {
val op1 = expr
val op2 = expr
val ident = getIdent
val subexpr = op1 + " " + op + " " + op2
Subexprs(ident) = subexpr
Dependencies(ident) = Dependencies(op1) ++ Dependencies(op2) + op1 + op2
ident
}
def expr: String = nextToken match {
case TwoArgsOp(op) => makeOp(op)
case Ident(id) => id
case Literal(lit) => lit
case x => error("Unknown token "+x)
}
def nextToken = tokens.next
var tokens: Iterator[String] = _
def parse(input: String) = {
tokens = input.trim split "\\s+" toIterator;
counter = 1
expr
if (tokens.hasNext)
error("Input not fully parsed: "+tokens.mkString(" "))
(Subexprs, Dependencies)
}
}
This will generate output like this:
scala> val (subexpressions, dependencies) = Parser.parse("- + ^ x 2 ^ y 2 1")
subexpressions: scala.collection.mutable.Map[String,String] = Map(t3 -> t1 + t2, t4 -> t3 - 1, t1 -> x ^ 2, t2 -> y ^ 2)
dependencies: scala.collection.mutable.Map[String,Set[String]] = Map(t3 -> Set(x, y, t2, 2, t1), t4 -> Set(x, y, t3, t2, 1, 2, t1), t1 -> Set(x, 2), t
2 -> Set(y, 2))
scala> subexpressions.toSeq.sorted foreach println
(t1,x ^ 2)
(t2,y ^ 2)
(t3,t1 + t2)
(t4,t3 - 1)
scala> dependencies.toSeq.sortBy(_._1) foreach println
(t1,Set(x, 2))
(t2,Set(y, 2))
(t3,Set(x, y, t2, 2, t1))
(t4,Set(x, y, t3, t2, 1, 2, t1))
This can be easily expanded. For instance, to handle multiple expression statements you can use this:
object Parser {
val Subexprs = collection.mutable.Map[String, String]()
val Dependencies = collection.mutable.Map[String, Set[String]]().withDefaultValue(Set.empty)
val TwoArgsOp = "([-+*/^])".r // - at the beginning, ^ at the end
val Ident = "(\\p{Alpha}\\w*)".r
val Literal = "(\\d+)".r
var counter = 1
def getIdent = {
val ident = "t" + counter
counter += 1
ident
}
def makeOp(op: String) = {
val op1 = expr
val op2 = expr
val ident = getIdent
val subexpr = op1 + " " + op + " " + op2
Subexprs(ident) = subexpr
Dependencies(ident) = Dependencies(op1) ++ Dependencies(op2) + op1 + op2
ident
}
def expr: String = nextToken match {
case TwoArgsOp(op) => makeOp(op)
case Ident(id) => id
case Literal(lit) => lit
case x => error("Unknown token "+x)
}
def assignment: Unit = {
val ident = nextToken
nextToken match {
case "=" =>
val tmpIdent = expr
Dependencies(ident) = Dependencies(tmpIdent)
Subexprs(ident) = Subexprs(tmpIdent)
Dependencies.remove(tmpIdent)
Subexprs.remove(tmpIdent)
case x => error("Expected assignment, got "+x)
}
}
def stmts: Unit = while(tokens.hasNext) tokens.head match {
case TwoArgsOp(_) => expr
case Ident(_) => assignment
case x => error("Unknown statement starting with "+x)
}
def nextToken = tokens.next
var tokens: BufferedIterator[String] = _
def parse(input: String) = {
tokens = (input.trim split "\\s+" toIterator).buffered
counter = 1
stmts
if (tokens.hasNext)
error("Input not fully parsed: "+tokens.mkString(" "))
(Subexprs, Dependencies)
}
}
Yielding:
scala> val (subexpressions, dependencies) = Parser.parse("""
| z = ^ x 2
| - + z ^ y 2 1
| - z y
| """)
subexpressions: scala.collection.mutable.Map[String,String] = Map(t3 -> z + t2, t5 -> z - y, t4 -> t3 - 1, z -> x ^ 2, t2 -> y ^ 2)
dependencies: scala.collection.mutable.Map[String,Set[String]] = Map(t3 -> Set(x, y, t2, 2, z), t5 -> Set(x, 2, z, y), t4 -> Set(x, y, t3, t2, 1, 2, z
), z -> Set(x, 2), t2 -> Set(y, 2))
scala> subexpressions.toSeq.sorted foreach println
(t2,y ^ 2)
(t3,z + t2)
(t4,t3 - 1)
(t5,z - y)
(z,x ^ 2)
scala> dependencies.toSeq.sortBy(_._1) foreach println
(t2,Set(y, 2))
(t3,Set(x, y, t2, 2, z))
(t4,Set(x, y, t3, t2, 1, 2, z))
(t5,Set(x, 2, z, y))
(z,Set(x, 2))
Ok, since recursive parsers are not your thing, here's an alternative with parse combinators:
object PrefixParser extends JavaTokenParsers {
import scala.collection.mutable
// Maps generated through parsing
val Subexprs = mutable.Map[String, String]()
val Dependencies = mutable.Map[String, Set[String]]().withDefaultValue(Set.empty)
// Initialize, read, parse & evaluate string
def read(input: String) = {
counter = 1
Subexprs.clear
Dependencies.clear
parseAll(stmts, input)
}
// Grammar
def stmts = stmt+
def stmt = assignment | expr
def assignment = (ident <~ "=") ~ expr ^^ assignOp
def expr: P = subexpr | identifier | number
def subexpr: P = twoArgs | nArgs
def twoArgs: P = operator ~ expr ~ expr ^^ twoArgsOp
def nArgs: P = "sum" ~ ("\\d+".r >> args) ^^ nArgsOp
def args(n: String): Ps = repN(n.toInt, expr)
def operator = "[-+*/^]".r
def identifier = ident ^^ (id => Result(id, Set(id)))
def number = wholeNumber ^^ (Result(_, Set.empty))
// Evaluation helper class and types
case class Result(ident: String, dependencies: Set[String])
type P = Parser[Result]
type Ps = Parser[List[Result]]
// Evaluation methods
def assignOp: (String ~ Result) => Result = {
case ident ~ result =>
val value = assign(ident,
Subexprs(result.ident),
result.dependencies - result.ident)
Subexprs.remove(result.ident)
Dependencies.remove(result.ident)
value
}
def assign(ident: String,
value: String,
dependencies: Set[String]): Result = {
Subexprs(ident) = value
Dependencies(ident) = dependencies
Result(ident, dependencies)
}
def twoArgsOp: (String ~ Result ~ Result) => Result = {
case op ~ op1 ~ op2 => makeOp(op, op1, op2)
}
def makeOp(op: String,
op1: Result,
op2: Result): Result = {
val ident = getIdent
assign(ident,
"%s %s %s" format (op1.ident, op, op2.ident),
op1.dependencies ++ op2.dependencies + ident)
}
def nArgsOp: (String ~ List[Result]) => Result = {
case op ~ ops => makeNOp(op, ops)
}
def makeNOp(op: String, ops: List[Result]): Result = {
val ident = getIdent
assign(ident,
"%s(%s)" format (op, ops map (_.ident) mkString ", "),
ops.foldLeft(Set(ident))(_ ++ _.dependencies))
}
var counter = 1
def getIdent = {
val ident = "t" + counter
counter += 1
ident
}
// Debugging helper methods
def printAssignments = Subexprs.toSeq.sorted foreach println
def printDependencies = Dependencies.toSeq.sortBy(_._1) map {
case (id, dependencies) => (id, dependencies - id)
} foreach println
}
This is the kind of results you get:
scala> PrefixParser.read("""
| z = ^ x 2
| - + z ^ y 2 1
| - z y
| """)
res77: PrefixParser.ParseResult[List[PrefixParser.Result]] = [5.1] parsed: List(Result(z,Set(x)), Result(t4,Set(t4, y, t3, t2, z)), Result(t5,Set(z, y
, t5)))
scala> PrefixParser.printAssignments
(t2,y ^ 2)
(t3,z + t2)
(t4,t3 - 1)
(t5,z - y)
(z,x ^ 2)
scala> PrefixParser.printDependencies
(t2,Set(y))
(t3,Set(z, y, t2))
(t4,Set(y, t3, t2, z))
(t5,Set(z, y))
(z,Set(x))
n-Ary operator
scala> PrefixParser.read("""
| x = sum 3 + 1 2 * 3 4 5
| * x x
| """)
res93: PrefixParser.ParseResult[List[PrefixParser.Result]] = [4.1] parsed: List(Result(x,Set(t1, t2)), Result(t4,Set(x, t4)))
scala> PrefixParser.printAssignments
(t1,1 + 2)
(t2,3 * 4)
(t4,x * x)
(x,sum(t1, t2, 5))
scala> PrefixParser.printDependencies
(t1,Set())
(t2,Set())
(t4,Set(x))
(x,Set(t1, t2))
It turns out that this sort of parsing is of interest to me also, so I've done a bit more work on it.
There seems to be a sentiment that things like simplification of expressions is hard. I'm not so sure. Let's take a look at a fairly complete solution. (The printing out of tn expressions is not useful for me, and you've got several Scala examples already, so I'll skip that.)
First, we need to extract the various parts of the language. I'll pick regular expressions, though parser combinators could be used also:
object OpParser {
val Natural = "([0-9]+)"r
val Number = """((?:-)?[0-9]+(?:\.[0-9]+)?(?:[eE](?:-)?[0-9]+)?)"""r
val Variable = "([a-z])"r
val Unary = "(exp|sin|cos|tan|sqrt)"r
val Binary = "([-+*/^])"r
val Nary = "(sum|prod|list)"r
Pretty straightforward. We define the various things that might appear. (I've decided that user-defined variables can only be a single lowercase letter, and that numbers can be floating-point since you have the exp function.) The r at the end means this is a regular expression, and it will give us the stuff in parentheses.
Now we need to represent our tree. There are a number of ways to do this, but I'll choose an abstract base class with specific expressions as case classes, since this makes pattern matching easy. Furthermore, we might want nice printing, so we'll override toString. Mostly, though, we'll use recursive functions to do the heavy lifting.
abstract class Expr {
def text: String
def args: List[Expr]
override def toString = args match {
case l :: r :: Nil => "(" + l + " " + text + " " + r + ")"
case Nil => text
case _ => args.mkString(text+"(", ",", ")")
}
}
case class Num(text: String, args: List[Expr]) extends Expr {
val quantity = text.toDouble
}
case class Var(text: String, args: List[Expr]) extends Expr {
override def toString = args match {
case arg :: Nil => "(" + text + " <- " + arg + ")"
case _ => text
}
}
case class Una(text: String, args: List[Expr]) extends Expr
case class Bin(text: String, args: List[Expr]) extends Expr
case class Nar(text: String, args: List[Expr]) extends Expr {
override def toString = text match {
case "list" =>
(for ((a,i) <- args.zipWithIndex) yield {
"%3d: %s".format(i+1,a.toString)
}).mkString("List[\n","\n","\n]")
case _ => super.toString
}
}
Mostly this is pretty dull--each case class overrides the base class, and the text and args automatically fill in for the def. Note that I've decided that a list is a possible n-ary function, and that it will be printed out with line numbers. (The reason is that if you have multiple lines of input, it's sometimes more convenient to work with them all together as one expression; this lets them be one function.)
Once our data structures are defined, we need to parse the expressions. It's convenient to represent the stuff to parse as a list of tokens; as we parse, we'll return both an expression and the remaining tokens that we haven't parsed--this is a particularly useful structure for recursive parsing. Of course, we might fail to parse anything, so it had better be wrapped in an Option also.
def parse(tokens: List[String]): Option[(Expr,List[String])] = tokens match {
case Variable(x) :: "=" :: rest =>
for ((expr,remains) <- parse(rest)) yield (Var(x,List(expr)), remains)
case Variable(x) :: rest => Some(Var(x,Nil), rest)
case Number(n) :: rest => Some(Num(n,Nil), rest)
case Unary(u) :: rest =>
for ((expr,remains) <- parse(rest)) yield (Una(u,List(expr)), remains)
case Binary(b) :: rest =>
for ((lexp,lrem) <- parse(rest); (rexp,rrem) <- parse(lrem)) yield
(Bin(b,List(lexp,rexp)), rrem)
case Nary(a) :: Natural(b) :: rest =>
val buffer = new collection.mutable.ArrayBuffer[Expr]
def parseN(tok: List[String], n: Int = b.toInt): List[String] = {
if (n <= 0) tok
else {
for ((expr,remains) <- parse(tok)) yield { buffer += expr; parseN(remains, n-1) }
}.getOrElse(tok)
}
val remains = parseN(rest)
if (buffer.length == b.toInt) Some( Nar(a,buffer.toList), remains )
else None
case _ => None
}
Note that we use pattern matching and recursion to do most of the heavy lifting--we pick off part of the list, figure out how many arguments we need, and pass those along recursively. The N-ary operation is a little less friendly, but we create a little recursive function that will parse N things at a time for us, storing the results in a buffer.
Of course, this is a little unfriendly to use, so we add some wrapper functions that let us interface with it nicely:
def parse(s: String): Option[Expr] = parse(s.split(" ").toList).flatMap(x => {
if (x._2.isEmpty) Some(x._1) else None
})
def parseLines(ls: List[String]): Option[Expr] = {
val attempt = ls.map(parse).flatten
if (attempt.length<ls.length) None
else if (attempt.length==1) attempt.headOption
else Some(Nar("list",attempt))
}
Okay, now, what about simplification? One thing we might want to do is numeric simplification, where we precompute the expressions and replace the original expression with the reduced version thereof. That sounds like some sort of a recursive operation--find numbers, and combine them. First we get some helper functions to do calculations on numbers:
def calc(n: Num, f: Double => Double): Num = Num(f(n.quantity).toString, Nil)
def calc(n: Num, m: Num, f: (Double,Double) => Double): Num =
Num(f(n.quantity,m.quantity).toString, Nil)
def calc(ln: List[Num], f: (Double,Double) => Double): Num =
Num(ln.map(_.quantity).reduceLeft(f).toString, Nil)
and then we do the simplification:
def numericSimplify(expr: Expr): Expr = expr match {
case Una(t,List(e)) => numericSimplify(e) match {
case n # Num(_,_) => t match {
case "exp" => calc(n, math.exp _)
case "sin" => calc(n, math.sin _)
case "cos" => calc(n, math.cos _)
case "tan" => calc(n, math.tan _)
case "sqrt" => calc(n, math.sqrt _)
}
case a => Una(t,List(a))
}
case Bin(t,List(l,r)) => (numericSimplify(l), numericSimplify(r)) match {
case (n # Num(_,_), m # Num(_,_)) => t match {
case "+" => calc(n, m, _ + _)
case "-" => calc(n, m, _ - _)
case "*" => calc(n, m, _ * _)
case "/" => calc(n, m, _ / _)
case "^" => calc(n, m, math.pow)
}
case (a,b) => Bin(t,List(a,b))
}
case Nar("list",list) => Nar("list",list.map(numericSimplify))
case Nar(t,list) =>
val simple = list.map(numericSimplify)
val nums = simple.collect { case n # Num(_,_) => n }
if (simple.length == 0) t match {
case "sum" => Num("0",Nil)
case "prod" => Num("1",Nil)
}
else if (nums.length == simple.length) t match {
case "sum" => calc(nums, _ + _)
case "prod" => calc(nums, _ * _)
}
else Nar(t, simple)
case Var(t,List(e)) => Var(t, List(numericSimplify(e)))
case _ => expr
}
Notice again the heavy use of pattern matching to find when we're in a good case, and to dispatch the appropriate calculation.
Now, surely algebraic substitution is much more difficult! Actually, all you need to do is notice that an expression has already been used, and assign a variable. Since the syntax I've defined above allows in-place variable substitution, we can actually just modify our expression tree to include more variable assignments. So we do (edited to only insert variables if the user hasn't):
def algebraicSimplify(expr: Expr): Expr = {
val all, dup, used = new collection.mutable.HashSet[Expr]
val made = new collection.mutable.HashMap[Expr,Int]
val user = new collection.mutable.HashMap[Expr,Expr]
def findExpr(e: Expr) {
e match {
case Var(t,List(v)) =>
user += v -> e
if (all contains e) dup += e else all += e
case Var(_,_) | Num(_,_) => // Do nothing in these cases
case _ => if (all contains e) dup += e else all += e
}
e.args.foreach(findExpr)
}
findExpr(expr)
def replaceDup(e: Expr): Expr = {
if (made contains e) Var("x"+made(e),Nil)
else if (used contains e) Var(user(e).text,Nil)
else if (dup contains e) {
val fixed = replaceDupChildren(e)
made += e -> made.size
Var("x"+made(e),List(fixed))
}
else replaceDupChildren(e)
}
def replaceDupChildren(e: Expr): Expr = e match {
case Una(t,List(u)) => Una(t,List(replaceDup(u)))
case Bin(t,List(l,r)) => Bin(t,List(replaceDup(l),replaceDup(r)))
case Nar(t,list) => Nar(t,list.map(replaceDup))
case Var(t,List(v)) =>
used += v
Var(t,List(if (made contains v) replaceDup(v) else replaceDupChildren(v)))
case _ => e
}
replaceDup(expr)
}
That's it--a fully functional algebraic replacement routine. Note that it builds up sets of expressions that it's seen, keeping special track of which ones are duplicates. Thanks to the magic of case classes, all the equalities are defined for us, so it just works. Then we can replace any duplicates as we recurse through to find them. Note that the replace routine is split in half, and that it matches on an unreplaced version of the tree, but uses a replaced version.
Okay, now let's add a few tests:
def main(args: Array[String]) {
val test1 = "- + ^ x 2 ^ y 2 1"
val test2 = "+ + +" // Bad!
val test3 = "exp sin cos sum 5" // Bad!
val test4 = "+ * 2 3 ^ 3 2"
val test5 = List(test1, test4, "^ y 2").mkString("list 3 "," ","")
val test6 = "+ + x y + + * + x y + 4 5 * + x y + 4 y + + x y + 4 y"
def performTest(test: String) = {
println("Start with: " + test)
val p = OpParser.parse(test)
if (p.isEmpty) println(" Parsing failed")
else {
println("Parsed: " + p.get)
val q = OpParser.numericSimplify(p.get)
println("Numeric: " + q)
val r = OpParser.algebraicSimplify(q)
println("Algebraic: " + r)
}
println
}
List(test1,test2,test3,test4,test5,test6).foreach(performTest)
}
}
How does it do?
$ scalac OpParser.scala; scala OpParser
Start with: - + ^ x 2 ^ y 2 1
Parsed: (((x ^ 2) + (y ^ 2)) - 1)
Numeric: (((x ^ 2) + (y ^ 2)) - 1)
Algebraic: (((x ^ 2) + (y ^ 2)) - 1)
Start with: + + +
Parsing failed
Start with: exp sin cos sum 5
Parsing failed
Start with: + * 2 3 ^ 3 2
Parsed: ((2 * 3) + (3 ^ 2))
Numeric: 15.0
Algebraic: 15.0
Start with: list 3 - + ^ x 2 ^ y 2 1 + * 2 3 ^ 3 2 ^ y 2
Parsed: List[
1: (((x ^ 2) + (y ^ 2)) - 1)
2: ((2 * 3) + (3 ^ 2))
3: (y ^ 2)
]
Numeric: List[
1: (((x ^ 2) + (y ^ 2)) - 1)
2: 15.0
3: (y ^ 2)
]
Algebraic: List[
1: (((x ^ 2) + (x0 <- (y ^ 2))) - 1)
2: 15.0
3: x0
]
Start with: + + x y + + * + x y + 4 5 * + x y + 4 y + + x y + 4 y
Parsed: ((x + y) + ((((x + y) * (4 + 5)) + ((x + y) * (4 + y))) + ((x + y) + (4 + y))))
Numeric: ((x + y) + ((((x + y) * 9.0) + ((x + y) * (4 + y))) + ((x + y) + (4 + y))))
Algebraic: ((x0 <- (x + y)) + (((x0 * 9.0) + (x0 * (x1 <- (4 + y)))) + (x0 + x1)))
So I don't know if that's useful for you, but it turns out to be useful for me. And this is the sort of thing that I would be very hesitant to tackle in C++ because various things that were supposed to be easy ended up being painful instead.
Edit: Here's an example of using this structure to print temporary assignments, just to demonstrate that this structure is perfectly okay for doing such things.
Code:
def useTempVars(expr: Expr): Expr = {
var n = 0
def temp = { n += 1; "t"+n }
def replaceTemp(e: Expr, exempt: Boolean = false): Expr = {
def varify(x: Expr) = if (exempt) x else Var(temp,List(x))
e match {
case Var(t,List(e)) => Var(t,List(replaceTemp(e, exempt = true)))
case Una(t,List(u)) => varify( Una(t, List(replaceTemp(u,false))) )
case Bin(t,lr) => varify( Bin(t, lr.map(replaceTemp(_,false))) )
case Nar(t,ls) => varify( Nar(t, ls.map(replaceTemp(_,false))) )
case _ => e
}
}
replaceTemp(expr)
}
def varCut(expr: Expr): Expr = expr match {
case Var(t,_) => Var(t,Nil)
case Una(t,List(u)) => Una(t,List(varCut(u)))
case Bin(t,lr) => Bin(t, lr.map(varCut))
case Nar(t,ls) => Nar(t, ls.map(varCut))
case _ => expr
}
def getAssignments(expr: Expr): List[Expr] = {
val children = expr.args.flatMap(getAssignments)
expr match {
case Var(t,List(e)) => children :+ expr
case _ => children
}
}
def listAssignments(expr: Expr): List[String] = {
getAssignments(expr).collect(e => e match {
case Var(t,List(v)) => t + " = " + varCut(v)
}) :+ (expr.text + " is the answer")
}
Selected results (from listAssignments(useTempVars(r)).foreach(printf(" %s\n",_))):
Start with: - + ^ x 2 ^ y 2 1
Assignments:
t1 = (x ^ 2)
t2 = (y ^ 2)
t3 = (t1 + t2)
t4 = (t3 - 1)
t4 is the answer
Start with: + + x y + + * + x y + 4 5 * + x y + 4 y + + x y + 4 y
Algebraic: ((x0 <- (x + y)) + (((x0 * 9.0) + (x0 * (x1 <- (4 + y)))) + (x0 + x1)))
Assignments:
x0 = (x + y)
t1 = (x0 * 9.0)
x1 = (4 + y)
t2 = (x0 * x1)
t3 = (t1 + t2)
t4 = (x0 + x1)
t5 = (t3 + t4)
t6 = (x0 + t5)
t6 is the answer
Second edit: finding dependencies is also not too bad.
Code:
def directDepends(expr: Expr): Set[Expr] = expr match {
case Var(t,_) => Set(expr)
case _ => expr.args.flatMap(directDepends).toSet
}
def indirectDepends(expr: Expr) = {
val depend = getAssignments(expr).map(e =>
e -> e.args.flatMap(directDepends).toSet
).toMap
val tagged = for ((k,v) <- depend) yield (k.text -> v.map(_.text))
def percolate(tags: Map[String,Set[String]]): Option[Map[String,Set[String]]] = {
val expand = for ((k,v) <- tags) yield (
k -> (v union v.flatMap(x => tags.get(x).getOrElse(Set())))
)
if (tags.exists(kv => expand(kv._1) contains kv._1)) None // Cyclic dependency!
else if (tags == expand) Some(tags)
else percolate(expand)
}
percolate(tagged)
}
def listDependents(expr: Expr): List[(String,String)] = {
def sayNothing(s: String) = if (s=="") "nothing" else s
val e = expr match {
case Var(_,_) => expr
case _ => Var("result",List(expr))
}
indirectDepends(e).map(_.toList.map(x =>
(x._1, sayNothing(x._2.toList.sorted.mkString(" ")))
)).getOrElse(List((e.text,"cyclic")))
}
And if we add new test cases val test7 = "list 3 z = ^ x 2 - + z ^ y 2 1 w = - z y" and val test8 = "list 2 x = y y = x" and show the answers with for ((v,d) <- listDependents(r)) println(" "+v+" requires "+d) we get (selected results):
Start with: - + ^ x 2 ^ y 2 1
Dependencies:
result requires x y
Start with: list 3 z = ^ x 2 - + z ^ y 2 1 w = - z y
Parsed: List[
1: (z <- (x ^ 2))
2: ((z + (y ^ 2)) - 1)
3: (w <- (z - y))
]
Dependencies:
z requires x
w requires x y z
result requires w x y z
Start with: list 2 x = y y = x
Parsed: List[
1: (x <- y)
2: (y <- x)
]
Dependencies:
result requires cyclic
Start with: + + x y + + * + x y + 4 5 * + x y + 4 y + + x y + 4 y
Algebraic: ((x0 <- (x + y)) + (((x0 * 9.0) + (x0 * (x1 <- (4 + y)))) + (x0 + x1)))
Dependencies:
x0 requires x y
x1 requires y
result requires x x0 x1 y
So I think that on top of this sort of structure, all of your individual requirements are met by blocks of one or two dozen lines of Scala code.
Edit: here's expression evaluation, if you're given a mapping from vars to values:
def numericEvaluate(expr: Expr, initialValues: Map[String,Double]) = {
val chain = new collection.mutable.ArrayBuffer[(String,Double)]
val evaluated = new collection.mutable.HashMap[String,Double]
def note(xv: (String,Double)) { chain += xv; evaluated += xv }
evaluated ++= initialValues
def substitute(expr: Expr): Expr = expr match {
case Var(t,List(n # Num(v,_))) => { note(t -> v.toDouble); n }
case Var(t,_) if (evaluated contains t) => Num(evaluated(t).toString,Nil)
case Var(t,ls) => Var(t,ls.map(substitute))
case Una(t,List(u)) => Una(t,List(substitute(u)))
case Bin(t,ls) => Bin(t,ls.map(substitute))
case Nar(t,ls) => Nar(t,ls.map(substitute))
case _ => expr
}
def recurse(e: Expr): Expr = {
val sub = numericSimplify(substitute(e))
if (sub == e) e else recurse(sub)
}
(recurse(expr), chain.toList)
}
and it's used like so in the testing routine:
val (num,ops) = numericEvaluate(r,Map("x"->3,"y"->1.5))
println("Evaluated:")
for ((v,n) <- ops) println(" "+v+" = "+n)
println(" result = " + num)
giving results like these (with input of x = 3 and y = 1.5):
Start with: list 3 - + ^ x 2 ^ y 2 1 + * 2 3 ^ 3 2 ^ y 2
Algebraic: List[
1: (((x ^ 2) + (x0 <- (y ^ 2))) - 1)
2: 15.0
3: x0
]
Evaluated:
x0 = 2.25
result = List[
1: 10.25
2: 15.0
3: 2.25
]
Start with: list 3 z = ^ x 2 - + z ^ y 2 1 w = - z y
Algebraic: List[
1: (z <- (x ^ 2))
2: ((z + (y ^ 2)) - 1)
3: (w <- (z - y))
]
Evaluated:
z = 9.0
w = 7.5
result = List[
1: 9.0
2: 10.25
3: 7.5
]
The other challenge--picking out the vars that haven't already been used--is just set subtraction off of the dependencies result list. diff is the name of the set subtraction method.
The problem consists of two subproblems: parsing and symbolic manipulation. It seems to me the answer boils down to two possible solutions.
One is to implement everything from scratch: "I do recommend creating the full expression tree if you want to retain maximum flexibility for handling tricky cases." - proposed by Rex. As Sven points out: "any of the high-level languages you listed are almost equally suited for the task," however "Python (or any of the high-level languages you listed) won't take away the complexity of the problem."
I have received very nice solutions in Scala (many thanks for Rex and Daniel), a nice little example in Python (from Sven). However, I am still interested in Lisp, Haskell or Erlang solutions.
The other solution is to use some existing library/software for the task, with all the implied pros and cons. Candidates are Maxima (Common Lisp), SymPy (Python, proposed by payne) and GiNaC (C++).

Categories