Example usage for org.apache.mahout.math Vector dot

List of usage examples for org.apache.mahout.math Vector dot

Introduction

In this page you can find the example usage for org.apache.mahout.math Vector dot.

Prototype

double dot(Vector x);

Source Link

Document

Return the dot product of the recipient and the argument

Usage

From source file:org.trustedanalytics.atk.giraph.algorithms.als.AlternatingLeastSquaresComputation.java

License:Apache License

/**
 * Compute bias//from w  w w  .  j a  v a  2 s.  c  om
 *
 * @param value of type Vector
 * @param messages of type Iterable
 * @return bias of type double
 */
private double computeBias(Vector value, Iterable<MessageData4CFWritable> messages) {
    double errorOnTrain = 0d;
    int numTrain = 0;
    for (MessageData4CFWritable message : messages) {
        EdgeType et = message.getType();
        if (et == EdgeType.TRAIN) {
            double weight = message.getWeight();
            Vector vector = message.getVector();
            double otherBias = message.getBias();
            double predict = otherBias + value.dot(vector);
            double e = weight - predict;
            errorOnTrain += e;
            numTrain++;
        }
    }
    double bias = 0d;
    if (numTrain > 0) {
        bias = errorOnTrain / ((1 + lambda) * numTrain);
    }
    return bias;
}

From source file:org.trustedanalytics.atk.giraph.algorithms.als.AlternatingLeastSquaresComputation.java

License:Apache License

@Override
public void compute(Vertex<CFVertexId, VertexData4CFWritable, EdgeData4CFWritable> vertex,
        Iterable<MessageData4CFWritable> messages) throws IOException {
    long step = getSuperstep();
    if (step == 0) {
        initialize(vertex);// w w w. ja v a 2 s  . com
        vertex.voteToHalt();
        return;
    }

    Vector currentValue = vertex.getValue().getVector();
    double currentBias = vertex.getValue().getBias();
    // update aggregators every (2 * interval) super steps
    if ((step % (2 * learningCurveOutputInterval)) == 0) {
        double errorOnTrain = 0d;
        double errorOnValidate = 0d;
        double errorOnTest = 0d;
        int numTrain = 0;
        for (MessageData4CFWritable message : messages) {
            EdgeType et = message.getType();
            double weight = message.getWeight();
            Vector vector = message.getVector();
            double otherBias = message.getBias();
            double predict = currentBias + otherBias + currentValue.dot(vector);
            double e = weight - predict;
            switch (et) {
            case TRAIN:
                errorOnTrain += e * e;
                numTrain++;
                break;
            case VALIDATE:
                errorOnValidate += e * e;
                break;
            case TEST:
                errorOnTest += e * e;
                break;
            default:
                throw new IllegalArgumentException("Unknown recognized edge type: " + et.toString());
            }
        }
        double costOnTrain = 0d;
        if (numTrain > 0) {
            costOnTrain = errorOnTrain / numTrain
                    + lambda * (currentBias * currentBias + currentValue.dot(currentValue));
        }
        aggregate(SUM_TRAIN_COST, new DoubleWritable(costOnTrain));
        aggregate(SUM_VALIDATE_ERROR, new DoubleWritable(errorOnValidate));
        aggregate(SUM_TEST_ERROR, new DoubleWritable(errorOnTest));
    }

    // update vertex value
    if (step < maxSupersteps) {
        // xxt records the result of x times x transpose
        Matrix xxt = new DenseMatrix(featureDimension, featureDimension);
        xxt = xxt.assign(0d);
        // xr records the result of x times rating
        Vector xr = currentValue.clone().assign(0d);
        int numTrain = 0;
        for (MessageData4CFWritable message : messages) {
            EdgeType et = message.getType();
            if (et == EdgeType.TRAIN) {
                double weight = message.getWeight();
                Vector vector = message.getVector();
                double otherBias = message.getBias();
                xxt = xxt.plus(vector.cross(vector));
                xr = xr.plus(vector.times(weight - currentBias - otherBias));
                numTrain++;
            }
        }
        xxt = xxt.plus(new DiagonalMatrix(lambda * numTrain, featureDimension));
        Matrix bMatrix = new DenseMatrix(featureDimension, 1).assignColumn(0, xr);
        Vector value = new QRDecomposition(xxt).solve(bMatrix).viewColumn(0);
        vertex.getValue().setVector(value);

        // update vertex bias
        if (biasOn) {
            double bias = computeBias(value, messages);
            vertex.getValue().setBias(bias);
        }

        // send out messages
        for (Edge<CFVertexId, EdgeData4CFWritable> edge : vertex.getEdges()) {
            MessageData4CFWritable newMessage = new MessageData4CFWritable(vertex.getValue(), edge.getValue());
            sendMessage(edge.getTargetVertexId(), newMessage);
        }
    }

    vertex.voteToHalt();
}

