List of usage examples for org.apache.mahout.math Vector maxValueIndex
int maxValueIndex();
From source file:com.cloudera.knittingboar.metrics.POLRModelTester.java
License:Apache License
/** * Runs the next training batch to prep the gamma buffer to send to the * mstr_node/*from ww w. j ava 2 s. co m*/ * * TODO: need to provide stats, group measurements into struct * * @throws Exception * @throws IOException */ public void RunThroughTestRecords() throws IOException, Exception { Text value = new Text(); long batch_vec_factory_time = 0; k = 0; int num_correct = 0; for (int x = 0; x < this.BatchSize; x++) { if (this.input_split.next(value)) { long startTime = System.currentTimeMillis(); Vector v = new RandomAccessSparseVector(this.FeatureVectorSize); int actual = this.VectorFactory.processLine(value.toString(), v); long endTime = System.currentTimeMillis(); // System.out.println("That took " + (endTime - startTime) + // " milliseconds"); batch_vec_factory_time += (endTime - startTime); String ng = this.VectorFactory.GetClassnameByID(actual); // .GetNewsgroupNameByID( // actual ); // calc stats --------- double mu = Math.min(k + 1, 200); double ll = this.polr.logLikelihood(actual, v); if (Double.isNaN(ll)) { /* * System.out.println(" --------- NaN -----------"); * * System.out.println( "k: " + k ); System.out.println( "ll: " + ll ); * System.out.println( "mu: " + mu ); */ // return; } else { metrics.AvgLogLikelihood = metrics.AvgLogLikelihood + (ll - metrics.AvgLogLikelihood) / mu; } Vector p = new DenseVector(20); this.polr.classifyFull(p, v); int estimated = p.maxValueIndex(); int correct = (estimated == actual ? 1 : 0); if (estimated == actual) { num_correct++; } // averageCorrect = averageCorrect + (correct - averageCorrect) / mu; metrics.AvgCorrect = metrics.AvgCorrect + (correct - metrics.AvgCorrect) / mu; // this.polr.train(actual, v); k++; // if (x == this.BatchSize - 1) { int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); if (k % (bump * scale) == 0) { step += 0.25; System.out.printf( "Worker %s:\t Trained Recs: %10d, numCorrect: %d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n", this.internalID, k, num_correct, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100, batch_vec_factory_time); } this.polr.close(); } else { // nothing else to process in split! break; } // if } // for the number of passes in the run }
From source file:com.cloudera.knittingboar.records.TestTwentyNewsgroupsCustomRecordParseOLRRun.java
License:Apache License
@Test public void testRecordFactoryOnDatasetShard() throws Exception { // TODO a test with assertions is not a test // p.270 ----- metrics to track lucene's parsing mechanics, progress, // performance of OLR ------------ double averageLL = 0.0; double averageCorrect = 0.0; int k = 0;/*from ww w . ja va 2s . c o m*/ double step = 0.0; int[] bumps = new int[] { 1, 2, 5 }; TwentyNewsgroupsRecordFactory rec_factory = new TwentyNewsgroupsRecordFactory("\t"); // rec_factory.setClassSplitString("\t"); JobConf job = new JobConf(defaultConf); long block_size = localFs.getDefaultBlockSize(workDir); LOG.info("default block size: " + (block_size / 1024 / 1024) + "MB"); // matches the OLR setup on p.269 --------------- // stepOffset, decay, and alpha --- describe how the learning rate decreases // lambda: amount of regularization // learningRate: amount of initial learning rate @SuppressWarnings("resource") OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1()).alpha(1) .stepOffset(1000).decayExponent(0.9).lambda(3.0e-5).learningRate(20); FileInputFormat.setInputPaths(job, workDir); // try splitting the file in a variety of sizes TextInputFormat format = new TextInputFormat(); format.configure(job); Text value = new Text(); int numSplits = 1; InputSplit[] splits = format.getSplits(job, numSplits); LOG.info("requested " + numSplits + " splits, splitting: got = " + splits.length); LOG.info("---- debug splits --------- "); rec_factory.Debug(); int total_read = 0; for (int x = 0; x < splits.length; x++) { LOG.info("> Split [" + x + "]: " + splits[x].getLength()); int count = 0; InputRecordsSplit custom_reader = new InputRecordsSplit(job, splits[x]); while (custom_reader.next(value)) { Vector v = new RandomAccessSparseVector(TwentyNewsgroupsRecordFactory.FEATURES); int actual = rec_factory.processLine(value.toString(), v); String ng = rec_factory.GetNewsgroupNameByID(actual); // calc stats --------- double mu = Math.min(k + 1, 200); double ll = learningAlgorithm.logLikelihood(actual, v); averageLL = averageLL + (ll - averageLL) / mu; Vector p = new DenseVector(20); learningAlgorithm.classifyFull(p, v); int estimated = p.maxValueIndex(); int correct = (estimated == actual ? 1 : 0); averageCorrect = averageCorrect + (correct - averageCorrect) / mu; learningAlgorithm.train(actual, v); k++; int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); if (k % (bump * scale) == 0) { step += 0.25; LOG.info(String.format("%10d %10.3f %10.3f %10.2f %s %s", k, ll, averageLL, averageCorrect * 100, ng, rec_factory.GetNewsgroupNameByID(estimated))); } learningAlgorithm.close(); count++; } LOG.info("read: " + count + " records for split " + x); total_read += count; } // for each split LOG.info("total read across all splits: " + total_read); rec_factory.Debug(); }
From source file:com.cloudera.knittingboar.sgd.iterativereduce.POLRWorkerNode.java
License:Apache License
/** * The IR::Compute method - this is where we do the next batch of records for * SGD//from w ww . jav a2 s . c o m */ @Override public ParameterVectorGradientUpdatable compute() { Text value = new Text(); long batch_vec_factory_time = 0; boolean result = true; //boolean processBatch = false; /* if (this.LocalPassCount > this.GlobalPassCount) { // we need to sit this one out System.out.println("Worker " + this.internalID + " is ahead of global pass count [" + this.LocalPassCount + ":" + this.GlobalPassCount + "] "); processBatch = true; } if (this.LocalPassCount >= this.NumberPasses) { // learning is done, terminate System.out.println("Worker " + this.internalID + " is done [" + this.LocalPassCount + ":" + this.GlobalPassCount + "] "); processBatch = false; } if (processBatch) { */ // if (this.lineParser.hasMoreRecords()) { //for (int x = 0; x < this.BatchSize; x++) { while (this.lineParser.hasMoreRecords()) { try { result = this.lineParser.next(value); } catch (IOException e1) { // TODO Auto-generated catch block e1.printStackTrace(); } if (result) { long startTime = System.currentTimeMillis(); Vector v = new RandomAccessSparseVector(this.FeatureVectorSize); int actual = -1; try { actual = this.VectorFactory.processLine(value.toString(), v); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } long endTime = System.currentTimeMillis(); batch_vec_factory_time += (endTime - startTime); // calc stats --------- double mu = Math.min(k + 1, 200); double ll = this.polr.logLikelihood(actual, v); metrics.AvgLogLikelihood = metrics.AvgLogLikelihood + (ll - metrics.AvgLogLikelihood) / mu; if (Double.isNaN(metrics.AvgLogLikelihood)) { metrics.AvgLogLikelihood = 0; } Vector p = new DenseVector(this.num_categories); this.polr.classifyFull(p, v); int estimated = p.maxValueIndex(); int correct = (estimated == actual ? 1 : 0); metrics.AvgCorrect = metrics.AvgCorrect + (correct - metrics.AvgCorrect) / mu; this.polr.train(actual, v); k++; metrics.TotalRecordsProcessed = k; // if (x == this.BatchSize - 1) { /* System.err .printf( "Worker %s:\t Iteration: %s, Trained Recs: %10d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n", this.internalID, this.CurrentIteration, k, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100, batch_vec_factory_time); */ // } this.polr.close(); } else { // this.LocalBatchCountForIteration++; // this.input_split.ResetToStartOfSplit(); // nothing else to process in split! // break; } // if } // for the batch size System.err.printf( "Worker %s:\t Iteration: %s, Trained Recs: %10d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n", this.internalID, this.CurrentIteration, k, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100, batch_vec_factory_time); /* } else { System.err .printf( "Worker %s:\t Trained Recs: %10d, AvgLL: %10.3f, Percent Correct: %10.2f, [Done With Iteration]\n", this.internalID, k, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100); } // if */ return new ParameterVectorGradientUpdatable(this.GenerateUpdate()); }
From source file:com.cloudera.knittingboar.sgd.olr.TestBaseOLRTest20Newsgroups.java
License:Apache License
public void testResults() throws Exception { OnlineLogisticRegression classifier = ModelSerializer .readBinary(new FileInputStream(model20News.toString()), OnlineLogisticRegression.class); Text value = new Text(); long batch_vec_factory_time = 0; int k = 0;//from ww w . ja v a 2 s . c o m int num_correct = 0; // ---- this all needs to be done in JobConf job = new JobConf(defaultConf); // TODO: work on this, splits are generating for everything in dir // InputSplit[] splits = generateDebugSplits(inputDir, job); //fullRCV1Dir InputSplit[] splits = generateDebugSplits(testData20News, job); System.out.println("split count: " + splits.length); InputRecordsSplit custom_reader_0 = new InputRecordsSplit(job, splits[0]); TwentyNewsgroupsRecordFactory VectorFactory = new TwentyNewsgroupsRecordFactory("\t"); for (int x = 0; x < 8000; x++) { if (custom_reader_0.next(value)) { long startTime = System.currentTimeMillis(); Vector v = new RandomAccessSparseVector(FEATURES); int actual = VectorFactory.processLine(value.toString(), v); long endTime = System.currentTimeMillis(); //System.out.println("That took " + (endTime - startTime) + " milliseconds"); batch_vec_factory_time += (endTime - startTime); String ng = VectorFactory.GetClassnameByID(actual); //.GetNewsgroupNameByID( actual ); // calc stats --------- double mu = Math.min(k + 1, 200); double ll = classifier.logLikelihood(actual, v); //averageLL = averageLL + (ll - averageLL) / mu; metrics.AvgLogLikelihood = metrics.AvgLogLikelihood + (ll - metrics.AvgLogLikelihood) / mu; Vector p = new DenseVector(20); classifier.classifyFull(p, v); int estimated = p.maxValueIndex(); int correct = (estimated == actual ? 1 : 0); if (estimated == actual) { num_correct++; } //averageCorrect = averageCorrect + (correct - averageCorrect) / mu; metrics.AvgCorrect = metrics.AvgCorrect + (correct - metrics.AvgCorrect) / mu; //this.polr.train(actual, v); k++; // if (x == this.BatchSize - 1) { int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); if (k % (bump * scale) == 0) { step += 0.25; System.out.printf( "Worker %s:\t Tested Recs: %10d, numCorrect: %d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n", "OLR-standard-test", k, num_correct, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100, batch_vec_factory_time); } classifier.close(); } else { // nothing else to process in split! break; } // if } // for the number of passes in the run }
From source file:com.cloudera.knittingboar.sgd.olr.TestBaseOLR_Train20Newsgroups.java
License:Apache License
public void testTrainNewsGroups() throws IOException { File base = new File("/Users/jpatterson/Downloads/datasets/20news-bydate/20news-bydate-train/"); overallCounts = HashMultiset.create(); long startTime = System.currentTimeMillis(); // p.269 --------------------------------------------------------- Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>(); // encodes the text content in both the subject and the body of the email FeatureVectorEncoder encoder = new StaticWordValueEncoder("body"); encoder.setProbes(2);/*from w ww . j a v a2 s. co m*/ encoder.setTraceDictionary(traceDictionary); // provides a constant offset that the model can use to encode the average frequency // of each class FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept"); bias.setTraceDictionary(traceDictionary); // used to encode the number of lines in a message FeatureVectorEncoder lines = new ConstantValueEncoder("Lines"); lines.setTraceDictionary(traceDictionary); FeatureVectorEncoder logLines = new ConstantValueEncoder("LogLines"); logLines.setTraceDictionary(traceDictionary); Dictionary newsGroups = new Dictionary(); // matches the OLR setup on p.269 --------------- // stepOffset, decay, and alpha --- describe how the learning rate decreases // lambda: amount of regularization // learningRate: amount of initial learning rate OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1()).alpha(1) .stepOffset(1000).decayExponent(0.9).lambda(3.0e-5).learningRate(20); // bottom of p.269 ------------------------------ // because OLR expects to get integer class IDs for the target variable during training // we need a dictionary to convert the target variable (the newsgroup name) // to an integer, which is the newsGroup object List<File> files = new ArrayList<File>(); for (File newsgroup : base.listFiles()) { newsGroups.intern(newsgroup.getName()); files.addAll(Arrays.asList(newsgroup.listFiles())); } // mix up the files, helps training in OLR Collections.shuffle(files); System.out.printf("%d training files\n", files.size()); // p.270 ----- metrics to track lucene's parsing mechanics, progress, performance of OLR ------------ double averageLL = 0.0; double averageCorrect = 0.0; double averageLineCount = 0.0; int k = 0; double step = 0.0; int[] bumps = new int[] { 1, 2, 5 }; double lineCount = 0; // last line on p.269 Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31); Splitter onColon = Splitter.on(":").trimResults(); int input_file_count = 0; // ----- p.270 ------------ "reading and tokenzing the data" --------- for (File file : files) { BufferedReader reader = new BufferedReader(new FileReader(file)); input_file_count++; // identify newsgroup ---------------- // convert newsgroup name to unique id // ----------------------------------- String ng = file.getParentFile().getName(); int actual = newsGroups.intern(ng); Multiset<String> words = ConcurrentHashMultiset.create(); // check for line count header ------- String line = reader.readLine(); while (line != null && line.length() > 0) { // if this is a line that has a line count, let's pull that value out ------ if (line.startsWith("Lines:")) { String count = Iterables.get(onColon.split(line), 1); try { lineCount = Integer.parseInt(count); averageLineCount += (lineCount - averageLineCount) / Math.min(k + 1, 1000); } catch (NumberFormatException e) { // if anything goes wrong in parse: just use the avg count lineCount = averageLineCount; } } boolean countHeader = (line.startsWith("From:") || line.startsWith("Subject:") || line.startsWith("Keywords:") || line.startsWith("Summary:")); // loop through the lines in the file, while the line starts with: " " do { // get a reader for this specific string ------ StringReader in = new StringReader(line); // ---- count words in header --------- if (countHeader) { countWords(analyzer, words, in); } // iterate to the next string ---- line = reader.readLine(); } while (line.startsWith(" ")); } // while (lines in header) { // -------- count words in body ---------- countWords(analyzer, words, reader); reader.close(); // ----- p.271 ----------- Vector v = new RandomAccessSparseVector(FEATURES); // original value does nothing in a ContantValueEncoder bias.addToVector("", 1, v); // original value does nothing in a ContantValueEncoder lines.addToVector("", lineCount / 30, v); // original value does nothing in a ContantValueEncoder logLines.addToVector("", Math.log(lineCount + 1), v); // now scan through all the words and add them for (String word : words.elementSet()) { encoder.addToVector(word, Math.log(1 + words.count(word)), v); } //Utils.PrintVectorNonZero(v); // calc stats --------- double mu = Math.min(k + 1, 200); double ll = learningAlgorithm.logLikelihood(actual, v); averageLL = averageLL + (ll - averageLL) / mu; Vector p = new DenseVector(20); learningAlgorithm.classifyFull(p, v); int estimated = p.maxValueIndex(); int correct = (estimated == actual ? 1 : 0); averageCorrect = averageCorrect + (correct - averageCorrect) / mu; learningAlgorithm.train(actual, v); k++; int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); if (k % (bump * scale) == 0) { step += 0.25; System.out.printf("%10d %10.3f %10.3f %10.2f %s %s\n", k, ll, averageLL, averageCorrect * 100, ng, newsGroups.values().get(estimated)); } learningAlgorithm.close(); /* if (k>4) { break; } */ } Utils.PrintVectorSection(learningAlgorithm.getBeta().viewRow(0), 3); long endTime = System.currentTimeMillis(); //System.out.println("That took " + (endTime - startTime) + " milliseconds"); long duration = (endTime - startTime); System.out.println("Processed Input Files: " + input_file_count + ", time: " + duration + "ms"); ModelSerializer.writeBinary("/tmp/olr-news-group.model", learningAlgorithm); // learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); }
From source file:com.cloudera.knittingboar.sgd.POLRWorkerDriver.java
License:Apache License
/** * Runs the next training batch to prep the gamma buffer to send to the * mstr_node//from w ww.ja v a 2s . c o m * * TODO: need to provide stats, group measurements into struct * * @throws Exception * @throws IOException */ public boolean RunNextTrainingBatch() throws IOException, Exception { Text value = new Text(); long batch_vec_factory_time = 0; if (this.LocalPassCount > this.GlobalPassCount) { // we need to sit this one out System.out.println("Worker " + this.internalID + " is ahead of global pass count [" + this.LocalPassCount + ":" + this.GlobalPassCount + "] "); return true; } if (this.LocalPassCount >= this.NumberPasses) { // learning is done, terminate System.out.println("Worker " + this.internalID + " is done [" + this.LocalPassCount + ":" + this.GlobalPassCount + "] "); return false; } for (int x = 0; x < this.BatchSize; x++) { if (this.input_split.next(value)) { long startTime = System.currentTimeMillis(); Vector v = new RandomAccessSparseVector(this.FeatureVectorSize); int actual = this.VectorFactory.processLine(value.toString(), v); long endTime = System.currentTimeMillis(); batch_vec_factory_time += (endTime - startTime); // calc stats --------- double mu = Math.min(k + 1, 200); double ll = this.polr.logLikelihood(actual, v); metrics.AvgLogLikelihood = metrics.AvgLogLikelihood + (ll - metrics.AvgLogLikelihood) / mu; Vector p = new DenseVector(20); this.polr.classifyFull(p, v); int estimated = p.maxValueIndex(); int correct = (estimated == actual ? 1 : 0); metrics.AvgCorrect = metrics.AvgCorrect + (correct - metrics.AvgCorrect) / mu; this.polr.train(actual, v); k++; if (x == this.BatchSize - 1) { System.out.printf( "Worker %s:\t Trained Recs: %10d, loglikelihood: %10.3f, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n", this.internalID, k, ll, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100, batch_vec_factory_time); } this.polr.close(); } else { this.LocalPassCount++; this.input_split.ResetToStartOfSplit(); // nothing else to process in split! break; } // if } // for the batch size return true; }
From source file:com.mapr.stats.bandit.ContextualBayesBandit.java
License:Apache License
public int sample() { final Vector pi = sampleNoLink(); return pi.maxValueIndex(); }
From source file:com.memonews.mahout.sentiment.SentimentModelTester.java
License:Apache License
public void run(final PrintWriter output) throws IOException { final File base = new File(inputFile); // contains the best model final OnlineLogisticRegression classifier = ModelSerializer.readBinary(new FileInputStream(modelFile), OnlineLogisticRegression.class); final Dictionary newsGroups = new Dictionary(); final Multiset<String> overallCounts = HashMultiset.create(); final List<File> files = Lists.newArrayList(); for (final File newsgroup : base.listFiles()) { if (newsgroup.isDirectory()) { newsGroups.intern(newsgroup.getName()); files.addAll(Arrays.asList(newsgroup.listFiles())); }//from w w w. ja v a 2 s .c om } System.out.printf("%d test files\n", files.size()); final ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT"); for (final File file : files) { final String ng = file.getParentFile().getName(); final int actual = newsGroups.intern(ng); final SentimentModelHelper helper = new SentimentModelHelper(); final Vector input = helper.encodeFeatureVector(file, overallCounts);// no // leak // type // ensures // this // is // a // normal // vector final Vector result = classifier.classifyFull(input); final int cat = result.maxValueIndex(); final double score = result.maxValue(); final double ll = classifier.logLikelihood(actual, input); final ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), score, ll); ra.addInstance(newsGroups.values().get(actual), cr); } output.printf("%s\n\n", ra.toString()); }
From source file:com.scaleunlimited.classify.vectors.BaseNormalizer.java
License:Apache License
public static void dumpTopTerms(final Vector docFrequencies, List<String> uniqueTerms) { int cardinality = docFrequencies.size(); List<Integer> sortedDocFrequencyIndexes = new ArrayList<Integer>(cardinality); for (int i = 0; i < cardinality; i++) { sortedDocFrequencyIndexes.add(i); }/*from w w w . j a v a2s. c om*/ Collections.sort(sortedDocFrequencyIndexes, new Comparator<Integer>() { @Override public int compare(Integer o1, Integer o2) { return (int) (docFrequencies.getQuick(o2) - docFrequencies.getQuick(o1)); } }); double maxFrequency = docFrequencies.getQuick(docFrequencies.maxValueIndex()); StringBuffer topTermsReport = new StringBuffer(); for (int i = 0; i < cardinality; i++) { int index = sortedDocFrequencyIndexes.get(i); double frequency = docFrequencies.getQuick(index); if ((frequency / maxFrequency) > MIN_FREQUENCY_REPORT_RATIO) { topTermsReport.append(String.format("'%s'=%d, ", uniqueTerms.get(index), (int) frequency)); } } LOGGER.debug(topTermsReport.toString()); }
From source file:com.technobium.MultinomialLogisticRegression.java
License:Apache License
public static void main(String[] args) throws Exception { // this test trains a 3-way classifier on the famous Iris dataset. // a similar exercise can be accomplished in R using this code: // library(nnet) // correct = rep(0,100) // for (j in 1:100) { // i = order(runif(150)) // train = iris[i[1:100],] // test = iris[i[101:150],] // m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train) // correct[j] = mean(predict(m, newdata=test) == test$Species) // }// www.ja v a2 s . c o m // hist(correct) // // Note that depending on the training/test split, performance can be better or worse. // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy // of 100% // // This test uses a deterministic split that is neither outstandingly good nor bad RandomUtils.useTestSeed(); Splitter onComma = Splitter.on(","); // read the data List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8); // holds features List<Vector> data = Lists.newArrayList(); // holds target variable List<Integer> target = Lists.newArrayList(); // for decoding target values Dictionary dict = new Dictionary(); // for permuting data later List<Integer> order = Lists.newArrayList(); for (String line : raw.subList(1, raw.size())) { // order gets a list of indexes order.add(order.size()); // parse the predictor variables Vector v = new DenseVector(5); v.set(0, 1); int i = 1; Iterable<String> values = onComma.split(line); for (String value : Iterables.limit(values, 4)) { v.set(i++, Double.parseDouble(value)); } data.add(v); // and the target target.add(dict.intern(Iterables.get(values, 4))); } // randomize the order ... original data has each species all together // note that this randomization is deterministic Random random = RandomUtils.getRandom(); Collections.shuffle(order, random); // select training and test data List<Integer> train = order.subList(0, 100); List<Integer> test = order.subList(100, 150); logger.warn("Training set = {}", train); logger.warn("Test set = {}", test); // now train many times and collect information on accuracy each time int[] correct = new int[test.size() + 1]; for (int run = 0; run < 200; run++) { OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1)); // 30 training passes should converge to > 95% accuracy nearly always but never to 100% for (int pass = 0; pass < 30; pass++) { Collections.shuffle(train, random); for (int k : train) { lr.train(target.get(k), data.get(k)); } } // check the accuracy on held out data int x = 0; int[] count = new int[3]; for (Integer k : test) { Vector vt = lr.classifyFull(data.get(k)); int r = vt.maxValueIndex(); count[r]++; x += r == target.get(k) ? 1 : 0; } correct[x]++; if (run == 199) { Vector v = new DenseVector(5); v.set(0, 1); int i = 1; Iterable<String> values = onComma.split("6.0,2.7,5.1,1.6,versicolor"); for (String value : Iterables.limit(values, 4)) { v.set(i++, Double.parseDouble(value)); } Vector vt = lr.classifyFull(v); for (String value : dict.values()) { System.out.println("target:" + value); } int t = dict.intern(Iterables.get(values, 4)); int r = vt.maxValueIndex(); boolean flag = r == t; lr.close(); Closer closer = Closer.create(); try { FileOutputStream byteArrayOutputStream = closer .register(new FileOutputStream(new File("model.txt"))); DataOutputStream dataOutputStream = closer .register(new DataOutputStream(byteArrayOutputStream)); PolymorphicWritable.write(dataOutputStream, lr); } finally { closer.close(); } } } // verify we never saw worse than 95% correct, for (int i = 0; i < Math.floor(0.95 * test.size()); i++) { System.out.println(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size())); } // nor perfect System.out.println(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1])); }