List of usage examples for org.apache.mahout.math Vector normalize
Vector normalize(double power);
From source file:org.trustedanalytics.atk.giraph.algorithms.lbp.LoopyBeliefPropagationComputation.java
License:Apache License
@Override public void compute(Vertex<LongWritable, VertexData4LBPWritable, DoubleWritable> vertex, Iterable<IdWithVectorMessage> messages) throws IOException { long step = getSuperstep(); if (step == 0) { initializeVertex(vertex);/* www. j a v a 2s . c o m*/ return; } // collect messages sent to this vertex HashMap<Long, Vector> map = new HashMap<Long, Vector>(); for (IdWithVectorMessage message : messages) { map.put(message.getData(), message.getVector()); } // update posterior according to prior and messages VertexData4LBPWritable vertexValue = vertex.getValue(); VertexType vt = vertexValue.getType(); vt = ignoreVertexType ? VertexType.TRAIN : vt; Vector prior = vertexValue.getPriorVector(); double nStates = prior.size(); if (vt != VertexType.TRAIN) { // assign a uniform prior for validate/test vertex prior = prior.clone().assign(Math.log(1.0 / nStates)); } // sum of prior and messages Vector sumPosterior = prior; for (IdWithVectorMessage message : messages) { sumPosterior = sumPosterior.plus(message.getVector()); } sumPosterior = sumPosterior.plus(-sumPosterior.maxValue()); // update posterior if this isn't an anchor vertex if (prior.maxValue() < anchorThreshold) { // normalize posterior Vector posterior = sumPosterior.clone().assign(Functions.EXP); posterior = posterior.normalize(1d); Vector oldPosterior = vertexValue.getPosteriorVector(); double delta = posterior.minus(oldPosterior).norm(1d); // aggregate deltas switch (vt) { case TRAIN: aggregate(SUM_TRAIN_DELTA, new DoubleWritable(delta)); break; case VALIDATE: aggregate(SUM_VALIDATE_DELTA, new DoubleWritable(delta)); break; case TEST: aggregate(SUM_TEST_DELTA, new DoubleWritable(delta)); break; default: throw new IllegalArgumentException("Unknown vertex type: " + vt.toString()); } // update posterior vertexValue.setPosteriorVector(posterior); } if (step < maxSupersteps) { // if it's not a training vertex, don't send out messages if (vt != VertexType.TRAIN) { return; } IdWithVectorMessage newMessage = new IdWithVectorMessage(); newMessage.setData(vertex.getId().get()); // update belief Vector belief = prior.clone(); for (Edge<LongWritable, DoubleWritable> edge : vertex.getEdges()) { double weight = edge.getValue().get(); long id = edge.getTargetVertexId().get(); Vector tempVector = sumPosterior; if (map.containsKey(id)) { tempVector = sumPosterior.minus(map.get(id)); } for (int i = 0; i < nStates; i++) { double sum = 0d; for (int j = 0; j < nStates; j++) { double msg = Math.exp( tempVector.getQuick(j) + edgePotential(Math.abs(i - j) / (nStates - 1), weight)); if (maxProduct) { sum = sum > msg ? sum : msg; } else { sum += msg; } } belief.setQuick(i, sum > 0d ? Math.log(sum) : Double.MIN_VALUE); } belief = belief.plus(-belief.maxValue()); newMessage.setVector(belief); sendMessage(edge.getTargetVertexId(), newMessage); } } else { // convert prior back to regular scale before output prior = vertexValue.getPriorVector(); prior = prior.assign(Functions.EXP); vertexValue.setPriorVector(prior); vertex.voteToHalt(); } }
From source file:org.trustedanalytics.atk.giraph.algorithms.lda.CVB0LDAComputation.java
License:Apache License
/** * Initialize vertex/edges, collect graph statistics and send out messages * * @param vertex of the graph/*w ww . j a v a 2 s . c om*/ */ private void initialize(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex) { // initialize vertex vector, i.e., the theta for doc and phi for word in LDA double[] vertexValues = new double[config.numTopics()]; vertex.getValue().setLdaResult(new DenseVector(vertexValues)); // initialize edge vector, i.e., the gamma in LDA Random rand1 = new Random(vertex.getId().seed()); long seed1 = rand1.nextInt(); double maxDelta = 0d; double sumWeights = 0d; for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getMutableEdges()) { double weight = edge.getValue().getWordCount(); // generate the random seed for this edge Random rand2 = new Random(edge.getTargetVertexId().seed()); long seed2 = rand2.nextInt(); long seed = seed1 + seed2; Random rand = new Random(seed); double[] edgeValues = new double[config.numTopics()]; for (int i = 0; i < config.numTopics(); i++) { edgeValues[i] = rand.nextDouble(); } Vector vector = new DenseVector(edgeValues); vector = vector.normalize(1d); edge.getValue().setVector(vector); // find the max delta among all edges double delta = vector.norm(1d) / config.numTopics(); if (delta > maxDelta) { maxDelta = delta; } // the sum of weights from all edges sumWeights += weight; } // update vertex value updateVertex(vertex); // aggregate max delta value aggregate(MAX_DELTA, new DoubleWritable(maxDelta)); // collect graph statistics if (vertex.getId().isDocument()) { aggregate(SUM_DOC_VERTEX_COUNT, new LongWritable(1)); } else { aggregate(SUM_OCCURRENCE_COUNT, new DoubleWritable(sumWeights)); aggregate(SUM_WORD_VERTEX_COUNT, new LongWritable(1)); } // send out messages LdaMessage newMessage = new LdaMessage(vertex.getId().copy(), vertex.getValue().getLdaResult()); sendMessageToAllEdges(vertex, newMessage); }
From source file:org.trustedanalytics.atk.giraph.algorithms.lda.CVB0LDAComputation.java
License:Apache License
/** * Update edge value according to vertex and messages * * @param vertex of the graph/*from ww w. j ava2 s . c om*/ * @param map of type HashMap */ private void updateEdge(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex, HashMap<LdaVertexId, Vector> map) { Vector vector = vertex.getValue().getLdaResult(); double maxDelta = 0d; for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getMutableEdges()) { Vector gamma = edge.getValue().getVector(); LdaVertexId id = edge.getTargetVertexId(); if (map.containsKey(id)) { Vector otherVector = map.get(id); Vector newGamma = null; if (vertex.getId().isDocument()) { newGamma = vector.minus(gamma).plus(config.alpha()) .times(otherVector.minus(gamma).plus(config.beta())) .times(nk.minus(gamma).plus(numWords * config.beta()).assign(Functions.INV)); } else { newGamma = vector.minus(gamma).plus(config.beta()) .times(otherVector.minus(gamma).plus(config.alpha())) .times(nk.minus(gamma).plus(numWords * config.beta()).assign(Functions.INV)); } newGamma = newGamma.normalize(1d); double delta = gamma.minus(newGamma).norm(1d) / config.numTopics(); if (delta > maxDelta) { maxDelta = delta; } // update edge vector edge.getValue().setVector(newGamma); } else { // this happens when you don't have your Vertex Id's being setup correctly throw new IllegalArgumentException( String.format("Vertex ID %s: A message is mis-matched.", vertex.getId())); } } aggregate(MAX_DELTA, new DoubleWritable(maxDelta)); }
From source file:org.trustedanalytics.atk.giraph.algorithms.lda.GiraphLdaComputation.java
License:Apache License
/** * Initialize vertex/edges, collect graph statistics and send out messages * * @param vertex of the graph/*from w ww. j ava 2 s . c om*/ */ private void initialize(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex) { // initialize vertex vector, i.e., the theta for doc and phi for word in LDA double[] vertexValues = new double[config.numTopics()]; vertex.getValue().setLdaResult(new DenseVector(vertexValues)); Vector updatedVector = vertex.getValue().getLdaResult().clone().assign(0d); // initialize edge vector, i.e., the gamma in LDA Random rand1 = new Random(vertex.getId().seed()); long seed1 = rand1.nextInt(); double maxDelta = 0d; double sumWeights = 0d; for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getMutableEdges()) { double weight = edge.getValue().getWordCount(); // generate the random seed for this edge Random rand2 = new Random(edge.getTargetVertexId().seed()); long seed2 = rand2.nextInt(); long seed = seed1 + seed2; Random rand = new Random(seed); double[] edgeValues = new double[config.numTopics()]; for (int i = 0; i < config.numTopics(); i++) { edgeValues[i] = rand.nextDouble(); } Vector vector = new DenseVector(edgeValues); vector = vector.normalize(1d); edge.getValue().setVector(vector); // find the max delta among all edges double delta = vector.norm(1d) / config.numTopics(); if (delta > maxDelta) { maxDelta = delta; } // the sum of weights from all edges sumWeights += weight; updatedVector = updateVector(updatedVector, edge); } // update vertex value vertex.getValue().setLdaResult(updatedVector); ; // aggregate max delta value aggregateWord(vertex); aggregate(MAX_DELTA, new DoubleWritable(maxDelta)); // collect graph statistics if (vertex.getId().isDocument()) { aggregate(SUM_DOC_VERTEX_COUNT, new LongWritable(1)); } else { aggregate(SUM_OCCURRENCE_COUNT, new DoubleWritable(sumWeights)); aggregate(SUM_WORD_VERTEX_COUNT, new LongWritable(1)); } // send out messages LdaMessage newMessage = new LdaMessage(vertex.getId().copy(), vertex.getValue().getLdaResult()); sendMessageToAllEdges(vertex, newMessage); }
From source file:org.trustedanalytics.atk.giraph.algorithms.lda.GiraphLdaComputation.java
License:Apache License
/** * Update vertex and outgoing edge values using current vertex values and messages * * @param vertex of the graph/*w w w . j a va2 s. co m*/ * @param map Map of vertices */ private void updateVertex(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex, HashMap<LdaVertexId, Vector> map) { Vector vector = vertex.getValue().getLdaResult(); Vector updatedVector = vertex.getValue().getLdaResult().clone().assign(0d); double maxDelta = 0d; for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getMutableEdges()) { Vector gamma = edge.getValue().getVector(); LdaVertexId id = edge.getTargetVertexId(); if (map.containsKey(id)) { Vector otherVector = map.get(id); Vector newGamma = null; if (vertex.getId().isDocument()) { newGamma = vector.minus(gamma).plus(config.alpha()) .times(otherVector.minus(gamma).plus(config.beta())) .times(nk.minus(gamma).plus(numWords * config.beta()).assign(Functions.INV)); } else { newGamma = vector.minus(gamma).plus(config.beta()) .times(otherVector.minus(gamma).plus(config.alpha())) .times(nk.minus(gamma).plus(numWords * config.beta()).assign(Functions.INV)); } newGamma = newGamma.normalize(1d); double delta = gamma.minus(newGamma).norm(1d) / config.numTopics(); if (delta > maxDelta) { maxDelta = delta; } // update edge vector edge.getValue().setVector(newGamma); } else { // this happens when you don't have your Vertex Id's being setup correctly throw new IllegalArgumentException( String.format("Vertex ID %s: A message is mis-matched.", vertex.getId())); } updatedVector = updateVector(updatedVector, edge); } vertex.getValue().setLdaResult(updatedVector); aggregateWord(vertex); aggregate(MAX_DELTA, new DoubleWritable(maxDelta)); }
From source file:org.trustedanalytics.atk.giraph.algorithms.lp.LabelPropagationComputation.java
License:Apache License
/** * initialize vertex and edges//from w ww. j av a 2 s .co m * * @param vertex a graph vertex */ private void initializeVertexEdges(Vertex<LongWritable, VertexData4LPWritable, DoubleWritable> vertex) { // normalize prior and initialize posterior VertexData4LPWritable vertexValue = vertex.getValue(); Vector priorValues = vertexValue.getPriorVector(); if (null != priorValues) { priorValues = priorValues.normalize(1d); initialVectorValues = priorValues; } else if (initialVectorValues != null) { priorValues = initialVectorValues; vertexValue.setLabeledStatus(false); } else { throw new RuntimeException("Vector labels missing from input data for vertex " + vertex.getId() + ". Add edge with vertex as first column."); } vertexValue.setPriorVector(priorValues); vertexValue.setPosteriorVector(priorValues.clone()); vertexValue.setDegree(initializeEdge(vertex)); // send out messages IdWithVectorMessage newMessage = new IdWithVectorMessage(vertex.getId().get(), priorValues); sendMessageToAllEdges(vertex, newMessage); }