List of usage examples for org.apache.mahout.math Vector dot
double dot(Vector x);
From source file:org.trustedanalytics.atk.giraph.algorithms.als.AlternatingLeastSquaresComputation.java
License:Apache License
/** * Compute bias//from w w w . j a v a 2 s. c om * * @param value of type Vector * @param messages of type Iterable * @return bias of type double */ private double computeBias(Vector value, Iterable<MessageData4CFWritable> messages) { double errorOnTrain = 0d; int numTrain = 0; for (MessageData4CFWritable message : messages) { EdgeType et = message.getType(); if (et == EdgeType.TRAIN) { double weight = message.getWeight(); Vector vector = message.getVector(); double otherBias = message.getBias(); double predict = otherBias + value.dot(vector); double e = weight - predict; errorOnTrain += e; numTrain++; } } double bias = 0d; if (numTrain > 0) { bias = errorOnTrain / ((1 + lambda) * numTrain); } return bias; }
From source file:org.trustedanalytics.atk.giraph.algorithms.als.AlternatingLeastSquaresComputation.java
License:Apache License
@Override public void compute(Vertex<CFVertexId, VertexData4CFWritable, EdgeData4CFWritable> vertex, Iterable<MessageData4CFWritable> messages) throws IOException { long step = getSuperstep(); if (step == 0) { initialize(vertex);// w w w. ja v a 2 s . com vertex.voteToHalt(); return; } Vector currentValue = vertex.getValue().getVector(); double currentBias = vertex.getValue().getBias(); // update aggregators every (2 * interval) super steps if ((step % (2 * learningCurveOutputInterval)) == 0) { double errorOnTrain = 0d; double errorOnValidate = 0d; double errorOnTest = 0d; int numTrain = 0; for (MessageData4CFWritable message : messages) { EdgeType et = message.getType(); double weight = message.getWeight(); Vector vector = message.getVector(); double otherBias = message.getBias(); double predict = currentBias + otherBias + currentValue.dot(vector); double e = weight - predict; switch (et) { case TRAIN: errorOnTrain += e * e; numTrain++; break; case VALIDATE: errorOnValidate += e * e; break; case TEST: errorOnTest += e * e; break; default: throw new IllegalArgumentException("Unknown recognized edge type: " + et.toString()); } } double costOnTrain = 0d; if (numTrain > 0) { costOnTrain = errorOnTrain / numTrain + lambda * (currentBias * currentBias + currentValue.dot(currentValue)); } aggregate(SUM_TRAIN_COST, new DoubleWritable(costOnTrain)); aggregate(SUM_VALIDATE_ERROR, new DoubleWritable(errorOnValidate)); aggregate(SUM_TEST_ERROR, new DoubleWritable(errorOnTest)); } // update vertex value if (step < maxSupersteps) { // xxt records the result of x times x transpose Matrix xxt = new DenseMatrix(featureDimension, featureDimension); xxt = xxt.assign(0d); // xr records the result of x times rating Vector xr = currentValue.clone().assign(0d); int numTrain = 0; for (MessageData4CFWritable message : messages) { EdgeType et = message.getType(); if (et == EdgeType.TRAIN) { double weight = message.getWeight(); Vector vector = message.getVector(); double otherBias = message.getBias(); xxt = xxt.plus(vector.cross(vector)); xr = xr.plus(vector.times(weight - currentBias - otherBias)); numTrain++; } } xxt = xxt.plus(new DiagonalMatrix(lambda * numTrain, featureDimension)); Matrix bMatrix = new DenseMatrix(featureDimension, 1).assignColumn(0, xr); Vector value = new QRDecomposition(xxt).solve(bMatrix).viewColumn(0); vertex.getValue().setVector(value); // update vertex bias if (biasOn) { double bias = computeBias(value, messages); vertex.getValue().setBias(bias); } // send out messages for (Edge<CFVertexId, EdgeData4CFWritable> edge : vertex.getEdges()) { MessageData4CFWritable newMessage = new MessageData4CFWritable(vertex.getValue(), edge.getValue()); sendMessage(edge.getTargetVertexId(), newMessage); } } vertex.voteToHalt(); }
From source file:org.trustedanalytics.atk.giraph.algorithms.cgd.ConjugateGradientDescentComputation.java
License:Apache License
/** * Compute gradient//from w w w . java2 s .c om * * @param bias of type double * @param value of type Vector * @param messages of type Iterable * @return gradient of type Vector */ private Vector computeGradient(double bias, Vector value, Iterable<MessageData4CFWritable> messages) { Vector xr = value.clone().assign(0d); int numTrain = 0; for (MessageData4CFWritable message : messages) { EdgeType et = message.getType(); if (et == EdgeType.TRAIN) { double weight = message.getWeight(); Vector vector = message.getVector(); double otherBias = message.getBias(); double predict = bias + otherBias + value.dot(vector); double e = predict - weight; xr = xr.plus(vector.times(e)); numTrain++; } } Vector gradient = value.clone().assign(0d); if (numTrain > 0) { gradient = xr.divide(numTrain).plus(value.times(lambda)); } return gradient; }
From source file:org.trustedanalytics.atk.giraph.algorithms.cgd.ConjugateGradientDescentComputation.java
License:Apache License
/** * Compute alpha//from w w w . j a va 2s.co m * * @param gradient of type Vector * @param conjugate of type Vector * @param messages of type Iterable * @return alpha of type double */ private double computeAlpha(Vector gradient, Vector conjugate, Iterable<MessageData4CFWritable> messages) { double alpha = 0d; if (conjugate.norm(1d) == 0d) { return alpha; } double predictSquared = 0d; int numTrain = 0; for (MessageData4CFWritable message : messages) { EdgeType et = message.getType(); if (et == EdgeType.TRAIN) { Vector vector = message.getVector(); double predict = conjugate.dot(vector); predictSquared += predict * predict; numTrain++; } } if (numTrain > 0) { alpha = -gradient.dot(conjugate) / (predictSquared / numTrain + lambda * conjugate.dot(conjugate)); } return alpha; }
From source file:org.trustedanalytics.atk.giraph.algorithms.cgd.ConjugateGradientDescentComputation.java
License:Apache License
/** * Compute beta according to Hestenes-Stiefel formula * * @param gradient of type Vector//from w ww . j a v a 2 s . com * @param conjugate of type Vector * @param gradientNext of type Vector * @return beta of type double */ private double computeBeta(Vector gradient, Vector conjugate, Vector gradientNext) { double beta = 0d; if (conjugate.norm(1d) == 0d) { return beta; } Vector deltaVector = gradientNext.minus(gradient); beta = -gradientNext.dot(deltaVector) / conjugate.dot(deltaVector); return beta; }
From source file:org.trustedanalytics.atk.giraph.algorithms.cgd.ConjugateGradientDescentComputation.java
License:Apache License
@Override public void compute(Vertex<CFVertexId, VertexData4CGDWritable, EdgeData4CFWritable> vertex, Iterable<MessageData4CFWritable> messages) throws IOException { long step = getSuperstep(); if (step == 0) { initialize(vertex);/*from w ww. ja v a2s.c om*/ vertex.voteToHalt(); return; } Vector currentValue = vertex.getValue().getVector(); double currentBias = vertex.getValue().getBias(); // update aggregators every (2 * interval) super steps if ((step % (2 * learningCurveOutputInterval)) == 0) { double errorOnTrain = 0d; double errorOnValidate = 0d; double errorOnTest = 0d; int numTrain = 0; for (MessageData4CFWritable message : messages) { EdgeType et = message.getType(); double weight = message.getWeight(); Vector vector = message.getVector(); double otherBias = message.getBias(); double predict = currentBias + otherBias + currentValue.dot(vector); double e = weight - predict; switch (et) { case TRAIN: errorOnTrain += e * e; numTrain++; break; case VALIDATE: errorOnValidate += e * e; break; case TEST: errorOnTest += e * e; break; default: throw new IllegalArgumentException("Unknown recognized edge type: " + et.toString()); } } double costOnTrain = 0d; if (numTrain > 0) { costOnTrain = errorOnTrain / numTrain + lambda * (currentBias * currentBias + currentValue.dot(currentValue)); } aggregate(SUM_TRAIN_COST, new DoubleWritable(costOnTrain)); aggregate(SUM_VALIDATE_ERROR, new DoubleWritable(errorOnValidate)); aggregate(SUM_TEST_ERROR, new DoubleWritable(errorOnTest)); } if (step < maxSupersteps) { // implement CGD iterations Vector value0 = vertex.getValue().getVector(); Vector gradient0 = vertex.getValue().getGradient(); Vector conjugate0 = vertex.getValue().getConjugate(); double bias0 = vertex.getValue().getBias(); for (int i = 0; i < numCGDIters; i++) { double alpha = computeAlpha(gradient0, conjugate0, messages); Vector value = value0.plus(conjugate0.times(alpha)); Vector gradient = computeGradient(bias0, value, messages); double beta = computeBeta(gradient0, conjugate0, gradient); Vector conjugate = conjugate0.times(beta).minus(gradient); value0 = value; gradient0 = gradient; conjugate0 = conjugate; } // update vertex values vertex.getValue().setVector(value0); vertex.getValue().setConjugate(conjugate0); vertex.getValue().setGradient(gradient0); // update vertex bias if (biasOn) { double bias = computeBias(value0, messages); vertex.getValue().setBias(bias); } // send out messages for (Edge<CFVertexId, EdgeData4CFWritable> edge : vertex.getEdges()) { MessageData4CFWritable newMessage = new MessageData4CFWritable(vertex.getValue(), edge.getValue()); sendMessage(edge.getTargetVertexId(), newMessage); } } vertex.voteToHalt(); }
From source file:org.trustedanalytics.atk.giraph.algorithms.lda.CVB0LDAComputation.java
License:Apache License
/** * Evaluate cost according to vertex and messages * * @param vertex of the graph/*from ww w. jav a 2 s .c o m*/ * @param messages of type iterable * @param map of type HashMap */ private void evaluateCost(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex, Iterable<LdaMessage> messages, HashMap<LdaVertexId, Vector> map) { if (vertex.getId().isDocument()) { return; } Vector vector = vertex.getValue().getLdaResult(); vector = vector.plus(config.beta()).times(nk.plus(numWords * config.beta()).assign(Functions.INV)); double cost = 0d; for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getEdges()) { double weight = edge.getValue().getWordCount(); LdaVertexId id = edge.getTargetVertexId(); if (map.containsKey(id)) { Vector otherVector = map.get(id); otherVector = otherVector.plus(config.alpha()).normalize(1d); cost -= weight * Math.log(vector.dot(otherVector)); } else { throw new IllegalArgumentException( String.format("Vertex ID %s: A message is mis-matched", vertex.getId().getValue())); } } aggregate(SUM_COST, new DoubleWritable(cost)); }