Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.knn.tools; import java.io.FileOutputStream; import java.io.IOException; import java.io.PrintStream; import java.util.List; import java.util.Map; import com.google.common.base.Preconditions; import com.google.common.collect.*; import org.apache.commons.cli.Options; import org.apache.commons.cli.ParseException; import org.apache.commons.cli.PosixParser; import org.apache.commons.cli.CommandLine; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.Text; import org.apache.mahout.classifier.sgd.*; import org.apache.mahout.common.Pair; import org.apache.mahout.common.iterator.sequencefile.PathType; import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable; import org.apache.mahout.knn.experimental.CentroidWritable; import org.apache.mahout.math.Centroid; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; /** * Reads and trains an adaptive logistic regression model on the 20 newsgroups data. * The features used are the distances to the centers of the 20 newsgroup clusters. * These are computed from the current vector to: * <ul> * <li> * the centroids of the 20 clusters computed using the path-based assignment. * </li> * <li> * the centroids of the 20 clusters computed using ball k-means clustering; * </li> * <li> * the centroids of the ~86 clusters (k log n) computed using streaming k-means (without a * ball k-means step) to get back to 20 clusters; * </li> * </ul> */ public final class TrainNewsGroupsKMeansLogisticRegression { /** * Constants are set knowing what the 20 newsgroups data set looks like. */ private static final int NUM_VECTORS = 18898; private static final int NUM_CLASSES = 20; private static final int NUM_FEATURES_ACTUAL = 20; private static final int NUM_FEATURES_BKM = 20; private static final int NUM_FEATURES_SKM = NUM_CLASSES * (int) Math.log(NUM_VECTORS); private TrainNewsGroupsKMeansLogisticRegression() { } public static void trainActual(Iterable<Pair<Text, VectorWritable>> inputIterable, String outBase, Map<String, Integer> clusterNamesToIds) throws IOException { Map<String, Centroid> actualClusters = Maps.newHashMap(); CreateCentroids.computeActualClusters(inputIterable, actualClusters); AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(NUM_CLASSES, NUM_FEATURES_ACTUAL, new L1()); int vectorId = 0; for (Pair<Text, VectorWritable> pair : inputIterable) { Vector actualCentroid = pair.getSecond().get(); Vector features = new DenseVector(NUM_FEATURES_ACTUAL); int i = 0; for (Centroid centroid : actualClusters.values()) { features.set(i++, centroid.getDistanceSquared(actualCentroid)); } String clusterName = pair.getFirst().toString(); learningAlgorithm.train(clusterNamesToIds.get(clusterName), features); ++vectorId; if (vectorId % 100 == 0) { System.out.printf("[actual] Training %f complete\n", (float) vectorId / NUM_VECTORS); } } learningAlgorithm.close(); ModelSerializer.writeBinary(outBase + "-actual.model", learningAlgorithm); } public static void trainComputed(Iterable<Pair<Text, VectorWritable>> inputIterable, String outBase, String suffix, Map<String, Integer> clusterNamesToIds, Pair<Integer, Iterable<Centroid>> numFeaturesCentroidsPair) throws IOException { final int numFeatures = numFeaturesCentroidsPair.getFirst(); System.out.printf("[%s] Starting training with %d features\n", suffix, numFeatures); AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(clusterNamesToIds.size(), numFeatures, new L1()); int vectorId = 0; for (Pair<Text, VectorWritable> pair : inputIterable) { Vector actualCentroid = pair.getSecond().get(); Vector features = new DenseVector(numFeatures); int i = 0; for (Centroid centroid : numFeaturesCentroidsPair.getSecond()) { features.set(i++, centroid.getDistanceSquared(actualCentroid)); } String clusterName = pair.getFirst().toString(); learningAlgorithm.train(clusterNamesToIds.get(clusterName), features); ++vectorId; if (vectorId % 100 == 0) { System.out.printf("[%s] Training %f complete\n", suffix, (float) vectorId / NUM_VECTORS); } } learningAlgorithm.close(); ModelSerializer.writeBinary(outBase + "-" + suffix + ".model", learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); } public static void main(String[] args) throws IOException, ParseException { Options options = new Options(); options.addOption("i", "input", true, "Path to the input folder containing the training set's" + " sequence files."); options.addOption("o", "output", true, "Base path to the output file. The name will be " + "appended with a suffix for each type of training."); options.addOption("a", "actual", false, "If set, runs the training with the actual cluster " + "assignments and outputs the model to the output path with a -actual suffix."); options.addOption("b", "ballkmeans", false, "If set, runs the training with the ball k-means " + "cluster assignments and outputs the model to the output path with a -ballkmeans suffix."); options.addOption("s", "streamingkmeans", false, "If set, runs the training with the " + "streaming k-means cluster assignments and outputs the model to the output path with a " + "-streamingkmeans suffix."); options.addOption("c", "centroids", true, "Path to the centroids seqfile"); CommandLine cmd = (new PosixParser()).parse(options, args); String inputPath = cmd.getOptionValue("input"); Preconditions.checkNotNull(inputPath); String outputBase = cmd.getOptionValue("output"); Preconditions.checkNotNull(outputBase); String centroidsPath = cmd.getOptionValue("centroids"); Preconditions.checkNotNull(centroidsPath); Configuration conf = new Configuration(); SequenceFileDirIterable<Text, VectorWritable> inputIterable = new SequenceFileDirIterable<Text, VectorWritable>( new Path(inputPath), PathType.LIST, conf); PrintStream clusterIdOut = new PrintStream(new FileOutputStream("cluster-ids.csv")); clusterIdOut.printf("clusterName, clusterId\n"); int clusterId = 0; Map<String, Integer> clusterNamesToIds = Maps.newHashMapWithExpectedSize(NUM_CLASSES); for (Pair<Text, VectorWritable> pair : inputIterable) { String clusterName = pair.getFirst().toString(); if (!clusterNamesToIds.containsKey(clusterName)) { clusterIdOut.printf("%s, %d\n", clusterName, clusterId); clusterNamesToIds.put(clusterName, clusterId++); } } clusterIdOut.close(); if (cmd.hasOption("actual")) { System.out.printf("\nActual clusters models\n"); System.out.printf("----------------------\n"); long start = System.currentTimeMillis(); trainActual(inputIterable, outputBase, clusterNamesToIds); long end = System.currentTimeMillis(); System.out.printf("Trained models for actual clusters. Took %d ms\n", end - start); } if (cmd.hasOption("ballkmeans") || cmd.hasOption("streamingkmeans")) { SequenceFileValueIterable<CentroidWritable> centroidIterable = new SequenceFileValueIterable<CentroidWritable>( new Path(centroidsPath), conf); List<Centroid> centroids = Lists .newArrayList(CreateCentroids.getCentroidsFromCentroidWritableIterable(centroidIterable)); if (cmd.hasOption("ballkmeans")) { System.out.printf("\nBall k-means clusters models\n"); System.out.printf("----------------------------\n"); long start = System.currentTimeMillis(); trainComputed(inputIterable, outputBase, "ballkmeans", clusterNamesToIds, new Pair<Integer, Iterable<Centroid>>(NUM_FEATURES_BKM, centroids)); long end = System.currentTimeMillis(); System.out.printf("Trained models for ballkmeans clusters. Took %d ms\n", end - start); } if (cmd.hasOption("streamingkmeans")) { System.out.printf("\nStreaming k-means clusters models\n"); System.out.printf("---------------------------------\n"); long start = System.currentTimeMillis(); trainComputed(inputIterable, outputBase, "streamingkmeans", clusterNamesToIds, new Pair<Integer, Iterable<Centroid>>(centroids.size(), centroids)); long end = System.currentTimeMillis(); System.out.printf("Trained models for streamingkmeans clusters. Took %d ms\n", end - start); } } } }