Example usage for org.apache.mahout.math Vector plus

List of usage examples for org.apache.mahout.math Vector plus

Introduction

In this page you can find the example usage for org.apache.mahout.math Vector plus.

Prototype

Vector plus(Vector x);

Source Link

Document

Return a new vector containing the element by element sum of the recipient and the argument

Usage

From source file:cn.edu.bjtu.cit.recommender.Recommender.java

License:Apache License

@SuppressWarnings("unchecked")
public int run(String[] args) throws Exception {
    if (args.length < 2) {
        System.err.println();//from w  w  w  .j a v a2  s  .co m
        System.err.println("Usage: " + this.getClass().getName()
                + " [generic options] input output [profiling] [estimation] [clustersize]");
        System.err.println();
        printUsage();
        GenericOptionsParser.printGenericCommandUsage(System.err);

        return 1;
    }
    OptionParser parser = new OptionParser(args);

    Pipeline pipeline = new MRPipeline(Recommender.class, getConf());

    if (parser.hasOption(CLUSTER_SIZE)) {
        pipeline.getConfiguration().setInt(ClusterOracle.CLUSTER_SIZE,
                Integer.parseInt(parser.getOption(CLUSTER_SIZE).getValue()));
    }

    if (parser.hasOption(PROFILING)) {
        pipeline.getConfiguration().setBoolean(Profiler.IS_PROFILE, true);
        this.profileFilePath = parser.getOption(PROFILING).getValue();

    }

    if (parser.hasOption(ESTIMATION)) {
        estFile = parser.getOption(ESTIMATION).getValue();
        est = new Estimator(estFile, clusterSize);
    }

    if (parser.hasOption(OPT_REDUCE)) {
        pipeline.getConfiguration().setBoolean(OPT_REDUCE, true);
    }

    if (parser.hasOption(OPT_MSCR)) {
        pipeline.getConfiguration().setBoolean(OPT_MSCR, true);
    }

    if (parser.hasOption(ACTIVE_THRESHOLD)) {
        threshold = Integer.parseInt(parser.getOption("at").getValue());
    }

    if (parser.hasOption(TOP)) {
        top = Integer.parseInt(parser.getOption("top").getValue());
    }

    profiler = new Profiler(pipeline);
    /*
     * input node
     */
    PCollection<String> lines = pipeline.readTextFile(args[0]);

    if (profiler.isProfiling() && lines.getSize() > 10 * 1024 * 1024) {
        lines = lines.sample(0.1);
    }

    /*
     * S0 + GBK
     */
    PGroupedTable<Long, Long> userWithPrefs = lines.parallelDo(new MapFn<String, Pair<Long, Long>>() {

        @Override
        public Pair<Long, Long> map(String input) {
            String[] split = input.split(Estimator.DELM);
            long userID = Long.parseLong(split[0]);
            long itemID = Long.parseLong(split[1]);
            return Pair.of(userID, itemID);
        }

        @Override
        public float scaleFactor() {
            return est.getScaleFactor("S0").sizeFactor;
        }

        @Override
        public float scaleFactorByRecord() {
            return est.getScaleFactor("S0").recsFactor;
        }
    }, Writables.tableOf(Writables.longs(), Writables.longs())).groupByKey(est.getClusterSize());

    /*
     * S1
     */
    PTable<Long, Vector> userVector = userWithPrefs
            .parallelDo(new MapFn<Pair<Long, Iterable<Long>>, Pair<Long, Vector>>() {
                @Override
                public Pair<Long, Vector> map(Pair<Long, Iterable<Long>> input) {
                    Vector userVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
                    for (long itemPref : input.second()) {
                        userVector.set((int) itemPref, 1.0f);
                    }
                    return Pair.of(input.first(), userVector);
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S1").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S1").recsFactor;
                }
            }, Writables.tableOf(Writables.longs(), Writables.vectors()));

    userVector = profiler.profile("S0-S1", pipeline, userVector, ProfileConverter.long_vector(),
            Writables.tableOf(Writables.longs(), Writables.vectors()));

    /*
     * S2
     */
    PTable<Long, Vector> filteredUserVector = userVector
            .parallelDo(new DoFn<Pair<Long, Vector>, Pair<Long, Vector>>() {

                @Override
                public void process(Pair<Long, Vector> input, Emitter<Pair<Long, Vector>> emitter) {
                    if (input.second().getNumNondefaultElements() > threshold) {
                        emitter.emit(input);
                    }
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S2").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S2").recsFactor;
                }

            }, Writables.tableOf(Writables.longs(), Writables.vectors()));

    filteredUserVector = profiler.profile("S2", pipeline, filteredUserVector, ProfileConverter.long_vector(),
            Writables.tableOf(Writables.longs(), Writables.vectors()));

    /*
     * S3 + GBK
     */
    PGroupedTable<Integer, Integer> coOccurencePairs = filteredUserVector
            .parallelDo(new DoFn<Pair<Long, Vector>, Pair<Integer, Integer>>() {
                @Override
                public void process(Pair<Long, Vector> input, Emitter<Pair<Integer, Integer>> emitter) {
                    Iterator<Vector.Element> it = input.second().iterateNonZero();
                    while (it.hasNext()) {
                        int index1 = it.next().index();
                        Iterator<Vector.Element> it2 = input.second().iterateNonZero();
                        while (it2.hasNext()) {
                            int index2 = it2.next().index();
                            emitter.emit(Pair.of(index1, index2));
                        }
                    }
                }

                @Override
                public float scaleFactor() {
                    float size = est.getScaleFactor("S3").sizeFactor;
                    return size;
                }

                @Override
                public float scaleFactorByRecord() {
                    float recs = est.getScaleFactor("S3").recsFactor;
                    return recs;
                }
            }, Writables.tableOf(Writables.ints(), Writables.ints())).groupByKey(est.getClusterSize());

    /*
     * S4
     */
    PTable<Integer, Vector> coOccurenceVector = coOccurencePairs
            .parallelDo(new MapFn<Pair<Integer, Iterable<Integer>>, Pair<Integer, Vector>>() {
                @Override
                public Pair<Integer, Vector> map(Pair<Integer, Iterable<Integer>> input) {
                    Vector cooccurrenceRow = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
                    for (int itemIndex2 : input.second()) {
                        cooccurrenceRow.set(itemIndex2, cooccurrenceRow.get(itemIndex2) + 1.0);
                    }
                    return Pair.of(input.first(), cooccurrenceRow);
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S4").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S4").recsFactor;
                }
            }, Writables.tableOf(Writables.ints(), Writables.vectors()));

    coOccurenceVector = profiler.profile("S3-S4", pipeline, coOccurenceVector, ProfileConverter.int_vector(),
            Writables.tableOf(Writables.ints(), Writables.vectors()));

    /*
     * S5 Wrapping co-occurrence columns
     */
    PTable<Integer, VectorOrPref> wrappedCooccurrence = coOccurenceVector
            .parallelDo(new MapFn<Pair<Integer, Vector>, Pair<Integer, VectorOrPref>>() {

                @Override
                public Pair<Integer, VectorOrPref> map(Pair<Integer, Vector> input) {
                    return Pair.of(input.first(), new VectorOrPref(input.second()));
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S5").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S5").recsFactor;
                }

            }, Writables.tableOf(Writables.ints(), VectorOrPref.vectorOrPrefs()));

    wrappedCooccurrence = profiler.profile("S5", pipeline, wrappedCooccurrence, ProfileConverter.int_vopv(),
            Writables.tableOf(Writables.ints(), VectorOrPref.vectorOrPrefs()));

    /*
     * S6 Splitting user vectors
     */
    PTable<Integer, VectorOrPref> userVectorSplit = filteredUserVector
            .parallelDo(new DoFn<Pair<Long, Vector>, Pair<Integer, VectorOrPref>>() {

                @Override
                public void process(Pair<Long, Vector> input, Emitter<Pair<Integer, VectorOrPref>> emitter) {
                    long userID = input.first();
                    Vector userVector = input.second();
                    Iterator<Vector.Element> it = userVector.iterateNonZero();
                    while (it.hasNext()) {
                        Vector.Element e = it.next();
                        int itemIndex = e.index();
                        float preferenceValue = (float) e.get();
                        emitter.emit(Pair.of(itemIndex, new VectorOrPref(userID, preferenceValue)));
                    }
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S6").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S6").recsFactor;
                }
            }, Writables.tableOf(Writables.ints(), VectorOrPref.vectorOrPrefs()));

    userVectorSplit = profiler.profile("S6", pipeline, userVectorSplit, ProfileConverter.int_vopp(),
            Writables.tableOf(Writables.ints(), VectorOrPref.vectorOrPrefs()));

    /*
     * S7 Combine VectorOrPrefs
     */
    PTable<Integer, VectorAndPrefs> combinedVectorOrPref = wrappedCooccurrence.union(userVectorSplit)
            .groupByKey(est.getClusterSize())
            .parallelDo(new DoFn<Pair<Integer, Iterable<VectorOrPref>>, Pair<Integer, VectorAndPrefs>>() {

                @Override
                public void process(Pair<Integer, Iterable<VectorOrPref>> input,
                        Emitter<Pair<Integer, VectorAndPrefs>> emitter) {
                    Vector vector = null;
                    List<Long> userIDs = Lists.newArrayList();
                    List<Float> values = Lists.newArrayList();
                    for (VectorOrPref vop : input.second()) {
                        if (vector == null) {
                            vector = vop.getVector();
                        }
                        long userID = vop.getUserID();
                        if (userID != Long.MIN_VALUE) {
                            userIDs.add(vop.getUserID());
                        }
                        float value = vop.getValue();
                        if (!Float.isNaN(value)) {
                            values.add(vop.getValue());
                        }
                    }
                    emitter.emit(Pair.of(input.first(), new VectorAndPrefs(vector, userIDs, values)));
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S7").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S7").recsFactor;
                }
            }, Writables.tableOf(Writables.ints(), VectorAndPrefs.vectorAndPrefs()));

    combinedVectorOrPref = profiler.profile("S5+S6-S7", pipeline, combinedVectorOrPref,
            ProfileConverter.int_vap(), Writables.tableOf(Writables.ints(), VectorAndPrefs.vectorAndPrefs()));
    /*
     * S8 Computing partial recommendation vectors
     */
    PTable<Long, Vector> partialMultiply = combinedVectorOrPref
            .parallelDo(new DoFn<Pair<Integer, VectorAndPrefs>, Pair<Long, Vector>>() {
                @Override
                public void process(Pair<Integer, VectorAndPrefs> input, Emitter<Pair<Long, Vector>> emitter) {
                    Vector cooccurrenceColumn = input.second().getVector();
                    List<Long> userIDs = input.second().getUserIDs();
                    List<Float> prefValues = input.second().getValues();
                    for (int i = 0; i < userIDs.size(); i++) {
                        long userID = userIDs.get(i);
                        if (userID != Long.MIN_VALUE) {
                            float prefValue = prefValues.get(i);
                            Vector partialProduct = cooccurrenceColumn.times(prefValue);
                            emitter.emit(Pair.of(userID, partialProduct));
                        }
                    }
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S8").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S8").recsFactor;
                }

            }, Writables.tableOf(Writables.longs(), Writables.vectors())).groupByKey(est.getClusterSize())
            .combineValues(new CombineFn<Long, Vector>() {

                @Override
                public void process(Pair<Long, Iterable<Vector>> input, Emitter<Pair<Long, Vector>> emitter) {
                    Vector partial = null;
                    for (Vector vector : input.second()) {
                        partial = partial == null ? vector : partial.plus(vector);
                    }
                    emitter.emit(Pair.of(input.first(), partial));
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("combine").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("combine").recsFactor;
                }
            });

    partialMultiply = profiler.profile("S8-combine", pipeline, partialMultiply, ProfileConverter.long_vector(),
            Writables.tableOf(Writables.longs(), Writables.vectors()));

    /*
     * S9 Producing recommendations from vectors
     */
    PTable<Long, RecommendedItems> recommendedItems = partialMultiply
            .parallelDo(new DoFn<Pair<Long, Vector>, Pair<Long, RecommendedItems>>() {

                @Override
                public void process(Pair<Long, Vector> input, Emitter<Pair<Long, RecommendedItems>> emitter) {
                    Queue<RecommendedItem> topItems = new PriorityQueue<RecommendedItem>(11,
                            Collections.reverseOrder(BY_PREFERENCE_VALUE));
                    Iterator<Vector.Element> recommendationVectorIterator = input.second().iterateNonZero();
                    while (recommendationVectorIterator.hasNext()) {
                        Vector.Element element = recommendationVectorIterator.next();
                        int index = element.index();
                        float value = (float) element.get();
                        if (topItems.size() < top) {
                            topItems.add(new GenericRecommendedItem(index, value));
                        } else if (value > topItems.peek().getValue()) {
                            topItems.add(new GenericRecommendedItem(index, value));
                            topItems.poll();
                        }
                    }
                    List<RecommendedItem> recommendations = new ArrayList<RecommendedItem>(topItems.size());
                    recommendations.addAll(topItems);
                    Collections.sort(recommendations, BY_PREFERENCE_VALUE);
                    emitter.emit(Pair.of(input.first(), new RecommendedItems(recommendations)));
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S9").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S9").recsFactor;
                }

            }, Writables.tableOf(Writables.longs(), RecommendedItems.recommendedItems()));

    recommendedItems = profiler.profile("S9", pipeline, recommendedItems, ProfileConverter.long_ri(),
            Writables.tableOf(Writables.longs(), RecommendedItems.recommendedItems()));

    /*
     * Profiling
     */
    if (profiler.isProfiling()) {
        profiler.writeResultToFile(profileFilePath);
        profiler.cleanup(pipeline.getConfiguration());
        return 0;
    }
    /*
     * asText
     */
    pipeline.writeTextFile(recommendedItems, args[1]);
    PipelineResult result = pipeline.done();
    return result.succeeded() ? 0 : 1;
}

