Example usage for org.apache.mahout.math Vector cross

List of usage examples for org.apache.mahout.math Vector cross

Introduction

In this page you can find the example usage for org.apache.mahout.math Vector cross.

Prototype

Matrix cross(Vector other);

Source Link

Document

Return the cross product of the receiver and the other vector

Usage

From source file:edu.snu.dolphin.bsp.examples.ml.algorithms.clustering.em.EMMainCmpTask.java

License:Apache License

@Override
public void run(final int iteration) {
    clusterToStats = new HashMap<>();
    final int numClusters = clusterSummaries.size();

    // Compute the partial statistics of each cluster
    for (final Vector vector : points) {
        final int dimension = vector.size();
        Matrix outProd = null;/*from  w w  w.  j av a2  s.  com*/

        if (isCovarianceDiagonal) {
            outProd = new SparseMatrix(dimension, dimension);
            for (int j = 0; j < dimension; j++) {
                outProd.set(j, j, vector.get(j) * vector.get(j));
            }
        } else {
            outProd = vector.cross(vector);
        }

        double denominator = 0;
        final double[] numerators = new double[numClusters];
        for (int i = 0; i < numClusters; i++) {
            final ClusterSummary clusterSummary = clusterSummaries.get(i);
            final Vector centroid = clusterSummary.getCentroid();
            final Matrix covariance = clusterSummary.getCovariance();
            final Double prior = clusterSummary.getPrior();

            final Vector differ = vector.minus(centroid);
            numerators[i] = prior / Math.sqrt(covariance.determinant())
                    * Math.exp(differ.dot(inverse(covariance).times(differ)) / (-2));
            denominator += numerators[i];
        }

        for (int i = 0; i < numClusters; i++) {
            final double posterior = denominator == 0 ? 1.0 / numerators.length : numerators[i] / denominator;
            if (!clusterToStats.containsKey(i)) {
                clusterToStats.put(i,
                        new ClusterStats(times(outProd, posterior), vector.times(posterior), posterior, false));
            } else {
                clusterToStats.get(i).add(
                        new ClusterStats(times(outProd, posterior), vector.times(posterior), posterior, false));
            }
        }
    }
}

From source file:org.trustedanalytics.atk.giraph.algorithms.als.AlternatingLeastSquaresComputation.java

License:Apache License

@Override
public void compute(Vertex<CFVertexId, VertexData4CFWritable, EdgeData4CFWritable> vertex,
        Iterable<MessageData4CFWritable> messages) throws IOException {
    long step = getSuperstep();
    if (step == 0) {
        initialize(vertex);/*  ww w  .j av  a 2  s .co  m*/
        vertex.voteToHalt();
        return;
    }

    Vector currentValue = vertex.getValue().getVector();
    double currentBias = vertex.getValue().getBias();
    // update aggregators every (2 * interval) super steps
    if ((step % (2 * learningCurveOutputInterval)) == 0) {
        double errorOnTrain = 0d;
        double errorOnValidate = 0d;
        double errorOnTest = 0d;
        int numTrain = 0;
        for (MessageData4CFWritable message : messages) {
            EdgeType et = message.getType();
            double weight = message.getWeight();
            Vector vector = message.getVector();
            double otherBias = message.getBias();
            double predict = currentBias + otherBias + currentValue.dot(vector);
            double e = weight - predict;
            switch (et) {
            case TRAIN:
                errorOnTrain += e * e;
                numTrain++;
                break;
            case VALIDATE:
                errorOnValidate += e * e;
                break;
            case TEST:
                errorOnTest += e * e;
                break;
            default:
                throw new IllegalArgumentException("Unknown recognized edge type: " + et.toString());
            }
        }
        double costOnTrain = 0d;
        if (numTrain > 0) {
            costOnTrain = errorOnTrain / numTrain
                    + lambda * (currentBias * currentBias + currentValue.dot(currentValue));
        }
        aggregate(SUM_TRAIN_COST, new DoubleWritable(costOnTrain));
        aggregate(SUM_VALIDATE_ERROR, new DoubleWritable(errorOnValidate));
        aggregate(SUM_TEST_ERROR, new DoubleWritable(errorOnTest));
    }

    // update vertex value
    if (step < maxSupersteps) {
        // xxt records the result of x times x transpose
        Matrix xxt = new DenseMatrix(featureDimension, featureDimension);
        xxt = xxt.assign(0d);
        // xr records the result of x times rating
        Vector xr = currentValue.clone().assign(0d);
        int numTrain = 0;
        for (MessageData4CFWritable message : messages) {
            EdgeType et = message.getType();
            if (et == EdgeType.TRAIN) {
                double weight = message.getWeight();
                Vector vector = message.getVector();
                double otherBias = message.getBias();
                xxt = xxt.plus(vector.cross(vector));
                xr = xr.plus(vector.times(weight - currentBias - otherBias));
                numTrain++;
            }
        }
        xxt = xxt.plus(new DiagonalMatrix(lambda * numTrain, featureDimension));
        Matrix bMatrix = new DenseMatrix(featureDimension, 1).assignColumn(0, xr);
        Vector value = new QRDecomposition(xxt).solve(bMatrix).viewColumn(0);
        vertex.getValue().setVector(value);

        // update vertex bias
        if (biasOn) {
            double bias = computeBias(value, messages);
            vertex.getValue().setBias(bias);
        }

        // send out messages
        for (Edge<CFVertexId, EdgeData4CFWritable> edge : vertex.getEdges()) {
            MessageData4CFWritable newMessage = new MessageData4CFWritable(vertex.getValue(), edge.getValue());
            sendMessage(edge.getTargetVertexId(), newMessage);
        }
    }

    vertex.voteToHalt();
}