Example usage for org.apache.mahout.math Vector plus

List of usage examples for org.apache.mahout.math Vector plus

Introduction

In this page you can find the example usage for org.apache.mahout.math Vector plus.

Prototype

Vector plus(Vector x);

Source Link

Document

Return a new vector containing the element by element sum of the recipient and the argument

Usage

From source file:org.trustedanalytics.atk.giraph.algorithms.lbp.LoopyBeliefPropagationComputation.java

License:Apache License

/**
 * Initialize vertex//  w  w  w .j av a 2 s.c o m
 *
 * @param vertex of the graph
 */
private void initializeVertex(Vertex<LongWritable, VertexData4LBPWritable, DoubleWritable> vertex) {
    // normalize prior and posterior
    Vector prior = vertex.getValue().getPriorVector();
    Vector posterior = vertex.getValue().getPosteriorVector();
    int nStates = prior.size();
    double sum = 0d;
    for (int i = 0; i < nStates; i++) {
        double v = prior.getQuick(i);
        if (v < 0d) {
            throw new IllegalArgumentException("Vertex ID: " + vertex.getId() + " has negative prior value.");
        } else if (v < MIN_PRIOR_VALUE) {
            v = MIN_PRIOR_VALUE;
            prior.setQuick(i, v);
        }
        sum += v;
    }
    for (int i = 0; i < nStates; i++) {
        posterior.setQuick(i, prior.getQuick(i) / sum);
        prior.setQuick(i, Math.log(posterior.getQuick(i)));
    }
    // collect graph statistics
    VertexType vt = vertex.getValue().getType();
    vt = ignoreVertexType ? VertexType.TRAIN : vt;
    switch (vt) {
    case TRAIN:
        aggregate(SUM_TRAIN_VERTICES, new LongWritable(1));
        break;
    case VALIDATE:
        aggregate(SUM_VALIDATE_VERTICES, new LongWritable(1));
        break;
    case TEST:
        aggregate(SUM_TEST_VERTICES, new LongWritable(1));
        break;
    default:
        throw new IllegalArgumentException("Unknown vertex type: " + vt.toString());
    }
    // if it's not a training vertex, use uniform posterior and don't send out messages
    if (vt != VertexType.TRAIN) {
        posterior.assign(1.0 / nStates);
        return;
    }
    // calculate messages
    IdWithVectorMessage newMessage = new IdWithVectorMessage();
    newMessage.setData(vertex.getId().get());
    // calculate initial belief
    Vector belief = prior.clone();
    for (Edge<LongWritable, DoubleWritable> edge : vertex.getEdges()) {
        double weight = edge.getValue().get();
        if (weight <= 0d) {
            throw new IllegalArgumentException("Vertex ID: " + vertex.getId()
                    + " has an edge with negative or zero weight value " + weight);
        }
        for (int i = 0; i < nStates; i++) {
            sum = 0d;
            for (int j = 0; j < nStates; j++) {
                double msg = Math.exp(
                        prior.getQuick(j) + edgePotential(Math.abs(i - j) / (double) (nStates - 1), weight));
                if (maxProduct) {
                    sum = sum > msg ? sum : msg;
                } else {
                    sum += msg;
                }
            }
            belief.setQuick(i, sum > 0d ? Math.log(sum) : Double.MIN_VALUE);
        }
        belief = belief.plus(-belief.maxValue());
        // send out messages
        newMessage.setVector(belief);
        sendMessage(edge.getTargetVertexId(), newMessage);
    }
}

From source file:org.trustedanalytics.atk.giraph.algorithms.lbp.LoopyBeliefPropagationComputation.java

License:Apache License