From source file:com.cloudera.science.ml.kmeans.core.KMeans.java

License:Open Source License

/**
 * Compute the {@code Vector} that is the centroid of the given weighted points.
 * //from   w ww.  j  a v  a  2s .co m
 * @param points The weighted points
 * @return The centroid of the weighted points
 */
public <V extends Vector> Vector centroid(Collection<Weighted<V>> points) {
    Vector center = null;
    long sz = 0;
    for (Weighted<V> v : points) {
        Vector weighted = v.thing().times(v.weight());
        if (center == null) {
            center = weighted;
        } else {
            center = center.plus(weighted);
        }
        sz += v.weight();
    }
    return center.divide(sz);
}

From source file:com.cloudera.science.ml.kmeans.core.LloydsUpdateStrategy.java

License:Open Source License

/**
 * Compute the {@code Vector} that is the centroid of the given weighted points.
 * /*  w  ww.ja va2  s. c o  m*/
 * @param points The weighted points
 * @return The centroid of the weighted points
 */
public <V extends Vector> Vector centroid(Collection<Weighted<V>> points) {
    Vector center = null;
    double sz = 0.0;
    for (Weighted<V> v : points) {
        Vector weighted = v.thing().times(v.weight());
        if (center == null) {
            center = weighted;
        } else {
            center = center.plus(weighted);
        }
        sz += v.weight();
    }
    return center.divide(sz);
}