From source file:org.trustedanalytics.atk.giraph.algorithms.cgd.ConjugateGradientDescentComputation.java

License:Apache License

/**
 * Compute gradient//from w  w w . java2  s  .c  om
 *
 * @param bias of type double
 * @param value of type Vector
 * @param messages of type Iterable
 * @return gradient of type Vector
 */
private Vector computeGradient(double bias, Vector value, Iterable<MessageData4CFWritable> messages) {
    Vector xr = value.clone().assign(0d);
    int numTrain = 0;
    for (MessageData4CFWritable message : messages) {
        EdgeType et = message.getType();
        if (et == EdgeType.TRAIN) {
            double weight = message.getWeight();
            Vector vector = message.getVector();
            double otherBias = message.getBias();
            double predict = bias + otherBias + value.dot(vector);
            double e = predict - weight;
            xr = xr.plus(vector.times(e));
            numTrain++;
        }
    }
    Vector gradient = value.clone().assign(0d);
    if (numTrain > 0) {
        gradient = xr.divide(numTrain).plus(value.times(lambda));
    }
    return gradient;
}

From source file:org.trustedanalytics.atk.giraph.algorithms.cgd.ConjugateGradientDescentComputation.java

License:Apache License

/**
 * Compute alpha//from w w w . j a va  2s.co m
 *
 * @param gradient of type Vector
 * @param conjugate of type Vector
 * @param messages of type Iterable
 * @return alpha of type double
 */
private double computeAlpha(Vector gradient, Vector conjugate, Iterable<MessageData4CFWritable> messages) {
    double alpha = 0d;
    if (conjugate.norm(1d) == 0d) {
        return alpha;
    }
    double predictSquared = 0d;
    int numTrain = 0;
    for (MessageData4CFWritable message : messages) {
        EdgeType et = message.getType();
        if (et == EdgeType.TRAIN) {
            Vector vector = message.getVector();
            double predict = conjugate.dot(vector);
            predictSquared += predict * predict;
            numTrain++;
        }
    }
    if (numTrain > 0) {
        alpha = -gradient.dot(conjugate) / (predictSquared / numTrain + lambda * conjugate.dot(conjugate));
    }
    return alpha;
}

From source file:org.trustedanalytics.atk.giraph.algorithms.cgd.ConjugateGradientDescentComputation.java

License:Apache License

/**
 * Compute beta according to Hestenes-Stiefel formula
 *
 * @param gradient of type Vector//from w ww  .  j a v a  2 s .  com
 * @param conjugate of type Vector
 * @param gradientNext of type Vector
 * @return beta of type double
 */
private double computeBeta(Vector gradient, Vector conjugate, Vector gradientNext) {
    double beta = 0d;
    if (conjugate.norm(1d) == 0d) {
        return beta;
    }
    Vector deltaVector = gradientNext.minus(gradient);
    beta = -gradientNext.dot(deltaVector) / conjugate.dot(deltaVector);
    return beta;
}

From source file:org.trustedanalytics.atk.giraph.algorithms.cgd.ConjugateGradientDescentComputation.java

License:Apache License

