spark graphx计算pagerank源代码分析



阅读次数

##入口函数

def pageRank(tol: Double, resetProb: Double = 0.15): Graph[Double, Double] = {                                                      
    PageRank.runUntilConvergence(graph, tol, resetProb)
}

其中resetProb的作用可以参考pagerank原理的说明:

##实现函数

196   def runUntilConvergenceWithOptions[VD: ClassTag, ED: ClassTag](
197       graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15,
198       srcId: Option[VertexId] = None): Graph[Double, Double] =
199   {
//这个用来表示,用户是否自定义了游走的起点
200     val personalized = srcId.isDefined
//默认是随机选择起点
201     val src: VertexId = srcId.getOrElse(-1L)
202 
// Initialize the pagerankGraph with each edge attribute
// having weight 1/outDegree and each vertex with attribute 1.0.
//初始化pagerankGraph,边的属性设置为源节点的出度的倒数。对于启动节点,顶点设置为(resetProb, Double.NegativeInfinity),其他的顶点设置为(0,0),图的顶点格式为VertextRDD[(Double,Double)]
205     val pagerankGraph: Graph[(Double, Double), Double] = graph
206       // Associate the degree with each vertex
207       .outerJoinVertices(graph.outDegrees) {
208         (vid, vdata, deg) => deg.getOrElse(0)
209       }
210       // Set the weight on the edges based on the degree
211       .mapTriplets( e => 1.0 / e.srcAttr )
212       // Set the vertex attributes to (initalPR, delta = 0)
213       .mapVertices { (id, attr) =>
214         if (id == src) (resetProb, Double.NegativeInfinity) else (0.0, 0.0)
215       }
216       .cache()
217 
// Define the three functions needed to implement PageRank in the GraphX
// version of Pregel

//定义pregel中的vprog,这里的msgSum就是下面map/reduce产出的messages图,每个节点的格式是VertexRDD[double],更新图
220     def vertexProgram(id: VertexId, attr: (Double, Double), msgSum: Double): (Double, Double) = {
221       val (oldPR, lastDelta) = attr
222       val newPR = oldPR + (1.0 - resetProb) * msgSum
223       (newPR, newPR - oldPR)
224     }
//定义指定了启动节点的vprog
226     def personalizedVertexProgram(id: VertexId, attr: (Double, Double),
227       msgSum: Double): (Double, Double) = {
228       val (oldPR, lastDelta) = attr
229       var teleport = oldPR
230       val delta = if (src==id) 1.0 else 0.0
231       teleport = oldPR*delta
232 
233       val newPR = teleport + (1.0 - resetProb) * msgSum
234       val newDelta = if (lastDelta == Double.NegativeInfinity) newPR else newPR - oldPR
235       (newPR, newDelta)
236     }

//定义pregel中的sendmessage,如果源节点的权重>tol的话,则按照边的权重(edge.attr)加权传递过来
238     def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = {
239       if (edge.srcAttr._2 > tol) {
240         Iterator((edge.dstId, edge.srcAttr._2 * edge.attr))
241       } else {
242         Iterator.empty
243       }
244     }
//定义pregel的reduce程序,将各个节点传递过来的权重相加即可
246     def messageCombiner(a: Double, b: Double): Double = a + b
247 
248     // The initial message received by all vertices in PageRank
249     val initialMessage = if (personalized) 0.0 else resetProb / (1.0 - resetProb)
250 
251     // Execute a dynamic version of Pregel.
252     val vp = if (personalized) {
253       (id: VertexId, attr: (Double, Double), msgSum: Double) =>
254         personalizedVertexProgram(id, attr, msgSum)
255     } else {
256       (id: VertexId, attr: (Double, Double), msgSum: Double) =>
257         vertexProgram(id, attr, msgSum)
258     }
259 
260     Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)(
261       vp, sendMessage, messageCombiner)
262       .mapVertices((vid, attr) => attr._1)
263   } // end of deltaPageRank
264 
265 } 

###pregel的具体实现(一次mapReduceTriplets完成一次全局权重调整,vprog可以理解为两次权重调整的更新)

class GraphOps[VD, ED] {
  def pregel[A]
      (initialMsg: A,
       maxIter: Int = Int.MaxValue,
       activeDir: EdgeDirection = EdgeDirection.Out)
      (vprog: (VertexId, VD, A) => VD,
       sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
       mergeMsg: (A, A) => A)
    : Graph[VD, ED] = {
    // Receive the initial message at each vertex
    var g = mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache()  

    // compute the messages
    var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg)  
    var activeMessages = messages.count()
    // Loop until no messages remain or maxIterations is achieved
    var i = 0
    while (activeMessages > 0 && i < maxIterations) {
      // Receive the messages and update the vertices.
      g = g.joinVertices(messages)(vprog).cache()
      val oldMessages = messages
      // Send new messages, skipping edges where neither side received a message. We must cache
      // messages so it can be materialized on the next line, allowing us to uncache the previous
      // iteration.
      messages = GraphXUtils.mapReduceTriplets(
        g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
      activeMessages = messages.count()
      i += 1
    }
    g
  }
}