From source file:com.cloudera.science.ml.kmeans.core.MiniBatchUpdateStrategy.java

License:Open Source License

@Override
public <V extends Vector> Centers update(List<Weighted<V>> points, Centers centers) {
    int[] perCenterStepCounts = new int[centers.size()];
    WeightedSampler<V> sampler = new WeightedSampler<V>(points, random);
    for (int iter = 0; iter < numIterations; iter++) {
        // Compute closest cent for each mini-batch
        List<List<V>> centerAssignments = Lists.newArrayList();
        for (int i = 0; i < centers.size(); i++) {
            centerAssignments.add(Lists.<V>newArrayList());
        }//from   www.  j  a  va  2s. c  o  m
        for (int i = 0; i < miniBatchSize; i++) {
            V sample = sampler.sample();
            int closestId = centers.indexOfClosest(sample);
            centerAssignments.get(closestId).add(sample);
        }
        // Apply the mini-batch
        List<Vector> nextCenters = Lists.newArrayList();
        for (int i = 0; i < centerAssignments.size(); i++) {
            Vector currentCenter = centers.get(i);
            for (int j = 0; j < centerAssignments.get(i).size(); j++) {
                double eta = 1.0 / (++perCenterStepCounts[i] + 1.0);
                currentCenter = currentCenter.times(1.0 - eta);
                currentCenter = currentCenter.plus(centerAssignments.get(i).get(j).times(eta));
            }
            nextCenters.add(currentCenter);
        }
        centers = new Centers(nextCenters);
    }
    return centers;
}

