Example usage for org.apache.mahout.math Vector get

List of usage examples for org.apache.mahout.math Vector get

Introduction

In this page you can find the example usage for org.apache.mahout.math Vector get.

Prototype

double get(int index);

Source Link

Document

Return the value at the given index

Usage

From source file:com.elex.dmp.core.TopicModel.java

License:Apache License

/**
 * Computes {@code p(topic x|term a, document i)} distributions given input document {@code i}.
 * {@code pTGT[x][a]} is the (un-normalized) {@code p(x|a,i)}, or if docTopics is {@code null},
 * {@code p(a|x)} (also un-normalized).//w  w  w .ja  v  a 2 s . c  o  m
 *
 * @param document doc-term vector encoding {@code w(term a|document i)}.
 * @param docTopics {@code docTopics[x]} is the overall weight of topic {@code x} in given
 *          document. If {@code null}, a topic weight of {@code 1.0} is used for all topics.
 * @param termTopicDist storage for output {@code p(x|a,i)} distributions.
 */
private void pTopicGivenTerm(Vector document, Vector docTopics, Matrix termTopicDist) {
    // for each topic x
    for (int x = 0; x < numTopics; x++) {
        // get p(topic x | document i), or 1.0 if docTopics is null
        double topicWeight = docTopics == null ? 1.0 : docTopics.get(x);
        // get w(term a | topic x)
        Vector topicTermRow = topicTermCounts.viewRow(x);
        // get \sum_a w(term a | topic x)
        double topicSum = topicSums.get(x);
        // get p(topic x | term a) distribution to update
        Vector termTopicRow = termTopicDist.viewRow(x);

        // for each term a in document i with non-zero weight
        Iterator<Vector.Element> it = document.iterateNonZero();
        while (it.hasNext()) {
            Vector.Element e = it.next();
            int termIndex = e.index();

            // calc un-normalized p(topic x | term a, document i)
            double termTopicLikelihood = (topicTermRow.get(termIndex) + eta) * (topicWeight + alpha)
                    / (topicSum + eta * numTerms);
            termTopicRow.set(termIndex, termTopicLikelihood);
        }
    }
}

From source file:com.elex.dmp.core.TopicModel.java

License:Apache License

/**
 * sum_x sum_a (c_ai * log(p(x|i) * p(a|x)))
 *///from   w ww. j  a  v  a2  s . c  o  m
public double perplexity(Vector document, Vector docTopics) {
    double perplexity = 0;
    double norm = docTopics.norm(1) + (docTopics.size() * alpha);
    Iterator<Vector.Element> it = document.iterateNonZero();
    while (it.hasNext()) {
        Vector.Element e = it.next();
        int term = e.index();
        double prob = 0;
        for (int x = 0; x < numTopics; x++) {
            double d = (docTopics.get(x) + alpha) / norm;
            double p = d * (topicTermCounts.viewRow(x).get(term) + eta) / (topicSums.get(x) + eta * numTerms);
            prob += p;
        }
        perplexity += e.get() * Math.log(prob);
    }
    return -perplexity;
}

From source file:com.mapr.stats.bandit.ContextualBayesBandit.java

License:Apache License

private Vector sampleNoLink() {
    final Vector theta = state.aggregateRows(new VectorFunction() {
        final DoubleFunction inverseLink = new InverseLogisticFunction();

        @Override/*from www .  j  av  a  2 s. c o  m*/
        public double apply(Vector f) {
            return inverseLink.apply(rand.nextDouble(f.get(0), f.get(1)));
        }
    });
    return featureMap.times(theta);
}

From source file:com.mapr.stats.bandit.ContextualBayesBanditTest.java

License:Apache License

@Test
public void testConvergence() {
    final Random rand = RandomUtils.getRandom();
    Matrix recipes = new DenseMatrix(100, 10).assign(new DoubleFunction() {
        @Override/*from   w ww  . ja va2 s.c o  m*/
        public double apply(double arg1) {
            return rand.nextDouble() < 0.2 ? 1 : 0;
        }
    });
    recipes.viewColumn(9).assign(1);

    Vector actualWeights = new DenseVector(new double[] { 1, 0.25, -0.25, 0, 0, 0, 0, 0, 0, -1 });

    Vector probs = recipes.times(actualWeights);

    ContextualBayesBandit banditry = new ContextualBayesBandit(recipes);

    for (int i = 0; i < 1000; i++) {
        int k = banditry.sample();
        final boolean success = rand.nextDouble() < probs.get(k);
        banditry.train(k, success);
    }
}

From source file:com.scaleunlimited.classify.vectors.VectorUtilsTest.java

License:Apache License

@Test
public void testMakeExtraVector() {
    List<String> uniqueTerms = new ArrayList<String>(2);
    uniqueTerms.add("a");
    uniqueTerms.add("b");

    Map<String, Integer> docTerms = new HashMap<String, Integer>();
    docTerms.put("a", 1);
    docTerms.put("c", 5);

    Vector v = VectorUtils.makeExtraVector(uniqueTerms, docTerms);
    Assert.assertEquals(1, v.size());//from w  ww.  j  a  v  a 2 s. co m
    Assert.assertEquals(5, new Double(v.get(0)).intValue());
}

From source file:com.scaleunlimited.classify.vectors.VectorUtilsTest.java

License:Apache License

@Test
public void testExtendVector() {
    List<String> uniqueTerms = new ArrayList<String>(2);
    uniqueTerms.add("a");
    uniqueTerms.add("b");

    Vector v1 = VectorUtils.makeVector(uniqueTerms, uniqueTerms);
    Vector v2 = VectorUtils.extendVector(v1, 1);

    Assert.assertEquals(1, new Double(v2.get(0)).intValue());
    Assert.assertEquals(1, new Double(v2.get(1)).intValue());
    Assert.assertEquals(0, new Double(v2.get(2)).intValue());
}

