I'm trying to translate the recursive python code for tarjan algorithm to scala and especially this part :
def tarjan_recursive(g):
S = []
S_set = set()
index = {}
lowlink = {}
ret = []
def visit(v):
index[v] = len(index)
lowlink[v] = index[v]
S.append(v)
S_set.add(v)
for w in g.get(v,()):
print(w)
if w not in index:
visit(w)
lowlink[v] = min(lowlink[w], lowlink[v])
elif w in S_set:
lowlink[v] = min(lowlink[v], index[w])
if lowlink[v] == index[v]:
scc = []
w = None
while v != w:
w = S.pop()
scc.append(w)
S_set.remove(w)
ret.append(scc)
for v in g:
print(index)
if not v in index:
visit(v)
return ret
I know that there's tarjan algorithm in scala here or here but it doesn't return good result and translate it from python help me understand it.
Here's what I have :
def tj_recursive(g: Map[Int,List[Int]])= {
var s : mutable.ListBuffer[Int] = new mutable.ListBuffer()
var s_set : mutable.Set[Int] = mutable.Set()
var index : mutable.Map[Int,Int] = mutable.Map()
var lowlink : mutable.Map[Int,Int]= mutable.Map()
var ret : mutable.Map[Int,mutable.ListBuffer[Int]]= mutable.Map()
def visit(v: Int):Int = {
index(v) = index.size
lowlink(v) = index(v)
var zz :List[Int]= gg.get(v).toList(0)
for( w <- zz) {
if( !(index.contains(w)) ){
visit(w)
lowlink(v) = List(lowlink(w),lowlink(v)).min
}else if(s_set.contains(w)){
lowlink(v)=List(lowlink(v),index(w)).min
}
}
if(lowlink(v)==index(v)){
var scc:mutable.ListBuffer[Int] = new mutable.ListBuffer()
var w:Int=null.asInstanceOf[Int]
while(v!=w){
w= s.last
scc+=w
s_set-=w
}
ret+=scc
}
}
for( v <- g) {if( !(index.contains(v)) ){visit(v)}}
ret
}
I know this isn't the scala way at all (and not clean ...) but I'm planning to slowly change it to a more functional style when I get the first version working.
For now, I got this error :
type mismatch; found : Unit required: Int
at this line
if(lowlink(v)==index(v)){
I think it's coming from this line but I'm not sure :
if( !(index.contains(w))
But it's really hard to debug it since I can't just println my mistakes ...
Thanks !
Here's a fairly literal translation of the Python:
def tj_recursive(g: Map[Int, List[Int]])= {
val s = mutable.Buffer.empty[Int]
val s_set = mutable.Set.empty[Int]
val index = mutable.Map.empty[Int, Int]
val lowlink = mutable.Map.empty[Int, Int]
val ret = mutable.Buffer.empty[mutable.Buffer[Int]]
def visit(v: Int): Unit = {
index(v) = index.size
lowlink(v) = index(v)
s += v
s_set += v
for (w <- g(v)) {
if (!index.contains(w)) {
visit(w)
lowlink(v) = math.min(lowlink(w), lowlink(v))
} else if (s_set(w)) {
lowlink(v) = math.min(lowlink(v), index(w))
}
}
if (lowlink(v) == index(v)) {
val scc = mutable.Buffer.empty[Int]
var w = -1
while(v != w) {
w = s.remove(s.size - 1)
scc += w
s_set -= w
}
ret += scc
}
}
for (v <- g.keys) if (!index.contains(v)) visit(v)
ret
}
It produces the same output on e.g.:
tj_recursive(Map(
1 -> List(2), 2 -> List(1, 5), 3 -> List(4),
4 -> List(3, 5), 5 -> List(6), 6 -> List(7),
7 -> List(8), 8 -> List(6, 9), 9 -> Nil
))
The biggest problem with your implementation was the return type of visit (which should have been Unit, not Int) and the fact that you were iterating over the graph's items instead of the graph's keys in the final for-comprehension, but I've made a number of other edits for style and clarity (while still keeping the basic shape).
Here is an iterative version. It is a translation from the recursive version of the algorithm in Wikipedia.
case class Arc[A](from:A, to:A)
class SparseDG[A](src: Iterable[Arc[A]]) {
val verts = (src.map(_.from) ++ src.map(_.to)).toSet.toIndexedSeq
val qVert = verts.size
val vertMap = verts.zipWithIndex.toMap
val indexedSrc = src.map{ arc => Arc(vertMap(arc.from), vertMap(arc.to)) }
val exit = (0 until qVert)
.map(v => indexedSrc.filter(_.from == v).map(_.to).toIndexedSeq)
lazy val tarjan_iterative: Seq[Seq[A]] = {
trait Step
case object SetDepth extends Step
case object ConsiderSuccessors extends Step
case object CalcLowlink extends Step
case object PopIfRoot extends Step
case class StackFrame(v:Int, next:Step)
val result = Buffer[Seq[A]]()
val index = new Array[Int](qVert).map(_ => -1) // -1 = undefined
val lowlink = new Array[Int](qVert).map(_ => -1) // -1 = undefined
val wIndex = new Array[Int](qVert) // used to iterate w nodes
var _index = 0
val s = Stack[Int]()
val isRemoved = BitSet()
val strongconnect = Stack[StackFrame]()
(0 until qVert).foreach { v_idx =>
if(index(v_idx) == -1) {
strongconnect.push(StackFrame(v_idx, SetDepth))
while(!strongconnect.isEmpty) {
val StackFrame(v, step) = strongconnect.pop()
step match {
case SetDepth =>
index(v) = _index
lowlink(v) = _index
_index += 1
s.push(v)
isRemoved.remove(v)
strongconnect.push(StackFrame(v, ConsiderSuccessors))
case ConsiderSuccessors =>
if(wIndex(v) < exit(v).size){
val w = exit(v)(wIndex(v))
if(index(w) == -1){
strongconnect.push(StackFrame(v, CalcLowlink))
strongconnect.push(StackFrame(w, SetDepth))
}
else{
if(!isRemoved.contains(w)){
if(lowlink(v) > lowlink(w)) lowlink(v) = index(w)
}
wIndex(v) += 1
strongconnect.push(StackFrame(v, ConsiderSuccessors))
}
}
else{
strongconnect.push(StackFrame(v, PopIfRoot))
}
case CalcLowlink =>
val w = exit(v)(wIndex(v))
if(lowlink(v) > lowlink(w)) lowlink(v) = lowlink(w)
wIndex(v) += 1
strongconnect.push(StackFrame(v, ConsiderSuccessors))
case PopIfRoot =>
if(index(v) == lowlink(v)){
val buf = Buffer[A]()
var w = 0
do{
w = s.pop()
isRemoved += w
buf += verts(w)
}
while(w != v)
result += buf.toSeq
}
}
}
}
}
result.toSeq
}
lazy val hasCycle = tarjan_iterative.find(_.size >= 2).isDefined
lazy val topologicalSort =
if(hasCycle) None
else Some(tarjan_iterative.flatten.reverse)
}
Running the example graph in the Wikipedia article:
val g = new SparseDG(Seq(
Arc("1","2"),
Arc("2","3"),
Arc("3","1"),
Arc("4","2"),
Arc("4","3"),
Arc("6","3"),
Arc("6","7"),
Arc("7","6"),
Arc("4","5"),
Arc("5","4"),
Arc("5","6"),
Arc("8","5"),
Arc("8","8"),
Arc("8","7")
))
g.tarjan_iterative
returns:
ArrayBuffer(ArrayBuffer(1, 3, 2), ArrayBuffer(7, 6), ArrayBuffer(4, 5), ArrayBuffer(8))
I know this post is old, but I have lately been working with the implementation of Tarjans algorithm in Scala. In the implementation of the code, I was looking at this post and it occurred to me, that it could be done in a simpler way:
case class Edge[A](from: A, to: Set[A])
class TarjanGraph[A](src: Iterable[Edge[A]]) {
lazy val trajan: mutable.Buffer[mutable.Buffer[A]] = {
var s = mutable.Buffer.empty[A] //Stack to keep track of nodes reachable from current node
val index = mutable.Map.empty[A, Int] //index of each node
val lowLink = mutable.Map.empty[A, Int] //The smallest index reachable from the node
val ret = mutable.Buffer.empty[mutable.Buffer[A]] //Keep track of SCC in graph
def visit(v: A): Unit = {
//Set index and lowlink of node on first visit
index(v) = index.size
lowLink(v) = index(v)
//Add to stack
s += v
if (src.exists(_.from == v)) {
for (w <- src.find(e => e.from == v).head.to) {
if (!index.contains(w)) { //Node is not explored yet
//Perform DFS from node W
visit(w)
//Update the lowlink value of v so it has the value of the lowest node reachable from itself and from node w
lowLink(v) = math.min(lowLink(w), lowLink(v))
} else if (s.contains(w)) {
// Node w is on the stack meaning - it means there is a path from w to v
// and since node w is a neighbor to node v there is also a path from v to w
lowLink(v) = math.min(lowLink(v), index(w))
}
}
}
//The lowlink value haven't been updated meaning it is the root of a cycle/SCC
if (lowLink(v) == index(v)) {
//Add the elements to the cycle that has been added to the stack and whose lowlink has been updated by node v's lowlink
//This is the elements on the stack that is placed behind v
val n = s.length - s.indexOf(v)
ret += s.takeRight(n)
//Remove these elements from the stack
s.dropRightInPlace(n)
}
}
//Perform a DFS from all no nodes that hasn't been explored
src.foreach(v => if (!index.contains(v.from)) visit(v.from))
ret
}
// A cycle exist if there is a SCC with at least two components
lazy val hasCycle: Boolean = trajan.exists(_.size >= 2)
lazy val trajanCycle: Iterable[Seq[A]] = trajan.filter(_.size >= 2).distinct.map(_.toSeq).toSeq
lazy val topologicalSortedEdges: Seq[Edge[A]] =
if (hasCycle) Seq[Edge[A]]()
else trajan.flatten.reverse.flatMap(x => src.find(_.from == x)).toSeq
}
Related
I am trying to calculate the values for the Elliott oscillator breakbands in python. I have the indicator logic in C#, which I will leave below.
In python I have already calculated and checked the value of the histogram, on which depends the calculation of the breakbands I want to calculate.
However, the values that the current logic gives are not the correct ones and I can't find what is the fault, since as I mentioned, I am replicating an existing logic in another language.
Actual logic in Python:
# For LowerBand, this return directly 0:
osc_fast = 5
osc_slow = 35
lens = osc_fast + osc_slow
pr = 2 / lens
strenght = 100
df = df.assign(Lwr_Line=0)
df['Lwr_Line'] = np.where(df["EWO_Std"] < 0, (df['EWO_Std']*pr) + (df['Lwr_Line'].shift(1)*(1-pr)), df['Lwr_Line'].shift(1))
df['LineEWOLwr'] = strenght / 100 * df['Lwr_Line']
df.drop(columns='Lwr_Line', inplace=True)
# For UpperBand:
df = df.assign(Upr_Line=0)
df['Upr_Line'] = np.where(df["EWO_Std"] > 0, (df['EWO_Std']*pr) + (df['Upr_Line'].shift(1)*(1-pr)), df['Upr_Line'].shift(1))
df['LineEWOUpr'] = strenght / 100 * df['Upr_Line']
df.drop(columns='Upr_Line', inplace=True)
Logic checked in C#:
{
MP[0] = ( High[0] + Low[0] ) / 2;
UprLine[0] = 0;
LwrLine[0] = 0;
Lens = OscFast + OscSlow;
Pr = 2.0/Lens;
if(CurrentBar < OscSlow){
OscAG = 0;
if (OscAG > 0){
OscAGUpr[0] = OscAG;
if (OscAGUpr[0] > OscAGUpr[1]){
OscAGUprDiv[0] = OscAG;
}
}
else{
OscAGLwr[0] = OscAG;
OscAGLwrDiv[0] = OscAG;
}
}
else{
OscAG = SMA(MP,OscFast)[0] - SMA(MP,OscSlow)[0];
if (OscAG > 0){
UprLine[0] = (OscAG*Pr) + (UprLine[1]*(1-Pr));
LwrLine[0] = LwrLine[1];
OscAGUpr[0] = OscAG;
if (OscAGUpr[0] > OscAGUpr[1])
{
OscAGUprDiv[0] = OscAG;
}
}
else{
UprLine[0] = UprLine[1];
LwrLine[0] = (OscAG*Pr) + (LwrLine[1]*(1-Pr));
OscAGLwr[0] = OscAG;
if (OscAGLwr[0] > OscAGLwr[1])
{
OscAGLwrDiv[0] = OscAG;
}
}
}
LineEWOUpr[0] = BOBStrength / 100 * UprLine[0];
LineEWOLwr[0] = BOBStrength / 100 * LwrLine[0];
}
Does anyone know what the error could be?
Thanks!
I'm tried many combinations, but no one works
I have a dataset like below -
List((X,Set(" 1", " 7")), (Z,Set(" 5")), (D,Set(" 2")), (E,Set(" 8")), ("F ",Set(" 5", " 9", " 108")), (G,Set(" 2", " 11")), (A,Set(" 7", " 5")), (M,Set(108)))
Here X is related to A as 7 is common between them
Z is related to A as 5 is common between them
F is related to A as 5 is common between them
M is related to F as 108 is common between them
So, X, Z, A, F and M are related
D and G are related as 2 is common between them
E is not related to anybody
So, the output would be ((X, Z, A, F, M), (D,G), (E))
Order doesn't matter here.
I have used Scala here, but solution in Scala/Python or a pseudocode would work for me.
Build an undirected graph where each label is connected to each number from the corresponding set (i.e. (A, { 1, 2 }) would give two edges: A <-> 1 and A <-> 2)
Compute the connected components (using depth-first search, for example).
Filter out only the labels from the connected components.
import util.{Left, Right, Either}
import collection.mutable
def connectedComponentsOfAsc[F, V](faces: List[(F, Set[V])]): List[List[F]] = {
type Node = Either[F, V]
val graphBuilder = mutable.HashMap.empty[Node, mutable.HashSet[Node]]
def addEdge(a: Node, b: Node): Unit =
graphBuilder.getOrElseUpdate(a, mutable.HashSet.empty[Node]) += b
for
(faceLabel, vertices) <- faces
vertex <- vertices
do
val faceNode = Left(faceLabel)
val vertexNode = Right(vertex)
addEdge(faceNode, vertexNode)
addEdge(vertexNode, faceNode)
val graph = graphBuilder.view.mapValues(_.toSet).toMap
val ccs = connectedComponents(graph)
ccs.map(_.collect { case Left(faceLabel) => faceLabel }.toList)
}
def connectedComponents[V](undirectedGraph: Map[V, Set[V]]): List[Set[V]] = {
val visited = mutable.HashSet.empty[V]
var connectedComponent = mutable.HashSet.empty[V]
val components = mutable.ListBuffer.empty[Set[V]]
def dfs(curr: V): Unit = {
if !visited(curr) then
visited += curr
connectedComponent += curr
undirectedGraph(curr).foreach(dfs)
}
for v <- undirectedGraph.keys do
if !visited(v) then
connectedComponent = mutable.HashSet.empty[V]
dfs(v)
components += connectedComponent.toSet
components.toList
}
Can be used like this:
#main def main(): Unit = {
println(connectedComponentsOfAsc(
List(
("X",Set("1", "7")),
("Z",Set("5")),
("D",Set("2")),
("E",Set("8")),
("F",Set("5", "9", "108")),
("G",Set("2", "11")),
("A",Set("7", "5")),
("M",Set("108"))
)
).map(_.sorted).sortBy(_.toString))
}
Produces:
List(List(A, F, M, X, Z), List(D, G), List(E))
All steps are O(n) (scales linearly with the size of input).
This answer is self-contained, but using some kind of graph-library would be clearly advantageous here.
Ultimately using a simpler solution in python as below:
data=[
["X",{"1", "7"}],
["Z",{"5",}],
["D",{"2",}],
["E",{"8",}],
["F",{"5", "9", "108"}],
["G",{"2", "11"}],
["A",{"7", "5"}],
["M",{"108"}]
]
for i in range(len(data)):
for j in range(len(data)):
if(data[i][1].intersection(data[j][1])):
if(data[i][0]!=data[j][0] ):
data[i][1] = data[j][1] = (data[i][1]).union(data[j][1])
for k, g in groupby(sorted([[sorted(tuple(d[1])),d[0]] for d in data]), key=lambda x: x[0]):
print(list(l[1] for l in g))
Getting output as :
['A', 'F', 'M', 'X', 'Z']
['D', 'G']
['E']
Tested for few more datasets and it seems to be working fine.
// I put some values in quotes so we have consistent string input
val initialData :List[(String, Set[String])] = List(
("X",Set(" 1", " 7")),
("Z",Set(" 5")),
("D",Set(" 2")),
("E",Set(" 8")),
("F ",Set(" 5", " 9", " 108")),
("G",Set(" 2", " 11")),
("A",Set(" 7", " 5")),
("M",Set("108"))
)
// Clean up the Sets by turning the string data inside the sets into Ints.
val cleanedData = initialData.map(elem => (elem._1, elem._2.map(_.trim.toInt)))
> cleanedData: List[(String, scala.collection.immutable.Set[Int])] = List((X,Set(1, 7)), (Z,Set(5)), (D,Set(2)), (E,Set(8)), ("F ",Set(5, 9, 108)), (G,Set(2, 11)), (A,Set(7, 5)), (M,Set(108)))
// Explode the Sets into a list of simple mappings. X -> 1, X -> 7 individually.
val explodedList = cleanedData.flatMap(x => x._2.map(v => (x._1, v)))
> explodedList: List[(String, Int)] = List((X,1), (X,7), (Z,5), (D,2), (E,8), ("F ",5), ("F ",9), ("F ",108), (G,2), (G,11), (A,7), (A,5), (M,108))
Group them together by the new key
val mappings = explodedList.groupBy(_._2)
> mappings: scala.collection.immutable.Map[Int,List[(String, Int)]] = Map(5 -> List((Z,5), ("F ",5), (A,5)), 1 -> List((X,1)), 9 -> List(("F ",9)), 2 -> List((D,2), (G,2)), 7 -> List((X,7), (A,7)), 108 -> List(("F ",108), (M,108)), 11 -> List((G,11)), 8 -> List((E,8)))
Print the output
mappings.foreach { case (key, items) =>
println(s"${items.map(_._1).mkString(",")} are all related because of $key")
}
> Z,F ,A are all related because of 5
> X are all related because of 1
> F are all related because of 9
> D,G are all related because of 2
> X,A are all related because of 7
> F ,M are all related because of 108
> G are all related because of 11
> E are all related because of 8
Read input, creating a vector of pairs
e.g.
X 1
X 7
Z 5
...
Sort the vector in order of the second member of the pairs
e.g
X 1
D 2
G 2
...
Iterate over sorted vector, adding to a "pass1 group" so long as the second member does not change. If it does change, start a new pass1 group.
e.g.
X
D G
Z F A
X A
E
F
G
merge pass1 groups with common members to give the output groups.
Here is the C++ code that implements this
#include <string>
#include <iostream>
#include <vector>
#include <algorithm>
bool merge(
std::vector<char> &res,
std::vector<char> &vg)
{
bool ret = false;
for (char r : res)
{
for (char c : vg)
{
if (c == r)
ret = true;
}
}
if (!ret)
return false;
for (char c : vg)
{
if (std::find(res.begin(), res.end(), c) == res.end())
res.push_back(c);
}
return true;
}
void add(
std::vector<std::vector<char>> &result,
std::vector<char> &vg)
{
std::vector<char> row;
for (char c : vg)
row.push_back(c);
result.push_back(row);
}
main()
{
std::string input = "List((X,Set(\" 1\", \" 7\")), (Z,Set(\" 5\")), (D,Set(\" 2\")), (E,Set(\" 8\")), (F,Set(\" 5\", \" 9\", \" 108\")), (G,Set(\" 2\", \" 11\")), (A,Set(\" 7\", \" 5\")), (M,Set(\"108\")))";
input = "List((A,Set(\"0\", \"1\")),(B,Set(\"1\", \"2\")),(C,Set(\"2\", \"3\")),(D,Set(\"3\", \"4\")))";
std::vector<std::pair<char, int>> vinp;
int p = input.find("Set");
int q = input.find("Set", p + 1);
while (p != -1)
{
char c = input[p - 2];
int s = input.find_first_of("0123456789", p);
if( s == -1 )
break;
while (s < q)
{
vinp.push_back(std::make_pair(
c,
atoi(input.substr(s).c_str())));
s = input.find_first_of("0123456789", s + 3);
if( s == -1 )
break;
}
p = q;
q = input.find("Set", p + 1);
if( q == -1 )
q = input.length();
}
std::sort(vinp.begin(), vinp.end(),
[](std::pair<char, int> a, std::pair<char, int> b)
{
return a.second < b.second;
});
std::cout << "sorted\n";
for (auto &p : vinp)
std::cout << p.first << " " << p.second << "\n";
std::vector<std::vector<char>> vpass1;
std::vector<char> row;
int sec = -1;
for (auto &p : vinp)
{
if (p.second != sec)
{
// new group
if (row.size())
vpass1.push_back(row);
sec = p.second;
row.clear();
}
row.push_back(p.first);
}
std::cout << "\npass1\n";
for (auto &row : vpass1)
{
for (char c : row)
std::cout << c << " ";
std::cout << "\n";
}
std::vector<std::vector<char>> result;
std::vector<char> pass2group;
bool fmerge2 = true;
while (fmerge2)
{
fmerge2 = false;
for (auto &vg : vpass1)
{
if (!result.size())
add(result, vg);
else
{
bool fmerge1 = false;
for (auto &res : result)
{
if (merge(res, vg))
{
fmerge1 = true;
fmerge2 = true;
break;
}
}
if (!fmerge1)
add(result, vg);
}
}
if (fmerge2)
{
vpass1 = result;
result.clear();
}
}
std::cout << "\n(";
for (auto &res : result)
{
if (res.size())
{
std::cout << "(";
for (char c : res)
std::cout << c << " ";
std::cout << ")";
}
}
std::cout << ")\n";
return 0;
}
It produces the correct result
((X A Z F M )(D G )(E ))
I am trying to convert a code in c++ to python and it seems to me that tho there are similarities of the code. There are also some huge differences like syntax and loops and condition implementation. Below is the code that I am working on.
process_queue = []
process_arrival = []
process_burst = []
process_holder = []
completion = []
waiting = []
turnaround = []
count = 0
time = 0
smallest = 0
i = 0
x = 0
end = 0
tt = 0
avg = 0
base = 0
n = int(input('Enter the total no of processes: '))
for i in range(n):
process_queue.append([])
process_arrival.append([])
process_burst.append([])
process_holder.append([])
process_queue[i].append(input('Enter name: '))
process_arrival[i].append(int(input('Enter arrival: ')))
process_burst[i].append(int(input('Enter burst: ')))
process_holder[i] = process_burst[i]
process_burst[i-0] = 9999
for time in range(n):
if count != n:
smallest = 9
for i in range(n):
arr_val = ''.join(map(str,process_arrival[n]))
int_arr_val = int(arr_val)
burst_val = ''.join(map(str,process_burst[n]))
int_burst_val = int(burst_val)
burst_val_min = ''.join(map(str,process_burst[smallest]))
int_burst_val_min = int(burst_val_min)
if((int_proc_val <= time) and (int_burst_val < int_burst_val_min) and (int_burst_val > 0)):
smallest = i
process_burst[smallest] - 1
if process_burst[smallest] == 0:
count += 1
end = time + 1
completion[smallest] = end
waiting[smallest] = end - process_arrival[smallest] - process_holder[smallest]
turnaround[smallest] = end - process_arrival[smallest]
Here is the code in c++
int a[10],b[10],x[10];
int waiting[10],turnaround[10],completion[10];
int i,j,smallest,count=0,time,n;
double avg=0,tt=0,end;
cout<<"\nEnter the number of Processes: "; //input
cin>>n;
for(i=0; i<n; i++)
{
cout<<"\nEnter arrival time of process: "; //input
cin>>a[i];
}
for(i=0; i<n; i++)
{
cout<<"\nEnter burst time of process: "; //input
cin>>b[i];
}
for(i=0; i<n; i++)
x[i]=b[i];
b[9]=9999;
for(time=0; count!=n; time++)
{
smallest=9;
for(i=0; i<n; i++)
{
if(a[i]<=time && b[i]<b[smallest] && b[i]>0 )
smallest=i;
}
b[smallest]--;
if(b[smallest]==0)
{
count++;
end=time+1;
completion[smallest] = end;
waiting[smallest] = end - a[smallest] - x[smallest];
turnaround[smallest] = end - a[smallest];
}
}
Error that I encountered is
IndexError: list index out of range
I am doing the conversion wrong?
I am running community detection in graphs made from telecom CDR data. First I was working with very dense graphs containing 10000 nodes, and the algorithm was producing 150 to 170 communities per graph. I was using Louvain community detection algorithm implemented in Scala for Spark.
When I try to run the same algorithm but implemented in C#, I get around 10 communities per graph. I also did some testing with smaller graph, containing around 300 nodes, and same thing occur. When I run it in Spark with Scala I get around 50 communities. When I run it in python or C# I get from 8 to 10 communities.
I am really surprised to see such difference. Every implementation that I used (Scala, Python or C#) is referring to the paper by VD Blondel https://arxiv.org/abs/0803.0476, so the algorithm should be the same, but the output is completely different. Did anyone experienced something like that when using Spark/Scala vs. python/c#?
This is how the main class Louvain is called:
import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.log4j.{Level, Logger}
object Driver {
def main(args: Array[String]): Unit ={
val config = LouvainConfig(
"src/data/input/file_with_edges.csv", //input file
"src/data/output/", //output dir
1, //parallelism
2000, //minimumComplessionProgress
1, //progressCounter
",") //delimiter
val sc = new SparkContext("local[*]", "Louvain")
val louvain = new Louvain()
louvain.run(sc, config)
}
}
This is Scala implementation that I am using:
import scala.reflect.ClassTag
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.serializers.DefaultArraySerializers.ObjectArraySerializer
import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import org.apache.spark._
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.graphx._
import org.apache.spark.broadcast.Broadcast
//import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.{SparkContext}
class Louvain() extends Serializable{
def getEdgeRDD(sc: SparkContext, conf: LouvainConfig, typeConversionMethod: String => Long = _.toLong): RDD[Edge[Long]] = {
sc.textFile(conf.inputFile, conf.parallelism).map(row => {
val tokens = row.split(conf.delimiter).map(_.trim())
tokens.length match {
case 2 => new Edge(typeConversionMethod(tokens(0)),
typeConversionMethod(tokens(1)), 1L)
case 3 => new Edge(typeConversionMethod(tokens(0)),
typeConversionMethod(tokens(1)), tokens(2).toDouble.toLong)
case _ => throw new IllegalArgumentException("invalid input line: " + row)
}
})
}
/**
* Generates a new graph of type Graph[VertexState,Long] based on an
input graph of type.
* Graph[VD,Long]. The resulting graph can be used for louvain computation.
*
*/
def createLouvainGraph[VD: ClassTag](graph: Graph[VD, Long]):
Graph[LouvainData, Long] = {
val nodeWeights = graph.aggregateMessages(
(e:EdgeContext[VD,Long,Long]) => {
e.sendToSrc(e.attr)
e.sendToDst(e.attr)
},
(e1: Long, e2: Long) => e1 + e2
)
graph.outerJoinVertices(nodeWeights)((vid, data, weightOption) => {
val weight = weightOption.getOrElse(0L)
new LouvainData(vid, weight, 0L, weight, false)
}).partitionBy(PartitionStrategy.EdgePartition2D).groupEdges(_ + _)
}
/**
* Creates the messages passed between each vertex to convey
neighborhood community data.
*/
def sendCommunityData(e: EdgeContext[LouvainData, Long, Map[(Long, Long), Long]]) = {
val m1 = (Map((e.srcAttr.community, e.srcAttr.communitySigmaTot) -> e.attr))
val m2 = (Map((e.dstAttr.community, e.dstAttr.communitySigmaTot) -> e.attr))
e.sendToSrc(m2)
e.sendToDst(m1)
}
/**
* Merge neighborhood community data into a single message for each vertex
*/
def mergeCommunityMessages(m1: Map[(Long, Long), Long], m2: Map[(Long, Long), Long]) = {
val newMap = scala.collection.mutable.HashMap[(Long, Long), Long]()
m1.foreach({ case (k, v) =>
if (newMap.contains(k)) newMap(k) = newMap(k) + v
else newMap(k) = v
})
m2.foreach({ case (k, v) =>
if (newMap.contains(k)) newMap(k) = newMap(k) + v
else newMap(k) = v
})
newMap.toMap
}
/**
* Returns the change in modularity that would result from a vertex
moving to a specified community.
*/
def q(
currCommunityId: Long,
testCommunityId: Long,
testSigmaTot: Long,
edgeWeightInCommunity: Long,
nodeWeight: Long,
internalWeight: Long,
totalEdgeWeight: Long): BigDecimal = {
val isCurrentCommunity = currCommunityId.equals(testCommunityId)
val M = BigDecimal(totalEdgeWeight)
val k_i_in_L = if (isCurrentCommunity) edgeWeightInCommunity + internalWeight else edgeWeightInCommunity
val k_i_in = BigDecimal(k_i_in_L)
val k_i = BigDecimal(nodeWeight + internalWeight)
val sigma_tot = if (isCurrentCommunity) BigDecimal(testSigmaTot) - k_i else BigDecimal(testSigmaTot)
var deltaQ = BigDecimal(0.0)
if (!(isCurrentCommunity && sigma_tot.equals(BigDecimal.valueOf(0.0)))) {
deltaQ = k_i_in - (k_i * sigma_tot / M)
//println(s" $deltaQ = $k_i_in - ( $k_i * $sigma_tot / $M")
}
deltaQ
}
/**
* Join vertices with community data form their neighborhood and
select the best community for each vertex to maximize change in
modularity.
* Returns a new set of vertices with the updated vertex state.
*/
def louvainVertJoin(
louvainGraph: Graph[LouvainData, Long],
msgRDD: VertexRDD[Map[(Long, Long), Long]],
totalEdgeWeight: Broadcast[Long],
even: Boolean) = {
// innerJoin[U, VD2](other: RDD[(VertexId, U)])(f: (VertexId, VD, U) => VD2): VertexRDD[VD2]
louvainGraph.vertices.innerJoin(msgRDD)((vid, louvainData, communityMessages) => {
var bestCommunity = louvainData.community
val startingCommunityId = bestCommunity
var maxDeltaQ = BigDecimal(0.0);
var bestSigmaTot = 0L
// VertexRDD[scala.collection.immutable.Map[(Long, Long),Long]]
// e.g. (1,Map((3,10) -> 2, (6,4) -> 2, (2,8) -> 2, (4,8) -> 2, (5,8) -> 2))
// e.g. communityId:3, sigmaTotal:10, communityEdgeWeight:2
communityMessages.foreach({ case ((communityId, sigmaTotal), communityEdgeWeight) =>
val deltaQ = q(
startingCommunityId,
communityId,
sigmaTotal,
communityEdgeWeight,
louvainData.nodeWeight,
louvainData.internalWeight,
totalEdgeWeight.value)
//println(" communtiy: "+communityId+" sigma:"+sigmaTotal+"
//edgeweight:"+communityEdgeWeight+" q:"+deltaQ)
if (deltaQ > maxDeltaQ || (deltaQ > 0 && (deltaQ == maxDeltaQ &&
communityId > bestCommunity))) {
maxDeltaQ = deltaQ
bestCommunity = communityId
bestSigmaTot = sigmaTotal
}
})
// only allow changes from low to high communties on even cyces and
// high to low on odd cycles
if (louvainData.community != bestCommunity && ((even &&
louvainData.community > bestCommunity) || (!even &&
louvainData.community < bestCommunity))) {
//println(" "+vid+" SWITCHED from "+vdata.community+" to "+bestCommunity)
louvainData.community = bestCommunity
louvainData.communitySigmaTot = bestSigmaTot
louvainData.changed = true
}
else {
louvainData.changed = false
}
if (louvainData == null)
println("vdata is null: " + vid)
louvainData
})
}
def louvain(
sc: SparkContext,
graph: Graph[LouvainData, Long],
minProgress: Int = 1,
progressCounter: Int = 1): (Double, Graph[LouvainData, Long], Int) = {
var louvainGraph = graph.cache()
val graphWeight = louvainGraph.vertices.map(louvainVertex => {
val (vertexId, louvainData) = louvainVertex
louvainData.internalWeight + louvainData.nodeWeight
}).reduce(_ + _)
val totalGraphWeight = sc.broadcast(graphWeight)
println("totalEdgeWeight: " + totalGraphWeight.value)
// gather community information from each vertex's local neighborhood
var communityRDD =
louvainGraph.aggregateMessages(sendCommunityData, mergeCommunityMessages)
var activeMessages = communityRDD.count() //materializes the msgRDD
//and caches it in memory
var updated = 0L - minProgress
var even = false
var count = 0
val maxIter = 100000
var stop = 0
var updatedLastPhase = 0L
do {
count += 1
even = !even
// label each vertex with its best community based on neighboring
// community information
val labeledVertices = louvainVertJoin(louvainGraph, communityRDD,
totalGraphWeight, even).cache()
// calculate new sigma total value for each community (total weight
// of each community)
val communityUpdate = labeledVertices
.map({ case (vid, vdata) => (vdata.community, vdata.nodeWeight +
vdata.internalWeight)})
.reduceByKey(_ + _).cache()
// map each vertex ID to its updated community information
val communityMapping = labeledVertices
.map({ case (vid, vdata) => (vdata.community, vid)})
.join(communityUpdate)
.map({ case (community, (vid, sigmaTot)) => (vid, (community, sigmaTot))})
.cache()
// join the community labeled vertices with the updated community info
val updatedVertices = labeledVertices.join(communityMapping).map({
case (vertexId, (louvainData, communityTuple)) =>
val (community, communitySigmaTot) = communityTuple
louvainData.community = community
louvainData.communitySigmaTot = communitySigmaTot
(vertexId, louvainData)
}).cache()
updatedVertices.count()
labeledVertices.unpersist(blocking = false)
communityUpdate.unpersist(blocking = false)
communityMapping.unpersist(blocking = false)
val prevG = louvainGraph
louvainGraph = louvainGraph.outerJoinVertices(updatedVertices)((vid, old, newOpt) => newOpt.getOrElse(old))
louvainGraph.cache()
// gather community information from each vertex's local neighborhood
val oldMsgs = communityRDD
communityRDD = louvainGraph.aggregateMessages(sendCommunityData, mergeCommunityMessages).cache()
activeMessages = communityRDD.count() // materializes the graph
// by forcing computation
oldMsgs.unpersist(blocking = false)
updatedVertices.unpersist(blocking = false)
prevG.unpersistVertices(blocking = false)
// half of the communites can swtich on even cycles and the other half
// on odd cycles (to prevent deadlocks) so we only want to look for
// progess on odd cycles (after all vertcies have had a chance to
// move)
if (even) updated = 0
updated = updated + louvainGraph.vertices.filter(_._2.changed).count
if (!even) {
println(" # vertices moved: " + java.text.NumberFormat.getInstance().format(updated))
if (updated >= updatedLastPhase - minProgress) stop += 1
updatedLastPhase = updated
}
} while (stop <= progressCounter && (even || (updated > 0 && count < maxIter)))
println("\nCompleted in " + count + " cycles")
// Use each vertex's neighboring community data to calculate the
// global modularity of the graph
val newVertices =
louvainGraph.vertices.innerJoin(communityRDD)((vertexId, louvainData,
communityMap) => {
// sum the nodes internal weight and all of its edges that are in
// its community
val community = louvainData.community
var accumulatedInternalWeight = louvainData.internalWeight
val sigmaTot = louvainData.communitySigmaTot.toDouble
def accumulateTotalWeight(totalWeight: Long, item: ((Long, Long), Long)) = {
val ((communityId, sigmaTotal), communityEdgeWeight) = item
if (louvainData.community == communityId)
totalWeight + communityEdgeWeight
else
totalWeight
}
accumulatedInternalWeight = communityMap.foldLeft(accumulatedInternalWeight)(accumulateTotalWeight)
val M = totalGraphWeight.value
val k_i = louvainData.nodeWeight + louvainData.internalWeight
val q = (accumulatedInternalWeight.toDouble / M) - ((sigmaTot * k_i) / math.pow(M, 2))
//println(s"vid: $vid community: $community $q = ($k_i_in / $M) - ( ($sigmaTot * $k_i) / math.pow($M, 2) )")
if (q < 0)
0
else
q
})
val actualQ = newVertices.values.reduce(_ + _)
// return the modularity value of the graph along with the
// graph. vertices are labeled with their community
(actualQ, louvainGraph, count / 2)
}
def compressGraph(graph: Graph[LouvainData, Long], debug: Boolean = true): Graph[LouvainData, Long] = {
// aggregate the edge weights of self loops. edges with both src and dst in the same community.
// WARNING can not use graph.mapReduceTriplets because we are mapping to new vertexIds
val internalEdgeWeights = graph.triplets.flatMap(et => {
if (et.srcAttr.community == et.dstAttr.community) {
Iterator((et.srcAttr.community, 2 * et.attr)) // count the weight from both nodes
}
else Iterator.empty
}).reduceByKey(_ + _)
// aggregate the internal weights of all nodes in each community
val internalWeights = graph.vertices.values.map(vdata =>
(vdata.community, vdata.internalWeight))
.reduceByKey(_ + _)
// join internal weights and self edges to find new interal weight of each community
val newVertices = internalWeights.leftOuterJoin(internalEdgeWeights).map({ case (vid, (weight1, weight2Option)) =>
val weight2 = weight2Option.getOrElse(0L)
val state = new LouvainData()
state.community = vid
state.changed = false
state.communitySigmaTot = 0L
state.internalWeight = weight1 + weight2
state.nodeWeight = 0L
(vid, state)
}).cache()
// translate each vertex edge to a community edge
val edges = graph.triplets.flatMap(et => {
val src = math.min(et.srcAttr.community, et.dstAttr.community)
val dst = math.max(et.srcAttr.community, et.dstAttr.community)
if (src != dst) Iterator(new Edge(src, dst, et.attr))
else Iterator.empty
}).cache()
// generate a new graph where each community of the previous graph is
// now represented as a single vertex
val compressedGraph = Graph(newVertices, edges)
.partitionBy(PartitionStrategy.EdgePartition2D).groupEdges(_ + _)
// calculate the weighted degree of each node
val nodeWeights = compressedGraph.aggregateMessages(
(e:EdgeContext[LouvainData,Long,Long]) => {
e.sendToSrc(e.attr)
e.sendToDst(e.attr)
},
(e1: Long, e2: Long) => e1 + e2
)
// fill in the weighted degree of each node
// val louvainGraph = compressedGraph.joinVertices(nodeWeights)((vid,data,weight)=> {
val louvainGraph = compressedGraph.outerJoinVertices(nodeWeights)((vid, data, weightOption) => {
val weight = weightOption.getOrElse(0L)
data.communitySigmaTot = weight + data.internalWeight
data.nodeWeight = weight
data
}).cache()
louvainGraph.vertices.count()
louvainGraph.triplets.count() // materialize the graph
newVertices.unpersist(blocking = false)
edges.unpersist(blocking = false)
println("******************************************************")
println (louvainGraph.vertices.count())
louvainGraph
}
def saveLevel(
sc: SparkContext,
config: LouvainConfig,
level: Int,
qValues: Array[(Int, Double)],
graph: Graph[LouvainData, Long]) = {
val vertexSavePath = config.outputDir + "/level_" + level + "_vertices"
val edgeSavePath = config.outputDir + "/level_" + level + "_edges"
// save
graph.vertices.saveAsTextFile(vertexSavePath)
graph.edges.saveAsTextFile(edgeSavePath)
// overwrite the q values at each level
sc.parallelize(qValues, 1).saveAsTextFile(config.outputDir + "/qvalues_" + level)
}
//def run[VD: ClassTag](sc: SparkContext, config: LouvainConfig, graph: Graph[VD, Long]): Unit = {
def run[VD: ClassTag](sc: SparkContext, config: LouvainConfig): Unit = {
val edgeRDD = getEdgeRDD(sc, config)
val initialGraph = Graph.fromEdges(edgeRDD, None)
var louvainGraph = createLouvainGraph(initialGraph)
var compressionLevel = -1 // number of times the graph has been compressed
var q_modularityValue = -1.0 // current modularity value
var halt = false
var qValues: Array[(Int, Double)] = Array()
do {
compressionLevel += 1
println(s"\nStarting Louvain level $compressionLevel")
// label each vertex with its best community choice at this level of compression
val (currentQModularityValue, currentGraph, numberOfPasses) =
louvain(sc, louvainGraph, config.minimumCompressionProgress, config.progressCounter)
louvainGraph.unpersistVertices(blocking = false)
louvainGraph = currentGraph
println(s"qValue: $currentQModularityValue")
qValues = qValues :+ ((compressionLevel, currentQModularityValue))
saveLevel(sc, config, compressionLevel, qValues, louvainGraph)
// If modularity was increased by at least 0.001 compress the graph and repeat
// halt immediately if the community labeling took less than 3 passes
//println(s"if ($passes > 2 && $currentQ > $q + 0.001 )")
if (numberOfPasses > 2 && currentQModularityValue > q_modularityValue + 0.001) {
q_modularityValue = currentQModularityValue
louvainGraph = compressGraph(louvainGraph)
}
else {
halt = true
}
} while (!halt)
//finalSave(sc, compressionLevel, q_modularityValue, louvainGraph)
}
}
The code is taken from github https://github.com/athinggoingon/louvain-modularity.
Here is the example of input file, just first 10 lines. The graph is made from csv file, schema is : node1, node2, weight_of_the_edge
104,158,34.23767571520276
146,242,12.49338107205348
36,37,0.6821403413414481
28,286,2.5053934980726456
9,92,0.34412941554076487
222,252,10.502288293870677
235,282,0.25717021769814874
264,79,18.555996343792327
24,244,1.7094102023399587
231,75,21.698401383558213
I am trying to convert the Java Code to Python Code and i have done it so far. Java Code works but Python Code doesn't work. Please help me.
Python Code
import random
class QLearning():
alpha = 0.1
gamma = 0.9
state_a = 0
state_b = 1
state_c = 2
state_d = 3
state_e = 4
state_f = 5
states_count = 6
states = [state_a, state_b, state_c, state_d, state_e, state_f]
R = [[0 for x in range(states_count)] for x in range(states_count)]
Q = [[0 for x in range(states_count)] for x in range(states_count)]
action_from_a = [state_b, state_d]
action_from_b = [state_a, state_c, state_e]
action_from_c = [state_c]
action_from_d = [state_a, state_e]
action_from_e = [state_b, state_d, state_f]
action_from_f = [state_c, state_e]
actions = [action_from_a, action_from_b, action_from_c, action_from_d, action_from_e, action_from_f]
state_names = ["A","B","C","D","E","F"]
def __init__(self):
self.R[self.state_b][self.state_c] = 100
self.R[self.state_f][self.state_c] = 100
def run(self):
for i in range(1000):
state = random.randrange(self.states_count)
while(state != self.state_c):
actions_from_state = self.actions[state]
index = random.randrange(len(actions_from_state))
action = actions_from_state[index]
next_state = action
q = self.Q_Value(state, action)
max_Q = self.max_q(next_state)
r = self.R_Value(state, action)
value = q + self.alpha * (r + self.gamma * max_Q - q)
self.set_q(state, action, value)
state = next_state
def max_q(self, s):
self.run().actions_from_state = self.actions[s]
max_value = 5
for i in range(len(self.run().actions_from_state)):
self.run().next_state = self.run().actions_from_state[i]
self.run().value = self.Q[s][self.run().next_state]
if self.run().value > max_value:
max_value = self.run().value
return max_value
def policy(self, state):
self.run().actions_from_state = self.actions[state]
max_value = 5
policy_goto_state = state
for i in range(len(self.run().actions_from_state)):
self.run().next_state = self.run().actions_from_state[i]
self.run().value = self.Q[state][self.run().next_state]
if self.run().value > max_value:
max_value = self.run().value
policy_goto_state = self.run().next_state
return policy_goto_state
def Q_Value(self, s,a):
return self.Q[s][a]
def set_q(self, s, a, value):
self.Q[s][a] = value
def R_Value(self, s, a):
return self.R[s][a]
def print_result(self):
print("Print Result")
for i in range(len(self.Q)):
print("Out From (0)".format(self.state_names[i]))
for j in range(len(self.Q[i])):
print(self.Q[i][j])
def show_policy(self):
print("Show Policy")
for i in range(len(self.states)):
fro = self.states[i]
to = self.policy(fro)
print("From {0} goto {1}".format(self.state_names[fro], self.state_names[to]))
obj = QLearning()
obj.run()
obj.print_result()
obj.show_policy()
Java Code
import java.text.DecimalFormat;
import java.util.Random;
public class Qlearning {
final DecimalFormat df = new DecimalFormat("#.##");
// path finding
final double alpha = 0.1;
final double gamma = 0.9;
// states A,B,C,D,E,F
// e.g. from A we can go to B or D
// from C we can only go to C
// C is goal state, reward 100 when B->C or F->C
//
// _______
// |A|B|C|
// |_____|
// |D|E|F|
// |_____|
//
final int stateA = 0;
final int stateB = 1;
final int stateC = 2;
final int stateD = 3;
final int stateE = 4;
final int stateF = 5;
final int statesCount = 6;
final int[] states = new int[]{stateA,stateB,stateC,stateD,stateE,stateF};
// http://en.wikipedia.org/wiki/Q-learning
// http://people.revoledu.com/kardi/tutorial/ReinforcementLearning/Q-Learning.htm
// Q(s,a)= Q(s,a) + alpha * (R(s,a) + gamma * Max(next state, all actions) - Q(s,a))
int[][] R = new int[statesCount][statesCount]; // reward lookup
double[][] Q = new double[statesCount][statesCount]; // Q learning
int[] actionsFromA = new int[] { stateB, stateD };
int[] actionsFromB = new int[] { stateA, stateC, stateE };
int[] actionsFromC = new int[] { stateC };
int[] actionsFromD = new int[] { stateA, stateE };
int[] actionsFromE = new int[] { stateB, stateD, stateF };
int[] actionsFromF = new int[] { stateC, stateE };
int[][] actions = new int[][] { actionsFromA, actionsFromB, actionsFromC,
actionsFromD, actionsFromE, actionsFromF };
String[] stateNames = new String[] { "A", "B", "C", "D", "E", "F" };
public Qlearning() {
init();
}
public void init() {
R[stateB][stateC] = 100; // from b to c
R[stateF][stateC] = 100; // from f to c
}
public static void main(String[] args) {
long BEGIN = System.currentTimeMillis();
Qlearning obj = new Qlearning();
obj.run();
obj.printResult();
obj.showPolicy();
long END = System.currentTimeMillis();
System.out.println("Time: " + (END - BEGIN) / 1000.0 + " sec.");
}
void run() {
/*
1. Set parameter , and environment reward matrix R
2. Initialize matrix Q as zero matrix
3. For each episode: Select random initial state
Do while not reach goal state o
Select one among all possible actions for the current state o
Using this possible action, consider to go to the next state o
Get maximum Q value of this next state based on all possible actions o
Compute o Set the next state as the current state
*/
// For each episode
Random rand = new Random();
for (int i = 0; i < 1000; i++) { // train episodes
// Select random initial state
int state = rand.nextInt(statesCount);
while (state != stateC) // goal state
{
// Select one among all possible actions for the current state
int[] actionsFromState = actions[state];
// Selection strategy is random in this example
int index = rand.nextInt(actionsFromState.length);
int action = actionsFromState[index];
// Action outcome is set to deterministic in this example
// Transition probability is 1
int nextState = action; // data structure
// Using this possible action, consider to go to the next state
double q = Q(state, action);
double maxQ = maxQ(nextState);
int r = R(state, action);
double value = q + alpha * (r + gamma * maxQ - q);
setQ(state, action, value);
// Set the next state as the current state
state = nextState;
}
}
}
double maxQ(int s) {
int[] actionsFromState = actions[s];
double maxValue = Double.MIN_VALUE;
for (int i = 0; i < actionsFromState.length; i++) {
int nextState = actionsFromState[i];
double value = Q[s][nextState];
if (value > maxValue)
maxValue = value;
}
return maxValue;
}
// get policy from state
int policy(int state) {
int[] actionsFromState = actions[state];
double maxValue = Double.MIN_VALUE;
int policyGotoState = state; // default goto self if not found
for (int i = 0; i < actionsFromState.length; i++) {
int nextState = actionsFromState[i];
double value = Q[state][nextState];
if (value > maxValue) {
maxValue = value;
policyGotoState = nextState;
}
}
return policyGotoState;
}
double Q(int s, int a) {
return Q[s][a];
}
void setQ(int s, int a, double value) {
Q[s][a] = value;
}
int R(int s, int a) {
return R[s][a];
}
void printResult() {
System.out.println("Print result");
for (int i = 0; i < Q.length; i++) {
System.out.print("out from " + stateNames[i] + ": ");
for (int j = 0; j < Q[i].length; j++) {
System.out.print(df.format(Q[i][j]) + " ");
}
System.out.println();
}
}
// policy is maxQ(states)
void showPolicy() {
System.out.println("\nshowPolicy");
for (int i = 0; i < states.length; i++) {
int from = states[i];
int to = policy(from);
System.out.println("from "+stateNames[from]+" goto "+stateNames[to]);
}
}
}
Traceback
C:\Python33\python.exe "C:/Users/Ajay/Documents/Python Scripts/RL/QLearning.py"
Traceback (most recent call last):
File "C:/Users/Ajay/Documents/Python Scripts/RL/QLearning.py", line 4, in <module>
class QLearning():
File "C:/Users/Ajay/Documents/Python Scripts/RL/QLearning.py", line 19, in QLearning
R = [[0 for x in range(states_count)] for x in range(states_count)]
File "C:/Users/Ajay/Documents/Python Scripts/RL/QLearning.py", line 19, in <listcomp>
R = [[0 for x in range(states_count)] for x in range(states_count)]
NameError: global name 'states_count' is not defined
To access all of the class attributes you define (i.e. everything between class QLearning and def __init__), you need to use self or the class name:
self.states_count
or
QLearning.states_count
I don't know the algorithm, but it is possible that these class attributes should be instance attributes (i.e. separate for each instance of the class, rather than shared amongst all instances) and therefore defined in __init__ (or other instance methods) using self anyway.