List of usage examples for org.apache.mahout.math Vector times
Vector times(Vector x);
From source file:org.trustedanalytics.atk.giraph.algorithms.lda.CVB0LDAComputation.java
License:Apache License
/** * Update vertex value according to edge value * * @param vertex of the graph// ww w . jav a2 s .co m */ private void updateVertex(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex) { Vector vector = vertex.getValue().getLdaResult().clone().assign(0d); for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getEdges()) { double weight = edge.getValue().getWordCount(); Vector gamma = edge.getValue().getVector(); vector = vector.plus(gamma.times(weight)); } vertex.getValue().setLdaResult(vector); if (vertex.getId().isWord()) { aggregate(SUM_WORD_VERTEX_VALUE, new VectorWritable(vector)); } }
From source file:org.trustedanalytics.atk.giraph.algorithms.lda.GiraphLdaComputation.java
License:Apache License
/** * Update vertex value with edge value/*from w w w . j av a 2 s. c o m*/ * * @param vector vector of vertex value * @param edge of the graph */ private Vector updateVector(Vector vector, Edge<LdaVertexId, LdaEdgeData> edge) { double weight = edge.getValue().getWordCount(); Vector gamma = edge.getValue().getVector(); vector = vector.plus(gamma.times(weight)); return vector; }
From source file:org.trustedanalytics.atk.giraph.algorithms.lp.LabelPropagationComputation.java
License:Apache License
@Override public void compute(Vertex<LongWritable, VertexData4LPWritable, DoubleWritable> vertex, Iterable<IdWithVectorMessage> messages) throws IOException { long superStep = getSuperstep(); if (superStep == 0) { initializeVertexEdges(vertex);/*ww w .j a va 2 s . c o m*/ vertex.voteToHalt(); } else if (superStep <= maxSupersteps) { VertexData4LPWritable vertexValue = vertex.getValue(); Vector prior = vertexValue.getPriorVector(); Vector posterior = vertexValue.getPosteriorVector(); double degree = vertexValue.getDegree(); // collect messages sent to this vertex HashMap<Long, Vector> map = new HashMap(); for (IdWithVectorMessage message : messages) { map.put(message.getData(), message.getVector()); } // Update belief and calculate cost double hi = prior.getQuick(0); double fi = posterior.getQuick(0); double crossSum = 0d; Vector newBelief = posterior.clone().assign(0d); for (Edge<LongWritable, DoubleWritable> edge : vertex.getEdges()) { double weight = edge.getValue().get(); if (weight <= 0d) { throw new IllegalArgumentException( "Vertex ID: " + vertex.getId() + "has an edge with negative or zero value"); } long targetVertex = edge.getTargetVertexId().get(); if (map.containsKey(targetVertex)) { Vector tempVector = map.get(targetVertex); newBelief = newBelief.plus(tempVector.times(weight)); double fj = tempVector.getQuick(0); crossSum += weight * fi * fj; } } double cost = degree * ((1 - lambda) * (Math.pow(fi, 2) - crossSum) + 0.5 * lambda * Math.pow((fi - hi), 2)); aggregate(SUM_COST, new DoubleWritable(cost)); // Update posterior if the vertex was not processed if (vertexValue.wasLabeled() == false) { newBelief = (newBelief.times(1 - lambda).plus(prior.times(lambda))).normalize(1d); vertexValue.setPosteriorVector(newBelief); } // Send out messages if not the last step if (superStep != maxSupersteps) { IdWithVectorMessage newMessage = new IdWithVectorMessage(vertex.getId().get(), vertexValue.getPosteriorVector()); sendMessageToAllEdges(vertex, newMessage); } } vertex.voteToHalt(); }