From source file:com.sixgroup.samplerecommender.Point.java

public static void main(String[] args) {

    Map<Point, Integer> points = new HashMap<Point, Integer>();

    points.put(new Point(0, 0), 0);
    points.put(new Point(1, 1), 0);
    points.put(new Point(1, 0), 0);
    points.put(new Point(0, 1), 0);
    points.put(new Point(2, 2), 0);

    points.put(new Point(8, 8), 1);
    points.put(new Point(8, 9), 1);
    points.put(new Point(9, 8), 1);
    points.put(new Point(9, 9), 1);

    OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression();
    learningAlgo = new OnlineLogisticRegression(2, 3, new L1());
    learningAlgo.lambda(0.1);/*from   ww  w .  j a  v  a2 s  .c o  m*/
    learningAlgo.learningRate(10);

    System.out.println("training model  \n");

    for (Point point : points.keySet()) {

        Vector v = getVector(point);
        System.out.println(point + " belongs to " + points.get(point));
        learningAlgo.train(points.get(point), v);
    }

    learningAlgo.close();

    Vector v = new RandomAccessSparseVector(3);
    v.set(0, 0.5);
    v.set(1, 0.5);
    v.set(2, 1);

    Vector r = learningAlgo.classifyFull(v);
    System.out.println(r);

    System.out.println("ans = ");
    System.out.println("no of categories = " + learningAlgo.numCategories());
    System.out.println("no of features = " + learningAlgo.numFeatures());
    System.out.println("Probability of cluster 0 = " + r.get(0));
    System.out.println("Probability of cluster 1 = " + r.get(1));

}

From source file:com.skp.experiment.common.MathHelper.java

License:Apache License

/**
 * checks whether the {@link Vector} is equivalent to the set of {@link Vector.Element}s
 *///  w w w  . j av  a 2  s.co m
public static boolean consistsOf(Vector vector, Vector.Element... elements) {
    if (elements.length != numberOfNoNZeroNonNaNElements(vector)) {
        return false;
    }
    for (Vector.Element element : elements) {
        if (Math.abs(element.get() - vector.get(element.index())) > MahoutTestCase.EPSILON) {
            return false;
        }
    }
    return true;
}

From source file:com.twitter.algebra.nmf.ErrDMJ.java

License:Apache License

public static long run(Configuration conf, DistributedRowMatrix X, Vector xColSumVec, DistributedRowMatrix A,
        DistributedRowMatrix Yt, String label)
        throws IOException, InterruptedException, ClassNotFoundException {
    log.info("running " + ErrDMJ.class.getName());
    if (X.numRows() != A.numRows()) {
        throw new CardinalityException(A.numRows(), A.numRows());
    }//  w ww  .  ja  v  a2 s  . c o m
    if (A.numCols() != Yt.numCols()) {
        throw new CardinalityException(A.numCols(), Yt.numCols());
    }
    if (X.numCols() != Yt.numRows()) {
        throw new CardinalityException(X.numCols(), Yt.numRows());
    }
    Path outPath = new Path(A.getOutputTempPath(), label);
    FileSystem fs = FileSystem.get(outPath.toUri(), conf);
    ErrDMJ job = new ErrDMJ();
    long totalErr = -1;
    if (!fs.exists(outPath)) {
        Job hJob = job.run(conf, X.getRowPath(), A.getRowPath(), Yt.getRowPath(), outPath, A.numRows(),
                Yt.numRows(), Yt.numCols());
        Counters counters = hJob.getCounters();
        counters.findCounter("Result", "sumAbs").getValue();
        log.info("FINAL ERR is " + totalErr);
    } else {
        log.warn("----------- Skip already exists: " + outPath);
    }
    Vector sumErrVec = AlgebraCommon.mapDirToSparseVector(outPath, 1, X.numCols(), conf);
    double maxColErr = Double.MIN_VALUE;
    double sumColErr = 0;
    int cntColErr = 0;
    Iterator<Vector.Element> it = sumErrVec.nonZeroes().iterator();
    while (it.hasNext()) {
        Vector.Element el = it.next();
        double errP2 = el.get();
        double origP2 = xColSumVec.get(el.index());
        double colErr = Math.sqrt(errP2 / origP2);
        log.info("col: " + el.index() + " sum(err^2): " + errP2 + " sum(val^2): " + origP2 + " colErr: "
                + colErr);
        maxColErr = Math.max(colErr, maxColErr);
        sumColErr += colErr;
        cntColErr++;
    }
    log.info(" Max Col Err: " + maxColErr);
    log.info(" Avg Col Err: " + sumColErr / cntColErr);
    return totalErr;
}

From source file:edu.indiana.d2i.htrc.io.SparseVectorsToMemcached.java

License:Apache License

private static Vector transform2Vector(String text, String field, Analyzer analyzer, HTRCFilter filter,
        Dictionary dictionary) throws IOException {
    Vector result = new RandomAccessSparseVector(dictionary.size());

    TokenStream stream = analyzer.reusableTokenStream(field, new StringReader(text.toString()));
    CharTermAttribute termAtt = stream.addAttribute(CharTermAttribute.class);
    stream.reset();//  w  w  w.j  av a  2  s  .com
    while (stream.incrementToken()) {
        // String term = new String(termAtt.buffer(), 0,
        // termAtt.length());
        String term = new String(termAtt.buffer(), 0, termAtt.length()).toLowerCase();
        if (filter.accept(term, 0)) {
            int index = dictionary.get(term);
            result.setQuick(index, result.get(index) + 1);
        }
    }

    return result;
}