List of usage examples for org.apache.mahout.math Vector maxValueIndex
int maxValueIndex();
From source file:edu.utsa.sifter.som.SelfOrganizingMap.java
License:Apache License
int maxTermDifference(final int tID, final int uID) { final Vector t = getCell(tID); final Vector u = getCell(uID); final Vector diff = t.minus(u); final double maxValue = diff.maxValue(); final double minValue = diff.minValue(); if (minValue < 0 && Math.abs(minValue) > maxValue) { // even if maxValue is negative, this will hold return -diff.minValueIndex(); } else {/* www . j av a 2 s .c o m*/ return diff.maxValueIndex(); } }
From source file:hk.newsRecommender.Classify.java
License:Open Source License
public static void test(Configuration conf, String testFile, int labelIndex) throws IOException { System.out.println("~~~ begin to test ~~~"); AbstractNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(naiveBayesModel); FileSystem fsopen = FileSystem.get(conf); FSDataInputStream in = fsopen.open(new Path(testFile)); CSVReader csv = new CSVReader(new InputStreamReader(in)); csv.readNext(); // skip header String[] line = null;//from w w w . jav a 2 s .co m double totalSampleCount = 0.; double correctClsCount = 0.; // String str="10,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8,8,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,27,0,0,0,0,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,28,0,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4,0,0,0,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6,0,0,0,0,0,0,0,0,0,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7,0,0,0,0,0,0,0,0,0,0,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,0,0,0,0,0,0,0,0,0,0,0,0,0,4,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,37,0,0,0,0,4,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0,0,0,0,0,16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7,0,0,0,0,0,0,0,8,7,5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,0,0,4,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,12,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,16,0,0,0,0,0,14,0,8"; // List<String> newsList=new ArrayList<String>(); // newsList.add(str); // for(int j=0;j<newsList.size();j++){ // line=newsList.get(j).split(","); while ((line = csv.readNext()) != null) { // ??ID???ID? // ???? List<String> tmpList = Lists.newArrayList(line); String label = tmpList.get(labelIndex); tmpList.remove(labelIndex); totalSampleCount++; Vector vector = new RandomAccessSparseVector(tmpList.size(), tmpList.size()); for (int i = 0; i < tmpList.size(); i++) { String tempStr = tmpList.get(i); if (StringUtils.isNumeric(tempStr)) { vector.set(i, Double.parseDouble(tempStr)); } else { Long id = strOptionMap.get(tempStr); if (id != null) vector.set(i, id); else { System.out.println(StringUtils.join(tempStr, ",")); continue; } } } Vector resultVector = classifier.classifyFull(vector); int classifyResult = resultVector.maxValueIndex(); if (StringUtils.equals(label, strLabelList.get(classifyResult))) { correctClsCount++; } else { // line[labelIndex]????ID?? // ??????? // System.out.println("CorrectORItem=" + label + "\tClassify=" + strLabelList.get(classifyResult)); } } // System.out.println("Correct Ratio:" + (correctClsCount / totalSampleCount)); }
From source file:technobium.OnlineLogisticRegressionTest.java
License:Apache License
public static void main(String[] args) throws Exception { // this test trains a 3-way classifier on the famous Iris dataset. // a similar exercise can be accomplished in R using this code: // library(nnet) // correct = rep(0,100) // for (j in 1:100) { // i = order(runif(150)) // train = iris[i[1:100],] // test = iris[i[101:150],] // m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train) // correct[j] = mean(predict(m, newdata=test) == test$Species) // }//www . j a v a 2 s . c om // hist(correct) // // Note that depending on the training/test split, performance can be better or worse. // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy // of 100% // // This test uses a deterministic split that is neither outstandingly good nor bad RandomUtils.useTestSeed(); Splitter onComma = Splitter.on(","); // read the data List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8); // holds features List<Vector> data = Lists.newArrayList(); // holds target variable List<Integer> target = Lists.newArrayList(); // for decoding target values Dictionary dict = new Dictionary(); // for permuting data later List<Integer> order = Lists.newArrayList(); for (String line : raw.subList(1, raw.size())) { // order gets a list of indexes order.add(order.size()); // parse the predictor variables Vector v = new DenseVector(5); v.set(0, 1); int i = 1; Iterable<String> values = onComma.split(line); for (String value : Iterables.limit(values, 4)) { v.set(i++, Double.parseDouble(value)); } data.add(v); // and the target target.add(dict.intern(Iterables.get(values, 4))); } // randomize the order ... original data has each species all together // note that this randomization is deterministic Random random = RandomUtils.getRandom(); Collections.shuffle(order, random); // select training and test data List<Integer> train = order.subList(0, 100); List<Integer> test = order.subList(100, 150); logger.warn("Training set = {}", train); logger.warn("Test set = {}", test); // now train many times and collect information on accuracy each time int[] correct = new int[test.size() + 1]; for (int run = 0; run < 200; run++) { OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1)); // 30 training passes should converge to > 95% accuracy nearly always but never to 100% for (int pass = 0; pass < 30; pass++) { Collections.shuffle(train, random); for (int k : train) { lr.train(target.get(k), data.get(k)); } } // check the accuracy on held out data int x = 0; int[] count = new int[3]; for (Integer k : test) { Vector vt = lr.classifyFull(data.get(k)); int r = vt.maxValueIndex(); count[r]++; x += r == target.get(k) ? 1 : 0; } correct[x]++; } // verify we never saw worse than 95% correct, for (int i = 0; i < Math.floor(0.95 * test.size()); i++) { System.out.println(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size())); } // nor perfect System.out.println(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1])); }
From source file:tv.floe.metronome.classification.logisticregression.iterativereduce.POLRWorkerNode.java
License:Apache License
/** * The IR::Compute method - this is where we do the next batch of records for * SGD/*from w w w . ja va 2s .c o m*/ */ @Override public ParameterVectorUpdatable compute() { Text value = new Text(); long batch_vec_factory_time = 0; boolean result = true; //boolean processBatch = false; while (this.lineParser.hasMoreRecords()) { try { result = this.lineParser.next(value); } catch (IOException e1) { // TODO Auto-generated catch block e1.printStackTrace(); } if (result) { long startTime = System.currentTimeMillis(); Vector v = new RandomAccessSparseVector(this.FeatureVectorSize); int actual = -1; try { actual = this.VectorFactory.processLine(value.toString(), v); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } long endTime = System.currentTimeMillis(); batch_vec_factory_time += (endTime - startTime); // calc stats --------- double mu = Math.min(k + 1, 200); double ll = this.polr.logLikelihood(actual, v); metrics.AvgLogLikelihood = metrics.AvgLogLikelihood + (ll - metrics.AvgLogLikelihood) / mu; if (Double.isNaN(metrics.AvgLogLikelihood)) { metrics.AvgLogLikelihood = 0; } Vector p = new DenseVector(this.num_categories); this.polr.classifyFull(p, v); int estimated = p.maxValueIndex(); int correct = (estimated == actual ? 1 : 0); metrics.AvgCorrect = metrics.AvgCorrect + (correct - metrics.AvgCorrect) / mu; this.polr.train(actual, v); k++; metrics.TotalRecordsProcessed = k; // if (x == this.BatchSize - 1) { /* System.err .printf( "Worker %s:\t Iteration: %s, Trained Recs: %10d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n", this.internalID, this.CurrentIteration, k, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100, batch_vec_factory_time); */ // } this.polr.close(); } else { // this.LocalBatchCountForIteration++; // this.input_split.ResetToStartOfSplit(); // nothing else to process in split! // break; } // if } // for the batch size System.err.printf( "Worker %s:\t Iteration: %s, Trained Recs: %10d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n", this.internalID, this.CurrentIteration, k, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100, batch_vec_factory_time); /* } else { System.err .printf( "Worker %s:\t Trained Recs: %10d, AvgLL: %10.3f, Percent Correct: %10.2f, [Done With Iteration]\n", this.internalID, k, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100); } // if */ return new ParameterVectorUpdatable(this.GenerateUpdate()); }