From source file:nl.gridline.zieook.inx.movielens.AggregateAndRecommendReducer.java

License:Apache License

private void reduceBooleanData(VarLongWritable userID, Iterable<PrefAndSimilarityColumnWritable> values,
        Context context) throws IOException, InterruptedException {
    /*/*from  ww  w .  j  a  v a2  s .c o m*/
     * having boolean data, each estimated preference can only be 1,
     * however we can't use this to rank the recommended items,
     * so we use the sum of similarities for that.
     */
    Vector predictionVector = null;
    for (PrefAndSimilarityColumnWritable prefAndSimilarityColumn : values) {
        predictionVector = predictionVector == null ? prefAndSimilarityColumn.getSimilarityColumn()
                : predictionVector.plus(prefAndSimilarityColumn.getSimilarityColumn());
    }
    writeRecommendedItems(userID, predictionVector, context);
}

From source file:nl.gridline.zieook.inx.movielens.AggregateAndRecommendReducer.java

License:Apache License

private void reduceNonBooleanData(VarLongWritable userID, Iterable<PrefAndSimilarityColumnWritable> values,
        Context context) throws IOException, InterruptedException {
    /* each entry here is the sum in the numerator of the prediction formula */
    Vector numerators = null;
    /* each entry here is the sum in the denominator of the prediction formula */
    Vector denominators = null;//from  w  w  w .  ja  va 2s .c  o  m
    /* each entry here is the number of similar items used in the prediction formula */
    Vector numberOfSimilarItemsUsed = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);

    for (PrefAndSimilarityColumnWritable prefAndSimilarityColumn : values) {
        Vector simColumn = prefAndSimilarityColumn.getSimilarityColumn();
        float prefValue = prefAndSimilarityColumn.getPrefValue();
        /* count the number of items used for each prediction */
        Iterator<Vector.Element> usedItemsIterator = simColumn.iterateNonZero();
        while (usedItemsIterator.hasNext()) {
            int itemIDIndex = usedItemsIterator.next().index();
            numberOfSimilarItemsUsed.setQuick(itemIDIndex, numberOfSimilarItemsUsed.getQuick(itemIDIndex) + 1);
        }

        numerators = numerators == null
                ? prefValue == BOOLEAN_PREF_VALUE ? simColumn.clone() : simColumn.times(prefValue)
                : numerators.plus(prefValue == BOOLEAN_PREF_VALUE ? simColumn : simColumn.times(prefValue));

        simColumn.assign(ABSOLUTE_VALUES);
        denominators = denominators == null ? simColumn : denominators.plus(simColumn);
    }

    if (numerators == null) {
        return;
    }

    Vector recommendationVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
    Iterator<Vector.Element> iterator = numerators.iterateNonZero();
    while (iterator.hasNext()) {
        Vector.Element element = iterator.next();
        int itemIDIndex = element.index();
        /* preference estimations must be based on at least 2 datapoints */
        if (numberOfSimilarItemsUsed.getQuick(itemIDIndex) > 1) {
            /* compute normalized prediction */
            double prediction = element.get() / denominators.getQuick(itemIDIndex);
            recommendationVector.setQuick(itemIDIndex, prediction);
        }
    }
    writeRecommendedItems(userID, recommendationVector, context);
}