@Override
public void compute(Vertex<CFVertexId, VertexData4CGDWritable, EdgeData4CFWritable> vertex,
        Iterable<MessageData4CFWritable> messages) throws IOException {
    long step = getSuperstep();
    if (step == 0) {
        initialize(vertex);/*from  w  ww.  ja  v a2s.c  om*/
        vertex.voteToHalt();
        return;
    }

    Vector currentValue = vertex.getValue().getVector();
    double currentBias = vertex.getValue().getBias();
    // update aggregators every (2 * interval) super steps
    if ((step % (2 * learningCurveOutputInterval)) == 0) {
        double errorOnTrain = 0d;
        double errorOnValidate = 0d;
        double errorOnTest = 0d;
        int numTrain = 0;
        for (MessageData4CFWritable message : messages) {
            EdgeType et = message.getType();
            double weight = message.getWeight();
            Vector vector = message.getVector();
            double otherBias = message.getBias();
            double predict = currentBias + otherBias + currentValue.dot(vector);
            double e = weight - predict;
            switch (et) {
            case TRAIN:
                errorOnTrain += e * e;
                numTrain++;
                break;
            case VALIDATE:
                errorOnValidate += e * e;
                break;
            case TEST:
                errorOnTest += e * e;
                break;
            default:
                throw new IllegalArgumentException("Unknown recognized edge type: " + et.toString());
            }
        }
        double costOnTrain = 0d;
        if (numTrain > 0) {
            costOnTrain = errorOnTrain / numTrain
                    + lambda * (currentBias * currentBias + currentValue.dot(currentValue));
        }
        aggregate(SUM_TRAIN_COST, new DoubleWritable(costOnTrain));
        aggregate(SUM_VALIDATE_ERROR, new DoubleWritable(errorOnValidate));
        aggregate(SUM_TEST_ERROR, new DoubleWritable(errorOnTest));
    }

    if (step < maxSupersteps) {
        // implement CGD iterations
        Vector value0 = vertex.getValue().getVector();
        Vector gradient0 = vertex.getValue().getGradient();
        Vector conjugate0 = vertex.getValue().getConjugate();
        double bias0 = vertex.getValue().getBias();
        for (int i = 0; i < numCGDIters; i++) {
            double alpha = computeAlpha(gradient0, conjugate0, messages);
            Vector value = value0.plus(conjugate0.times(alpha));
            Vector gradient = computeGradient(bias0, value, messages);
            double beta = computeBeta(gradient0, conjugate0, gradient);
            Vector conjugate = conjugate0.times(beta).minus(gradient);
            value0 = value;
            gradient0 = gradient;
            conjugate0 = conjugate;
        }
        // update vertex values
        vertex.getValue().setVector(value0);
        vertex.getValue().setConjugate(conjugate0);
        vertex.getValue().setGradient(gradient0);

        // update vertex bias
        if (biasOn) {
            double bias = computeBias(value0, messages);
            vertex.getValue().setBias(bias);
        }

        // send out messages
        for (Edge<CFVertexId, EdgeData4CFWritable> edge : vertex.getEdges()) {
            MessageData4CFWritable newMessage = new MessageData4CFWritable(vertex.getValue(), edge.getValue());
            sendMessage(edge.getTargetVertexId(), newMessage);
        }
    }

    vertex.voteToHalt();
}

From source file:org.trustedanalytics.atk.giraph.algorithms.lda.CVB0LDAComputation.java

License:Apache License

/**
 * Evaluate cost according to vertex and messages
 *
 * @param vertex of the graph/*from ww w. jav a  2 s  .c o  m*/
 * @param messages of type iterable
 * @param map of type HashMap
 */
private void evaluateCost(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex, Iterable<LdaMessage> messages,
        HashMap<LdaVertexId, Vector> map) {

    if (vertex.getId().isDocument()) {
        return;
    }
    Vector vector = vertex.getValue().getLdaResult();
    vector = vector.plus(config.beta()).times(nk.plus(numWords * config.beta()).assign(Functions.INV));

    double cost = 0d;
    for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getEdges()) {
        double weight = edge.getValue().getWordCount();
        LdaVertexId id = edge.getTargetVertexId();
        if (map.containsKey(id)) {
            Vector otherVector = map.get(id);
            otherVector = otherVector.plus(config.alpha()).normalize(1d);
            cost -= weight * Math.log(vector.dot(otherVector));
        } else {
            throw new IllegalArgumentException(
                    String.format("Vertex ID %s: A message is mis-matched", vertex.getId().getValue()));
        }
    }
    aggregate(SUM_COST, new DoubleWritable(cost));
}