@Override
public void compute(Vertex<LongWritable, VertexData4LBPWritable, DoubleWritable> vertex,
        Iterable<IdWithVectorMessage> messages) throws IOException {
    long step = getSuperstep();
    if (step == 0) {
        initializeVertex(vertex);//w ww .  jav  a2s. c  om
        return;
    }

    // collect messages sent to this vertex
    HashMap<Long, Vector> map = new HashMap<Long, Vector>();
    for (IdWithVectorMessage message : messages) {
        map.put(message.getData(), message.getVector());
    }

    // update posterior according to prior and messages
    VertexData4LBPWritable vertexValue = vertex.getValue();
    VertexType vt = vertexValue.getType();
    vt = ignoreVertexType ? VertexType.TRAIN : vt;
    Vector prior = vertexValue.getPriorVector();
    double nStates = prior.size();
    if (vt != VertexType.TRAIN) {
        // assign a uniform prior for validate/test vertex
        prior = prior.clone().assign(Math.log(1.0 / nStates));
    }
    // sum of prior and messages
    Vector sumPosterior = prior;
    for (IdWithVectorMessage message : messages) {
        sumPosterior = sumPosterior.plus(message.getVector());
    }
    sumPosterior = sumPosterior.plus(-sumPosterior.maxValue());
    // update posterior if this isn't an anchor vertex
    if (prior.maxValue() < anchorThreshold) {
        // normalize posterior
        Vector posterior = sumPosterior.clone().assign(Functions.EXP);
        posterior = posterior.normalize(1d);
        Vector oldPosterior = vertexValue.getPosteriorVector();
        double delta = posterior.minus(oldPosterior).norm(1d);
        // aggregate deltas
        switch (vt) {
        case TRAIN:
            aggregate(SUM_TRAIN_DELTA, new DoubleWritable(delta));
            break;
        case VALIDATE:
            aggregate(SUM_VALIDATE_DELTA, new DoubleWritable(delta));
            break;
        case TEST:
            aggregate(SUM_TEST_DELTA, new DoubleWritable(delta));
            break;
        default:
            throw new IllegalArgumentException("Unknown vertex type: " + vt.toString());
        }
        // update posterior
        vertexValue.setPosteriorVector(posterior);
    }

    if (step < maxSupersteps) {
        // if it's not a training vertex, don't send out messages
        if (vt != VertexType.TRAIN) {
            return;
        }
        IdWithVectorMessage newMessage = new IdWithVectorMessage();
        newMessage.setData(vertex.getId().get());
        // update belief
        Vector belief = prior.clone();
        for (Edge<LongWritable, DoubleWritable> edge : vertex.getEdges()) {
            double weight = edge.getValue().get();
            long id = edge.getTargetVertexId().get();
            Vector tempVector = sumPosterior;
            if (map.containsKey(id)) {
                tempVector = sumPosterior.minus(map.get(id));
            }
            for (int i = 0; i < nStates; i++) {
                double sum = 0d;
                for (int j = 0; j < nStates; j++) {
                    double msg = Math.exp(
                            tempVector.getQuick(j) + edgePotential(Math.abs(i - j) / (nStates - 1), weight));
                    if (maxProduct) {
                        sum = sum > msg ? sum : msg;
                    } else {
                        sum += msg;
                    }
                }
                belief.setQuick(i, sum > 0d ? Math.log(sum) : Double.MIN_VALUE);
            }
            belief = belief.plus(-belief.maxValue());
            newMessage.setVector(belief);
            sendMessage(edge.getTargetVertexId(), newMessage);
        }
    } else {
        // convert prior back to regular scale before output
        prior = vertexValue.getPriorVector();
        prior = prior.assign(Functions.EXP);
        vertexValue.setPriorVector(prior);
        vertex.voteToHalt();
    }
}

From source file:org.trustedanalytics.atk.giraph.algorithms.lda.CVB0LDAComputation.java

License:Apache License

/**
 * Update vertex value according to edge value
 *
 * @param vertex of the graph/*from w  w w.  ja va2 s.  c o  m*/
 */
private void updateVertex(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex) {
    Vector vector = vertex.getValue().getLdaResult().clone().assign(0d);
    for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getEdges()) {
        double weight = edge.getValue().getWordCount();
        Vector gamma = edge.getValue().getVector();
        vector = vector.plus(gamma.times(weight));
    }
    vertex.getValue().setLdaResult(vector);
    if (vertex.getId().isWord()) {
        aggregate(SUM_WORD_VERTEX_VALUE, new VectorWritable(vector));
    }
}

From source file:org.trustedanalytics.atk.giraph.algorithms.lda.CVB0LDAComputation.java

License:Apache License

/**
 * Normalize vertex value// w ww. j a va 2  s .  co m
 *
 * @param vertex of the graph
 */
private void normalizeVertex(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex) {
    Vector vector = vertex.getValue().getLdaResult();
    if (vertex.getId().isDocument()) {
        vector = vector.plus(config.alpha()).normalize(1d);
    } else {
        vector = vector.plus(config.beta()).times(nk.plus(numWords * config.beta()).assign(Functions.INV));
    }
    // update vertex value
    vertex.getValue().setLdaResult(vector);
}

From source file:org.trustedanalytics.atk.giraph.algorithms.lda.CVB0LDAComputation.java

License:Apache License

/**
 * Evaluate cost according to vertex and messages
 *
 * @param vertex of the graph/*from  ww w  . j a  v  a  2s. com*/
 * @param messages of type iterable
 * @param map of type HashMap
 */