From source file:org.trustedanalytics.atk.giraph.aggregators.VectorSumAggregator.java

License:Apache License

@Override
public void aggregate(VectorWritable value) {
    Vector currentValue = getAggregatedValue().get();
    if (currentValue.size() == 0) {
        getAggregatedValue().set(value.get());
    } else if (value.get().size() > 0) {
        getAggregatedValue().set(currentValue.plus(value.get()));
    }//from   w  w w . j a  v  a  2 s.com
}

From source file:org.trustedanalytics.atk.giraph.algorithms.als.AlternatingLeastSquaresComputation.java

License:Apache License

@Override
public void compute(Vertex<CFVertexId, VertexData4CFWritable, EdgeData4CFWritable> vertex,
        Iterable<MessageData4CFWritable> messages) throws IOException {
    long step = getSuperstep();
    if (step == 0) {
        initialize(vertex);/*from  w w  w .j  a v  a2 s  . c  om*/
        vertex.voteToHalt();
        return;
    }

    Vector currentValue = vertex.getValue().getVector();
    double currentBias = vertex.getValue().getBias();
    // update aggregators every (2 * interval) super steps
    if ((step % (2 * learningCurveOutputInterval)) == 0) {
        double errorOnTrain = 0d;
        double errorOnValidate = 0d;
        double errorOnTest = 0d;
        int numTrain = 0;
        for (MessageData4CFWritable message : messages) {
            EdgeType et = message.getType();
            double weight = message.getWeight();
            Vector vector = message.getVector();
            double otherBias = message.getBias();
            double predict = currentBias + otherBias + currentValue.dot(vector);
            double e = weight - predict;
            switch (et) {
            case TRAIN:
                errorOnTrain += e * e;
                numTrain++;
                break;
            case VALIDATE:
                errorOnValidate += e * e;
                break;
            case TEST:
                errorOnTest += e * e;
                break;
            default:
                throw new IllegalArgumentException("Unknown recognized edge type: " + et.toString());
            }
        }
        double costOnTrain = 0d;
        if (numTrain > 0) {
            costOnTrain = errorOnTrain / numTrain
                    + lambda * (currentBias * currentBias + currentValue.dot(currentValue));
        }
        aggregate(SUM_TRAIN_COST, new DoubleWritable(costOnTrain));
        aggregate(SUM_VALIDATE_ERROR, new DoubleWritable(errorOnValidate));
        aggregate(SUM_TEST_ERROR, new DoubleWritable(errorOnTest));
    }

    // update vertex value
    if (step < maxSupersteps) {
        // xxt records the result of x times x transpose
        Matrix xxt = new DenseMatrix(featureDimension, featureDimension);
        xxt = xxt.assign(0d);
        // xr records the result of x times rating
        Vector xr = currentValue.clone().assign(0d);
        int numTrain = 0;
        for (MessageData4CFWritable message : messages) {
            EdgeType et = message.getType();
            if (et == EdgeType.TRAIN) {
                double weight = message.getWeight();
                Vector vector = message.getVector();
                double otherBias = message.getBias();
                xxt = xxt.plus(vector.cross(vector));
                xr = xr.plus(vector.times(weight - currentBias - otherBias));
                numTrain++;
            }
        }
        xxt = xxt.plus(new DiagonalMatrix(lambda * numTrain, featureDimension));
        Matrix bMatrix = new DenseMatrix(featureDimension, 1).assignColumn(0, xr);
        Vector value = new QRDecomposition(xxt).solve(bMatrix).viewColumn(0);
        vertex.getValue().setVector(value);

        // update vertex bias
        if (biasOn) {
            double bias = computeBias(value, messages);
            vertex.getValue().setBias(bias);
        }

        // send out messages
        for (Edge<CFVertexId, EdgeData4CFWritable> edge : vertex.getEdges()) {
            MessageData4CFWritable newMessage = new MessageData4CFWritable(vertex.getValue(), edge.getValue());
            sendMessage(edge.getTargetVertexId(), newMessage);
        }
    }

    vertex.voteToHalt();
}

