Java tutorial
/** * $RCSfile: Classify.java * $Revision: 1.0 * $Date: 2015-6-24 * * Copyright (C) 2015 EastHope, Inc. All rights reserved. * * Use is subject to license terms. */ package hk.newsRecommender; import hk.mahout.bayes.kdd99.Kdd99CsvToSeqFile; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.List; import java.util.Map; import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.mapreduce.Reducer; import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; import org.apache.hadoop.mapreduce.lib.input.FileSplit; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier; import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier; import org.apache.mahout.classifier.naivebayes.NaiveBayesModel; import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier; import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob; import org.apache.mahout.common.HadoopUtil; import org.apache.mahout.math.RandomAccessSparseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import au.com.bytecode.opencsv.CSVReader; public class Classify { private static NaiveBayesModel naiveBayesModel = null; private static Map<String, Long> strOptionMap = Maps.newHashMap(); private static List<String> strLabelList = Lists.newArrayList(); public static void main(String[] args) throws Exception { Configuration conf = new Configuration(); String hdfsUrl = conf.get("fs.defaultFS"); // part1-------------------???mahout??-------------------------------------- // Job job1 = Job.getInstance(conf, "generateUserNewsTaggedMatrix"); // Path output1=new Path(hdfsUrl + "/data/recommend/class2"); // HadoopUtil.delete(conf, output1); // job1.setJarByClass(TFIDF.class); // job1.setMapperClass(Mapper_Part1.class); // job1.setReducerClass(Reduce_Part1.class); // job1.setMapOutputKeyClass(Text.class); // job1.setMapOutputValueClass(Text.class); // job1.setOutputKeyClass(Text.class); // job1.setOutputValueClass(NullWritable.class); // FileInputFormat.addInputPath(job1, new Path(hdfsUrl + "/data/recommend/matrix2")); // FileInputFormat.addInputPath(job1, new Path(hdfsUrl + "/data/recommend/ClusterPointsInfo.txt")); // FileOutputFormat.setOutputPath(job1, output1); // job1.waitForCompletion(true); // part1--------------------------------------------------------------- String trainFile = hdfsUrl + "/data/recommend/class2/part-r-00000"; String trainSeqPath = hdfsUrl + "/data/recommend/class3"; String trainSeqFile = trainSeqPath + "/matrixSeq.seq"; String testFile = hdfsUrl + "/data/recommend/class1/matrix2/part-r-00000"; // String testFile = hdfsUrl+"/data/recommend/class2/part-r-00000"; String outputPath = hdfsUrl + "/data/recommend/class4"; HadoopUtil.delete(conf, new Path[] { new Path(outputPath), new Path(trainSeqPath) }); classify(conf, trainFile, trainSeqFile, testFile, outputPath, 0); } // part1---------------??--------------------------------------------------------- public static class Mapper_Part1 extends Mapper<LongWritable, Text, Text, Text> { private String flag; @Override protected void setup(Context context) throws IOException, InterruptedException { FileSplit split = (FileSplit) context.getInputSplit(); flag = split.getPath().getName();// ? } public void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { if (flag.equals("ClusterPointsInfo.txt")) { String[] lineSplits = value.toString().split("\t"); context.write(new Text(lineSplits[0]), new Text(lineSplits[1])); } else { int index = value.toString().indexOf(" "); String keyStr = value.toString().substring(0, index); String valStr = value.toString().substring(index + 1); context.write(new Text(keyStr), new Text(valStr)); } } } public static class Reduce_Part1 extends Reducer<Text, Text, Text, NullWritable> { public void reduce(Text key, Iterable<Text> values, Context context) throws IOException, InterruptedException { String label = ""; StringBuilder sb = new StringBuilder(); for (Text text : values) { if (text.toString().contains(" ")) { String[] valArray = text.toString().split(" "); for (String temp : valArray) sb.append(temp).append(","); } else label = text.toString(); } sb.insert(0, label + ","); context.write(new Text(sb.deleteCharAt(sb.length() - 1).toString()), NullWritable.get()); } } // part2---------------bayes--------------------------------------------------------- public static void classify(Configuration conf, String trainFile, String trainSeqFile, String testFile, String outputPath, int labelIndex) throws Exception { // Step 1 : Convert CSV to Sequence file genNaiveBayesModel(conf, labelIndex, trainFile, trainSeqFile, false); // Step 2: Train NB train(conf, trainSeqFile, outputPath); // Step 3: Test to see result test(conf, testFile, labelIndex); } public static void genNaiveBayesModel(Configuration conf, int labelIndex, String trainFile, String trainSeqFile, boolean hasHeader) { CSVReader reader = null; try { FileSystem fs = FileSystem.get(conf); if (fs.exists(new Path(trainSeqFile))) fs.delete(new Path(trainSeqFile), true); SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, new Path(trainSeqFile), Text.class, VectorWritable.class); FileSystem fsopen = FileSystem.get(conf); FSDataInputStream in = fsopen.open(new Path(trainFile)); reader = new CSVReader(new InputStreamReader(in)); String[] header = null; if (hasHeader) header = reader.readNext(); String[] line = null; Long l = 0L; while ((line = reader.readNext()) != null) { if (labelIndex > line.length) break; l++; List<String> tmpList = Lists.newArrayList(line); String label = tmpList.get(labelIndex); if (!strLabelList.contains(label)) strLabelList.add(label); // Text key = new Text("/" + label + "/" + l); Text key = new Text("/" + label + "/"); tmpList.remove(labelIndex); VectorWritable vectorWritable = new VectorWritable(); Vector vector = new RandomAccessSparseVector(tmpList.size(), tmpList.size());//??? for (int i = 0; i < tmpList.size(); i++) { String tmpStr = tmpList.get(i); if (StringUtils.isNumeric(tmpStr)) vector.set(i, Double.parseDouble(tmpStr)); else vector.set(i, parseStrCell(tmpStr)); } vectorWritable.set(vector); writer.append(key, vectorWritable); } writer.close(); } catch (IOException e) { e.printStackTrace(); } } public static void train(Configuration conf, String trainSeqFile, String outputPath) throws Exception { System.out.println("~~~ begin to train ~~~"); String outputDirectory = outputPath + "/result"; String tempDirectory = outputPath + "/temp"; FileSystem fs = FileSystem.get(conf); TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob(); trainNaiveBayes.setConf(conf); fs.delete(new Path(outputDirectory), true); fs.delete(new Path(tempDirectory), true); // cmd sample: mahout trainnb -i train-vectors -el -li labelindex -o model -ow -c trainNaiveBayes.run(new String[] { "--input", trainSeqFile, "--output", outputDirectory, "-el", "--labelIndex", "labelIndex", "--overwrite", "--tempDir", tempDirectory }); // Train the classifier naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDirectory), conf); System.out.println("features: " + naiveBayesModel.numFeatures()); System.out.println("labels: " + naiveBayesModel.numLabels()); } public static void test(Configuration conf, String testFile, int labelIndex) throws IOException { System.out.println("~~~ begin to test ~~~"); AbstractNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(naiveBayesModel); FileSystem fsopen = FileSystem.get(conf); FSDataInputStream in = fsopen.open(new Path(testFile)); CSVReader csv = new CSVReader(new InputStreamReader(in)); csv.readNext(); // skip header String[] line = null; double totalSampleCount = 0.; double correctClsCount = 0.; // String str="10,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8,8,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,27,0,0,0,0,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,28,0,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4,0,0,0,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6,0,0,0,0,0,0,0,0,0,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7,0,0,0,0,0,0,0,0,0,0,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,0,0,0,0,0,0,0,0,0,0,0,0,0,4,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,37,0,0,0,0,4,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0,0,0,0,0,16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7,0,0,0,0,0,0,0,8,7,5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,0,0,4,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,12,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,16,0,0,0,0,0,14,0,8"; // List<String> newsList=new ArrayList<String>(); // newsList.add(str); // for(int j=0;j<newsList.size();j++){ // line=newsList.get(j).split(","); while ((line = csv.readNext()) != null) { // ??ID???ID? // ???? List<String> tmpList = Lists.newArrayList(line); String label = tmpList.get(labelIndex); tmpList.remove(labelIndex); totalSampleCount++; Vector vector = new RandomAccessSparseVector(tmpList.size(), tmpList.size()); for (int i = 0; i < tmpList.size(); i++) { String tempStr = tmpList.get(i); if (StringUtils.isNumeric(tempStr)) { vector.set(i, Double.parseDouble(tempStr)); } else { Long id = strOptionMap.get(tempStr); if (id != null) vector.set(i, id); else { System.out.println(StringUtils.join(tempStr, ",")); continue; } } } Vector resultVector = classifier.classifyFull(vector); int classifyResult = resultVector.maxValueIndex(); if (StringUtils.equals(label, strLabelList.get(classifyResult))) { correctClsCount++; } else { // line[labelIndex]????ID?? // ??????? // System.out.println("CorrectORItem=" + label + "\tClassify=" + strLabelList.get(classifyResult)); } } // System.out.println("Correct Ratio:" + (correctClsCount / totalSampleCount)); } private static Long parseStrCell(String str) { Long id = strOptionMap.get(str); if (id == null) { id = (long) (strOptionMap.size() + 1); strOptionMap.put(str, id); } return id; } }