private void evaluateCost(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex, Iterable<LdaMessage> messages,
        HashMap<LdaVertexId, Vector> map) {

    if (vertex.getId().isDocument()) {
        return;
    }
    Vector vector = vertex.getValue().getLdaResult();
    vector = vector.plus(config.beta()).times(nk.plus(numWords * config.beta()).assign(Functions.INV));

    double cost = 0d;
    for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getEdges()) {
        double weight = edge.getValue().getWordCount();
        LdaVertexId id = edge.getTargetVertexId();
        if (map.containsKey(id)) {
            Vector otherVector = map.get(id);
            otherVector = otherVector.plus(config.alpha()).normalize(1d);
            cost -= weight * Math.log(vector.dot(otherVector));
        } else {
            throw new IllegalArgumentException(
                    String.format("Vertex ID %s: A message is mis-matched", vertex.getId().getValue()));
        }
    }
    aggregate(SUM_COST, new DoubleWritable(cost));
}

From source file:org.trustedanalytics.atk.giraph.algorithms.lda.GiraphLdaComputation.java

License:Apache License

/**
 * Update vertex value with edge value// w w w.  j  a v  a 2s.com
 *
 * @param vector vector of vertex value
 * @param edge of the graph
 */
private Vector updateVector(Vector vector, Edge<LdaVertexId, LdaEdgeData> edge) {
    double weight = edge.getValue().getWordCount();
    Vector gamma = edge.getValue().getVector();
    vector = vector.plus(gamma.times(weight));
    return vector;
}

From source file:org.trustedanalytics.atk.giraph.algorithms.lp.LabelPropagationComputation.java

License:Apache License

@Override
public void compute(Vertex<LongWritable, VertexData4LPWritable, DoubleWritable> vertex,
        Iterable<IdWithVectorMessage> messages) throws IOException {
    long superStep = getSuperstep();

    if (superStep == 0) {
        initializeVertexEdges(vertex);/*from   w  ww  . j ava2 s  .  co m*/
        vertex.voteToHalt();
    } else if (superStep <= maxSupersteps) {
        VertexData4LPWritable vertexValue = vertex.getValue();
        Vector prior = vertexValue.getPriorVector();
        Vector posterior = vertexValue.getPosteriorVector();
        double degree = vertexValue.getDegree();

        // collect messages sent to this vertex
        HashMap<Long, Vector> map = new HashMap();
        for (IdWithVectorMessage message : messages) {
            map.put(message.getData(), message.getVector());
        }

        // Update belief and calculate cost
        double hi = prior.getQuick(0);
        double fi = posterior.getQuick(0);
        double crossSum = 0d;
        Vector newBelief = posterior.clone().assign(0d);

        for (Edge<LongWritable, DoubleWritable> edge : vertex.getEdges()) {
            double weight = edge.getValue().get();
            if (weight <= 0d) {
                throw new IllegalArgumentException(
                        "Vertex ID: " + vertex.getId() + "has an edge with negative or zero value");
            }
            long targetVertex = edge.getTargetVertexId().get();
            if (map.containsKey(targetVertex)) {
                Vector tempVector = map.get(targetVertex);
                newBelief = newBelief.plus(tempVector.times(weight));
                double fj = tempVector.getQuick(0);
                crossSum += weight * fi * fj;
            }
        }

        double cost = degree
                * ((1 - lambda) * (Math.pow(fi, 2) - crossSum) + 0.5 * lambda * Math.pow((fi - hi), 2));
        aggregate(SUM_COST, new DoubleWritable(cost));

        // Update posterior if the vertex was not processed
        if (vertexValue.wasLabeled() == false) {
            newBelief = (newBelief.times(1 - lambda).plus(prior.times(lambda))).normalize(1d);
            vertexValue.setPosteriorVector(newBelief);
        }

        // Send out messages if not the last step
        if (superStep != maxSupersteps) {
            IdWithVectorMessage newMessage = new IdWithVectorMessage(vertex.getId().get(),
                    vertexValue.getPosteriorVector());
            sendMessageToAllEdges(vertex, newMessage);
        }
    }

    vertex.voteToHalt();
}

From source file:zx.soft.mahout.knn.search.AbstractSearchTest.java

License:Apache License

@Test
public void testNearMatch() {
    List<MatrixSlice> queries = Lists.newArrayList(Iterables.limit(testData(), 100));
    Searcher s = getSearch(20);//from  ww  w  .j a va  2 s.c om
    s.addAllMatrixSlicesAsWeightedVectors(testData());

    MultiNormal noise = new MultiNormal(0.01, new DenseVector(20));
    for (MatrixSlice slice : queries) {
        Vector query = slice.vector();
        final Vector epsilon = noise.sample();
        //         List<WeightedThing<Vector>> r0 = s.search(query, 2);
        query = query.plus(epsilon);
        List<WeightedThing<Vector>> r = s.search(query, 2);
        r = s.search(query, 2);
        assertEquals("Distance has to be small", epsilon.norm(2), r.get(0).getWeight(), 1e-5);
        assertEquals("Answer must be substantially the same as query", epsilon.norm(2),
                r.get(0).getValue().minus(query).norm(2), 1e-5);
        assertTrue("Wrong answer must be further away", r.get(1).getWeight() > r.get(0).getWeight());
    }
}