From source file:org.trustedanalytics.atk.giraph.algorithms.cgd.ConjugateGradientDescentComputation.java

License:Apache License

/**
 * Compute gradient//from  w  ww .j a v  a2 s .c  om
 *
 * @param bias of type double
 * @param value of type Vector
 * @param messages of type Iterable
 * @return gradient of type Vector
 */
private Vector computeGradient(double bias, Vector value, Iterable<MessageData4CFWritable> messages) {
    Vector xr = value.clone().assign(0d);
    int numTrain = 0;
    for (MessageData4CFWritable message : messages) {
        EdgeType et = message.getType();
        if (et == EdgeType.TRAIN) {
            double weight = message.getWeight();
            Vector vector = message.getVector();
            double otherBias = message.getBias();
            double predict = bias + otherBias + value.dot(vector);
            double e = predict - weight;
            xr = xr.plus(vector.times(e));
            numTrain++;
        }
    }
    Vector gradient = value.clone().assign(0d);
    if (numTrain > 0) {
        gradient = xr.divide(numTrain).plus(value.times(lambda));
    }
    return gradient;
}

From source file:org.trustedanalytics.atk.giraph.algorithms.cgd.ConjugateGradientDescentComputation.java

License:Apache License

@Override
public void compute(Vertex<CFVertexId, VertexData4CGDWritable, EdgeData4CFWritable> vertex,
        Iterable<MessageData4CFWritable> messages) throws IOException {
    long step = getSuperstep();
    if (step == 0) {
        initialize(vertex);//from   ww  w  . j  a v a 2  s .c  o  m
        vertex.voteToHalt();
        return;
    }

    Vector currentValue = vertex.getValue().getVector();
    double currentBias = vertex.getValue().getBias();
    // update aggregators every (2 * interval) super steps
    if ((step % (2 * learningCurveOutputInterval)) == 0) {
        double errorOnTrain = 0d;
        double errorOnValidate = 0d;
        double errorOnTest = 0d;
        int numTrain = 0;
        for (MessageData4CFWritable message : messages) {
            EdgeType et = message.getType();
            double weight = message.getWeight();
            Vector vector = message.getVector();
            double otherBias = message.getBias();
            double predict = currentBias + otherBias + currentValue.dot(vector);
            double e = weight - predict;
            switch (et) {
            case TRAIN:
                errorOnTrain += e * e;
                numTrain++;
                break;
            case VALIDATE:
                errorOnValidate += e * e;
                break;
            case TEST:
                errorOnTest += e * e;
                break;
            default:
                throw new IllegalArgumentException("Unknown recognized edge type: " + et.toString());
            }
        }
        double costOnTrain = 0d;
        if (numTrain > 0) {
            costOnTrain = errorOnTrain / numTrain
                    + lambda * (currentBias * currentBias + currentValue.dot(currentValue));
        }
        aggregate(SUM_TRAIN_COST, new DoubleWritable(costOnTrain));
        aggregate(SUM_VALIDATE_ERROR, new DoubleWritable(errorOnValidate));
        aggregate(SUM_TEST_ERROR, new DoubleWritable(errorOnTest));
    }

    if (step < maxSupersteps) {
        // implement CGD iterations
        Vector value0 = vertex.getValue().getVector();
        Vector gradient0 = vertex.getValue().getGradient();
        Vector conjugate0 = vertex.getValue().getConjugate();
        double bias0 = vertex.getValue().getBias();
        for (int i = 0; i < numCGDIters; i++) {
            double alpha = computeAlpha(gradient0, conjugate0, messages);
            Vector value = value0.plus(conjugate0.times(alpha));
            Vector gradient = computeGradient(bias0, value, messages);
            double beta = computeBeta(gradient0, conjugate0, gradient);
            Vector conjugate = conjugate0.times(beta).minus(gradient);
            value0 = value;
            gradient0 = gradient;
            conjugate0 = conjugate;
        }
        // update vertex values
        vertex.getValue().setVector(value0);
        vertex.getValue().setConjugate(conjugate0);
        vertex.getValue().setGradient(gradient0);

        // update vertex bias
        if (biasOn) {
            double bias = computeBias(value0, messages);
            vertex.getValue().setBias(bias);
        }

        // send out messages
        for (Edge<CFVertexId, EdgeData4CFWritable> edge : vertex.getEdges()) {
            MessageData4CFWritable newMessage = new MessageData4CFWritable(vertex.getValue(), edge.getValue());
            sendMessage(edge.getTargetVertexId(), newMessage);
        }
    }

    vertex.voteToHalt();
}