List of usage examples for org.apache.mahout.math Vector set
void set(int index, double value);
From source file:com.mozilla.grouperfish.pig.storage.DocumentVectorStorage.java
License:Apache License
@SuppressWarnings("unchecked") @Override//from ww w . j av a2 s . co m public void putNext(Tuple tuple) throws IOException { outputKey.set((String) tuple.get(0)); Tuple vectorTuple = (Tuple) tuple.get(1); Vector vector = new NamedVector(new RandomAccessSparseVector(dimensions, vectorTuple.size()), outputKey.toString()); for (int i = 0; i < vectorTuple.size(); i++) { Object o = vectorTuple.get(i); switch (vectorTuple.getType(i)) { case DataType.INTEGER: // If this is just an integer then we just want to set the index to 1.0 vector.set((Integer) o, 1.0); break; case DataType.TUPLE: // If this is a tuple then we want to set the index and the weight Tuple subt = (Tuple) o; vector.set((Integer) subt.get(0), (Double) subt.get(1)); break; default: throw new RuntimeException("Unexpected tuple form"); } } outputValue.set(vector); try { writer.write(outputKey, outputValue); } catch (InterruptedException e) { LOG.error("Interrupted while writing", e); } }
From source file:com.mozilla.grouperfish.transforms.coclustering.pig.storage.MahoutVectorStorage.java
License:Apache License
@Override public void putNext(Tuple t) throws IOException { IntWritable outputKey = new IntWritable(); VectorWritable outputValue = new VectorWritable(); outputKey.set((Integer) t.get(0)); Tuple currRow = (Tuple) t.get(1);//from w ww. j a va 2s.c o m Vector currRowVector; if (dimensions == 0) { throw new IllegalArgumentException("Trying to create 0 dimension vector"); } if (STORE_AS_DENSE) { currRowVector = new NamedVector(new DenseVector(dimensions), outputKey.toString()); } else if (STORE_AS_SEQUENTIAL) { currRowVector = new NamedVector(new SequentialAccessSparseVector(dimensions, currRow.size()), outputKey.toString()); } else { currRowVector = new NamedVector(new RandomAccessSparseVector(dimensions, currRow.size()), outputKey.toString()); } for (int ii = 0; ii < currRow.size(); ii++) { Object o = currRow.get(ii); switch (currRow.getType(ii)) { case DataType.INTEGER: case DataType.LONG: case DataType.FLOAT: case DataType.DOUBLE: currRowVector.set(ii, (Double) o); break; case DataType.TUPLE: // If this is a tuple then we want to set column and element Tuple subt = (Tuple) o; currRowVector.set((Integer) subt.get(0), (Double) subt.get(1)); break; default: throw new RuntimeException("Unexpected tuple form"); } } outputValue.set(currRowVector); try { writer.write(outputKey, outputValue); } catch (InterruptedException e) { LOG.error("Interrupted while writing", e); } }
From source file:com.sixgroup.samplerecommender.Point.java
public static void main(String[] args) { Map<Point, Integer> points = new HashMap<Point, Integer>(); points.put(new Point(0, 0), 0); points.put(new Point(1, 1), 0); points.put(new Point(1, 0), 0); points.put(new Point(0, 1), 0); points.put(new Point(2, 2), 0); points.put(new Point(8, 8), 1); points.put(new Point(8, 9), 1); points.put(new Point(9, 8), 1); points.put(new Point(9, 9), 1); OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression(); learningAlgo = new OnlineLogisticRegression(2, 3, new L1()); learningAlgo.lambda(0.1);/*w w w . java 2 s . co m*/ learningAlgo.learningRate(10); System.out.println("training model \n"); for (Point point : points.keySet()) { Vector v = getVector(point); System.out.println(point + " belongs to " + points.get(point)); learningAlgo.train(points.get(point), v); } learningAlgo.close(); Vector v = new RandomAccessSparseVector(3); v.set(0, 0.5); v.set(1, 0.5); v.set(2, 1); Vector r = learningAlgo.classifyFull(v); System.out.println(r); System.out.println("ans = "); System.out.println("no of categories = " + learningAlgo.numCategories()); System.out.println("no of features = " + learningAlgo.numFeatures()); System.out.println("Probability of cluster 0 = " + r.get(0)); System.out.println("Probability of cluster 1 = " + r.get(1)); }
From source file:com.sixgroup.samplerecommender.Point.java
public static Vector getVector(Point point) { Vector v = new DenseVector(3); v.set(0, point.x); v.set(1, point.y);//from w ww .j a v a2 s. c o m v.set(2, 1); return v; }
From source file:com.tamingtext.mahout.VectorExamplesTest.java
License:Apache License
@Test public void testProgrammatic() throws Exception { //<start id="vec.examples.programmatic"/> double[] vals = new double[] { 0.3, 1.8, 200.228 }; Vector dense = new DenseVector(vals);//<co id="vec.exam.dense"/> assertTrue(dense.size() == 3);/* w w w. j a va 2 s . com*/ Vector sparseSame = new SequentialAccessSparseVector(3);//<co id="vec.exam.sparse.same"/> Vector sparse = new SequentialAccessSparseVector(3000);//<co id="vec.exam.sparse"/> for (int i = 0; i < vals.length; i++) {//<co id="vec.exam.assign.sparse"/> sparseSame.set(i, vals[i]); sparse.set(i, vals[i]); } assertFalse(dense.equals(sparse));//<co id="vec.exam.notequals.d.s"/> assertEquals(dense, sparseSame);//<co id="vec.exam.equals.d.s"/> assertFalse(sparse.equals(sparseSame)); /* <calloutlist> <callout arearefs="vec.exam.dense"><para>Create a <classname>DenseVector</classname> with a label of "my-dense" and 3 values. The cardinality of this vector is 3 </para></callout> <callout arearefs="vec.exam.sparse.same"><para>Create a <classname>SparseVector</classname> with a label of my-sparse-same that has cardinality of 3</para></callout> <callout arearefs="vec.exam.sparse"><para>Create a <classname>SparseVector</classname> with a label of my-sparse and a cardinality of 3000.</para></callout> <callout arearefs="vec.exam.assign.sparse"><para>Set the values to the first 3 items in the sparse vectors.</para></callout> <callout arearefs="vec.exam.notequals.d.s"><para>The dense and the sparse <classname>Vector</classname>s are not equal because they have different cardinality.</para></callout> <callout arearefs="vec.exam.equals.d.s"><para>The dense and sparseSame <classname>Vector</classname>s are equal because they have the same values and cardinality</para></callout> </calloutlist> */ //<end id="vec.examples.programmatic"/> //<start id="vec.examples.seq.file"/> File tmpDir = new File(System.getProperty("java.io.tmpdir")); File tmpLoc = new File(tmpDir, "sfvwt"); tmpLoc.mkdirs(); File tmpFile = File.createTempFile("sfvwt", ".dat", tmpLoc); Path path = new Path(tmpFile.getAbsolutePath()); Configuration conf = new Configuration();//<co id="vec.examples.seq.conf"/> FileSystem fs = FileSystem.get(conf); SequenceFile.Writer seqWriter = SequenceFile.createWriter(fs, conf, path, LongWritable.class, VectorWritable.class);//<co id="vec.examples.seq.writer"/> VectorWriter vecWriter = new SequenceFileVectorWriter(seqWriter);//<co id="vec.examples.seq.vecwriter"/> List<Vector> vectors = new ArrayList<Vector>(); vectors.add(sparse); vectors.add(sparseSame); vecWriter.write(vectors);//<co id="vec.examples.seq.write"/> vecWriter.close(); /* <calloutlist> <callout arearefs="vec.examples.seq.conf"><para>Create a <classname>Configuration</classname> for Hadoop</para></callout> <callout arearefs="vec.examples.seq.writer"><para>Create a Hadoop <classname>SequenceFile.Writer</classname> to handle the job of physically writing out the vectors to a file in HDFS</para></callout> <callout arearefs="vec.examples.seq.vecwriter"><para>A <classname>VectorWriter</classname> processes the <classname>Vector</classname>s and invokes the underlying write methods on the <classname>SequenceFile.Writer</classname></para></callout> <callout arearefs="vec.examples.seq.write"><para>Do the work of writing out the files</para></callout> </calloutlist> */ //<end id="vec.examples.seq.file"/> }
From source file:com.technobium.MultinomialLogisticRegression.java
License:Apache License
public static void main(String[] args) throws Exception { // this test trains a 3-way classifier on the famous Iris dataset. // a similar exercise can be accomplished in R using this code: // library(nnet) // correct = rep(0,100) // for (j in 1:100) { // i = order(runif(150)) // train = iris[i[1:100],] // test = iris[i[101:150],] // m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train) // correct[j] = mean(predict(m, newdata=test) == test$Species) // }//ww w .j a v a 2 s . com // hist(correct) // // Note that depending on the training/test split, performance can be better or worse. // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy // of 100% // // This test uses a deterministic split that is neither outstandingly good nor bad RandomUtils.useTestSeed(); Splitter onComma = Splitter.on(","); // read the data List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8); // holds features List<Vector> data = Lists.newArrayList(); // holds target variable List<Integer> target = Lists.newArrayList(); // for decoding target values Dictionary dict = new Dictionary(); // for permuting data later List<Integer> order = Lists.newArrayList(); for (String line : raw.subList(1, raw.size())) { // order gets a list of indexes order.add(order.size()); // parse the predictor variables Vector v = new DenseVector(5); v.set(0, 1); int i = 1; Iterable<String> values = onComma.split(line); for (String value : Iterables.limit(values, 4)) { v.set(i++, Double.parseDouble(value)); } data.add(v); // and the target target.add(dict.intern(Iterables.get(values, 4))); } // randomize the order ... original data has each species all together // note that this randomization is deterministic Random random = RandomUtils.getRandom(); Collections.shuffle(order, random); // select training and test data List<Integer> train = order.subList(0, 100); List<Integer> test = order.subList(100, 150); logger.warn("Training set = {}", train); logger.warn("Test set = {}", test); // now train many times and collect information on accuracy each time int[] correct = new int[test.size() + 1]; for (int run = 0; run < 200; run++) { OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1)); // 30 training passes should converge to > 95% accuracy nearly always but never to 100% for (int pass = 0; pass < 30; pass++) { Collections.shuffle(train, random); for (int k : train) { lr.train(target.get(k), data.get(k)); } } // check the accuracy on held out data int x = 0; int[] count = new int[3]; for (Integer k : test) { Vector vt = lr.classifyFull(data.get(k)); int r = vt.maxValueIndex(); count[r]++; x += r == target.get(k) ? 1 : 0; } correct[x]++; if (run == 199) { Vector v = new DenseVector(5); v.set(0, 1); int i = 1; Iterable<String> values = onComma.split("6.0,2.7,5.1,1.6,versicolor"); for (String value : Iterables.limit(values, 4)) { v.set(i++, Double.parseDouble(value)); } Vector vt = lr.classifyFull(v); for (String value : dict.values()) { System.out.println("target:" + value); } int t = dict.intern(Iterables.get(values, 4)); int r = vt.maxValueIndex(); boolean flag = r == t; lr.close(); Closer closer = Closer.create(); try { FileOutputStream byteArrayOutputStream = closer .register(new FileOutputStream(new File("model.txt"))); DataOutputStream dataOutputStream = closer .register(new DataOutputStream(byteArrayOutputStream)); PolymorphicWritable.write(dataOutputStream, lr); } finally { closer.close(); } } } // verify we never saw worse than 95% correct, for (int i = 0; i < Math.floor(0.95 * test.size()); i++) { System.out.println(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size())); } // nor perfect System.out.println(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1])); }
From source file:com.twitter.algebra.AlgebraCommon.java
License:Apache License
/** * Multiply a vector with transpose of a matrix * @param vector V//from w w w . j a v a2 s. co m * @param transpose of matrix M * @param resVector will be filled with V * M * @return V * M */ public static Vector vectorTimesMatrixTranspose(Vector vector, Matrix matrixTranspose, Vector resVector) { int nCols = matrixTranspose.numRows(); for (int c = 0; c < nCols; c++) { Vector col = matrixTranspose.viewRow(c); double resDouble = 0d; boolean hasNonZero = col.getNumNondefaultElements() != 0; if (hasNonZero) resDouble = vector.dot(col); resVector.set(c, resDouble); } return resVector; }
From source file:edu.snu.cms.reef.ml.kmeans.CentroidListCodecTest.java
License:Apache License
@Before public final void setUp() { for (int j = 0; j < (int) (Math.random() * 1000); j++) { final Vector vector = new DenseVector((int) (Math.random() * 1000)); for (int i = 0; i < vector.size(); i++) { vector.set(i, Math.random()); }/*from w w w.ja va 2s . com*/ final Centroid centroid = new Centroid((int) (Math.random() * 1000000), vector); list.add(centroid); } }
From source file:edu.snu.cms.reef.ml.kmeans.data.Centroid.java
License:Apache License
/** * A copy constructor that creates a deep copy of a centroid. * * The newly created KMeansCentroid does not reference * anything from the original KMeansCentroid. *///from w ww. j a va2s . com public Centroid(final Centroid centroid) { this.clusterId = centroid.clusterId; final Vector vector = new DenseVector(centroid.vector.size()); for (int i = 0; i < vector.size(); i++) { vector.set(i, centroid.vector.get(i)); } this.vector = vector; }
From source file:edu.snu.cms.reef.ml.kmeans.data.KMeansDataParser.java
License:Apache License
@Override public final void parse() { List<Vector> centroids = new ArrayList<>(); List<Vector> points = new ArrayList<>(); for (final Pair<LongWritable, Text> keyValue : dataSet) { String[] split = keyValue.second.toString().trim().split("\\s+"); if (split.length == 0) { continue; }/*w w w. j a v a 2s .c om*/ if (split[0].equals("*")) { final Vector centroid = new DenseVector(split.length - 1); try { for (int i = 1; i < split.length; i++) { centroid.set(i - 1, Double.valueOf(split[i])); } centroids.add(centroid); } catch (final NumberFormatException e) { parseException = new ParseException("Parse failed: numbers should be DOUBLE"); return; } } else { final Vector data = new DenseVector(split.length); try { for (int i = 0; i < split.length; i++) { data.set(i, Double.valueOf(split[i])); } points.add(data); } catch (final NumberFormatException e) { parseException = new ParseException("Parse failed: numbers should be DOUBLE"); return; } } result = new Pair<>(centroids, points); } }