Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.memonews.mahout.sentiment; import java.io.File; import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression; import org.apache.mahout.classifier.sgd.CrossFoldLearner; import org.apache.mahout.classifier.sgd.ModelDissector; import org.apache.mahout.classifier.sgd.ModelSerializer; import org.apache.mahout.classifier.sgd.OnlineLogisticRegression; import org.apache.mahout.ep.State; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.Vector; import org.apache.mahout.math.function.DoubleFunction; import org.apache.mahout.math.function.Functions; import org.apache.mahout.vectorizer.encoders.Dictionary; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Multiset; public final class SGDHelper { private static final String[] LEAK_LABELS = { "none", "month-year", "day-month-year" }; private SGDHelper() { } public static void dissect(final int leakType, final Dictionary dictionary, final AdaptiveLogisticRegression learningAlgorithm, final Iterable<File> files, final Multiset<String> overallCounts) throws IOException { final CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner(); model.close(); final Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap(); final ModelDissector md = new ModelDissector(); final SentimentModelHelper helper = new SentimentModelHelper(); helper.getEncoder().setTraceDictionary(traceDictionary); helper.getBias().setTraceDictionary(traceDictionary); for (final File file : permute(files, helper.getRandom()).subList(0, 500)) { traceDictionary.clear(); final Vector v = helper.encodeFeatureVector(file, overallCounts); md.update(v, traceDictionary, model); } final List<String> ngNames = Lists.newArrayList(dictionary.values()); final List<ModelDissector.Weight> weights = md.summary(100); System.out.println("============"); System.out.println("Model Dissection"); for (final ModelDissector.Weight w : weights) { System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\n", w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact()), w.getCategory(0), w.getWeight(0)); } } public static List<File> permute(final Iterable<File> files, final Random rand) { final List<File> r = Lists.newArrayList(); for (final File file : files) { final int i = rand.nextInt(r.size() + 1); if (i == r.size()) { r.add(file); } else { r.add(r.get(i)); r.set(i, file); } } return r; } static void analyzeState(final SGDInfo info, final int leakType, final int k, final State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best) throws IOException { final int bump = info.getBumps()[(int) Math.floor(info.getStep()) % info.getBumps().length]; final int scale = (int) Math.pow(10, Math.floor(info.getStep() / info.getBumps().length)); double maxBeta; double nonZeros; double positive; double norm; double lambda = 0; double mu = 0; if (best != null) { final CrossFoldLearner state = best.getPayload().getLearner(); info.setAverageCorrect(state.percentCorrect()); info.setAverageLL(state.logLikelihood()); final OnlineLogisticRegression model = state.getModels().get(0); // finish off pending regularization model.close(); final Matrix beta = model.getBeta(); maxBeta = beta.aggregate(Functions.MAX, Functions.ABS); nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() { @Override public double apply(final double v) { return Math.abs(v) > 1.0e-6 ? 1 : 0; } }); positive = beta.aggregate(Functions.PLUS, new DoubleFunction() { @Override public double apply(final double v) { return v > 0 ? 1 : 0; } }); norm = beta.aggregate(Functions.PLUS, Functions.ABS); lambda = best.getMappedParams()[0]; mu = best.getMappedParams()[1]; } else { maxBeta = 0; nonZeros = 0; positive = 0; norm = 0; } if (k % (bump * scale) == 0) { if (best != null) { ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model", best.getPayload().getLearner().getModels().get(0)); } info.setStep(info.getStep() + 0.25); System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu); System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, info.getAverageLL(), info.getAverageCorrect() * 100, LEAK_LABELS[leakType % 3]); } } }