Java tutorial
/* * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ /* * LeastMedSq.java * * Copyright (C) 2001 University of Waikato */ package weka.classifiers.functions; import java.util.Collections; import java.util.Enumeration; import java.util.Random; import java.util.Vector; import weka.classifiers.AbstractClassifier; import weka.core.Capabilities; import weka.core.Capabilities.Capability; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; import weka.core.OptionHandler; import weka.core.RevisionUtils; import weka.core.TechnicalInformation; import weka.core.TechnicalInformation.Field; import weka.core.TechnicalInformation.Type; import weka.core.TechnicalInformationHandler; import weka.core.Utils; import weka.filters.Filter; import weka.filters.supervised.attribute.NominalToBinary; import weka.filters.unsupervised.attribute.ReplaceMissingValues; import weka.filters.unsupervised.instance.RemoveRange; /** * <!-- globalinfo-start --> Implements a least median sqaured linear regression * utilising the existing weka LinearRegression class to form predictions. <br/> * Least squared regression functions are generated from random subsamples of * the data. The least squared regression with the lowest meadian squared error * is chosen as the final model.<br/> * <br/> * The basis of the algorithm is <br/> * <br/> * Peter J. Rousseeuw, Annick M. Leroy (1987). Robust regression and outlier * detection. . * <p/> * <!-- globalinfo-end --> * * <!-- technical-bibtex-start --> BibTeX: * * <pre> * @book{Rousseeuw1987, * author = {Peter J. Rousseeuw and Annick M. Leroy}, * title = {Robust regression and outlier detection}, * year = {1987} * } * </pre> * <p/> * <!-- technical-bibtex-end --> * * <!-- options-start --> Valid options are: * <p/> * * <pre> * -S <sample size> * Set sample size * (default: 4) * </pre> * * <pre> * -G <seed> * Set the seed used to generate samples * (default: 0) * </pre> * * <pre> * -D * Produce debugging output * (default no debugging output) * </pre> * * <!-- options-end --> * * @author Tony Voyle (tv6@waikato.ac.nz) * @version $Revision$ */ public class LeastMedSq extends AbstractClassifier implements OptionHandler, TechnicalInformationHandler { /** for serialization */ static final long serialVersionUID = 4288954049987652970L; private double[] m_Residuals; private double[] m_weight; // private double m_SSR; NOT USED private double m_scalefactor; private double m_bestMedian = Double.POSITIVE_INFINITY; private LinearRegression m_currentRegression; private LinearRegression m_bestRegression; private LinearRegression m_ls; private Instances m_Data; private Instances m_RLSData; private Instances m_SubSample; private ReplaceMissingValues m_MissingFilter; private NominalToBinary m_TransformFilter; private RemoveRange m_SplitFilter; private int m_samplesize = 4; private int m_samples; // private boolean m_israndom = false; NOT USED private Random m_random; private long m_randomseed = 0; private weka.core.SelectedTag m_tag = new weka.core.SelectedTag(1, LinearRegression.TAGS_SELECTION); /** * Returns a string describing this classifier * * @return a description of the classifier suitable for displaying in the * explorer/experimenter gui */ public String globalInfo() { return "Implements a least median sqaured linear regression utilising the " + "existing weka LinearRegression class to form predictions. \n" + "Least squared regression functions are generated from random subsamples of " + "the data. The least squared regression with the lowest meadian squared error " + "is chosen as the final model.\n\n" + "The basis of the algorithm is \n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing detailed * information about the technical background of this class, e.g., paper * reference or book this class is based on. * * @return the technical information about this class */ @Override public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.BOOK); result.setValue(Field.AUTHOR, "Peter J. Rousseeuw and Annick M. Leroy"); result.setValue(Field.YEAR, "1987"); result.setValue(Field.TITLE, "Robust regression and outlier detection"); return result; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ @Override public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NUMERIC_CLASS); result.enable(Capability.DATE_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); return result; } /** * Build lms regression * * @param data training data * @throws Exception if an error occurs */ @Override public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); cleanUpData(data); getSamples(); findBestRegression(); buildRLSRegression(); } // buildClassifier /** * Classify a given instance using the best generated LinearRegression * Classifier. * * @param instance instance to be classified * @return class value * @throws Exception if an error occurs */ @Override public double classifyInstance(Instance instance) throws Exception { Instance transformedInstance = instance; m_TransformFilter.input(transformedInstance); transformedInstance = m_TransformFilter.output(); m_MissingFilter.input(transformedInstance); transformedInstance = m_MissingFilter.output(); return m_ls.classifyInstance(transformedInstance); } // classifyInstance /** * Cleans up data * * @param data data to be cleaned up * @throws Exception if an error occurs */ private void cleanUpData(Instances data) throws Exception { m_Data = data; m_TransformFilter = new NominalToBinary(); m_TransformFilter.setInputFormat(m_Data); m_Data = Filter.useFilter(m_Data, m_TransformFilter); m_MissingFilter = new ReplaceMissingValues(); m_MissingFilter.setInputFormat(m_Data); m_Data = Filter.useFilter(m_Data, m_MissingFilter); m_Data.deleteWithMissingClass(); } /** * Gets the number of samples to use. * * @throws Exception if an error occurs */ private void getSamples() throws Exception { int stuf[] = new int[] { 500, 50, 22, 17, 15, 14 }; if (m_samplesize < 7) { if (m_Data.numInstances() < stuf[m_samplesize - 1]) { m_samples = combinations(m_Data.numInstances(), m_samplesize); } else { m_samples = m_samplesize * 500; } } else { m_samples = 3000; } if (m_Debug) { System.out.println("m_samplesize: " + m_samplesize); System.out.println("m_samples: " + m_samples); System.out.println("m_randomseed: " + m_randomseed); } } /** * Set up the random number generator * */ private void setRandom() { m_random = new Random(getRandomSeed()); } /** * Finds the best regression generated from m_samples random samples from the * training data * * @throws Exception if an error occurs */ private void findBestRegression() throws Exception { setRandom(); m_bestMedian = Double.POSITIVE_INFINITY; if (m_Debug) { System.out.println("Starting:"); } for (int s = 0 /* , r = 0 NOT USED */; s < m_samples; s++ /* , r++ NOT USED */) { if (m_Debug) { if (s % (m_samples / 100) == 0) { System.out.print("*"); } } genRegression(); getMedian(); } if (m_Debug) { System.out.println(""); } m_currentRegression = m_bestRegression; } /** * Generates a LinearRegression classifier from the current m_SubSample * * @throws Exception if an error occurs */ private void genRegression() throws Exception { m_currentRegression = new LinearRegression(); m_currentRegression.setAttributeSelectionMethod(m_tag); // m_currentRegression.setOptions(new String[] { "-S", "1" }); selectSubSample(m_Data); m_currentRegression.buildClassifier(m_SubSample); } /** * Finds residuals (squared) for the current regression. * * @throws Exception if an error occurs */ private void findResiduals() throws Exception { // m_SSR = 0; NOT USED m_Residuals = new double[m_Data.numInstances()]; for (int i = 0; i < m_Data.numInstances(); i++) { m_Residuals[i] = m_currentRegression.classifyInstance(m_Data.instance(i)); m_Residuals[i] -= m_Data.instance(i).value(m_Data.classAttribute()); m_Residuals[i] *= m_Residuals[i]; // m_SSR += m_Residuals[i]; NOT USED } } /** * finds the median residual squared for the current regression * * @throws Exception if an error occurs */ private void getMedian() throws Exception { findResiduals(); int p = m_Residuals.length; select(m_Residuals, 0, p - 1, p / 2); if (m_Residuals[p / 2] < m_bestMedian) { m_bestMedian = m_Residuals[p / 2]; m_bestRegression = m_currentRegression; } } /** * Returns a string representing the best LinearRegression classifier found. * * @return String representing the regression */ @Override public String toString() { if (m_ls == null) { return "model has not been built"; } return m_ls.toString(); } /** * Builds a weight function removing instances with an abnormally high scaled * residual * * @throws Exception if weight building fails */ private void buildWeight() throws Exception { findResiduals(); m_scalefactor = 1.4826 * (1 + 5 / (m_Data.numInstances() - m_Data.numAttributes())) * Math.sqrt(m_bestMedian); m_weight = new double[m_Residuals.length]; for (int i = 0; i < m_Residuals.length; i++) { m_weight[i] = ((Math.sqrt(m_Residuals[i]) / m_scalefactor < 2.5) ? 1.0 : 0.0); } } /** * Builds a new LinearRegression without the 'bad' data found by buildWeight * * @throws Exception if building fails */ private void buildRLSRegression() throws Exception { buildWeight(); m_RLSData = new Instances(m_Data); int x = 0; int y = 0; int n = m_RLSData.numInstances(); while (y < n) { if (m_weight[x] == 0) { m_RLSData.delete(y); n = m_RLSData.numInstances(); y--; } x++; y++; } if (m_RLSData.numInstances() == 0) { System.err.println("rls regression unbuilt"); m_ls = m_currentRegression; } else { m_ls = new LinearRegression(); // m_ls.setOptions(new String[] { "-S", "1" }); m_ls.setAttributeSelectionMethod(m_tag); m_ls.buildClassifier(m_RLSData); m_currentRegression = m_ls; } } /** * Finds the kth number in an array * * @param a an array of numbers * @param l left pointer * @param r right pointer * @param k position of number to be found */ private static void select(double[] a, int l, int r, int k) { if (r <= l) { return; } int i = partition(a, l, r); if (i > k) { select(a, l, i - 1, k); } if (i < k) { select(a, i + 1, r, k); } } /** * Partitions an array of numbers such that all numbers less than that at * index r, between indexes l and r will have a smaller index and all numbers * greater than will have a larger index * * @param a an array of numbers * @param l left pointer * @param r right pointer * @return final index of number originally at r */ private static int partition(double[] a, int l, int r) { int i = l - 1, j = r; double v = a[r], temp; while (true) { while (a[++i] < v) { ; } while (v < a[--j]) { if (j == l) { break; } } if (i >= j) { break; } temp = a[i]; a[i] = a[j]; a[j] = temp; } temp = a[i]; a[i] = a[r]; a[r] = temp; return i; } /** * Produces a random sample from m_Data in m_SubSample * * @param data data from which to take sample * @throws Exception if an error occurs */ private void selectSubSample(Instances data) throws Exception { m_SplitFilter = new RemoveRange(); m_SplitFilter.setInvertSelection(true); m_SubSample = data; m_SplitFilter.setInputFormat(m_SubSample); m_SplitFilter.setInstancesIndices(selectIndices(m_SubSample)); m_SubSample = Filter.useFilter(m_SubSample, m_SplitFilter); } /** * Returns a string suitable for passing to RemoveRange consisting of * m_samplesize indices. * * @param data dataset from which to take indicese * @return string of indices suitable for passing to RemoveRange */ private String selectIndices(Instances data) { StringBuffer text = new StringBuffer(); for (int i = 0, x = 0; i < m_samplesize; i++) { do { x = (int) (m_random.nextDouble() * data.numInstances()); } while (x == 0); text.append(Integer.toString(x)); if (i < m_samplesize - 1) { text.append(","); } else { text.append("\n"); } } return text.toString(); } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String sampleSizeTipText() { return "Set the size of the random samples used to generate the least sqaured " + "regression functions."; } /** * sets number of samples * * @param samplesize value */ public void setSampleSize(int samplesize) { m_samplesize = samplesize; } /** * gets number of samples * * @return value */ public int getSampleSize() { return m_samplesize; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String randomSeedTipText() { return "Set the seed for selecting random subsamples of the training data."; } /** * Set the seed for the random number generator * * @param randomseed the seed */ public void setRandomSeed(long randomseed) { m_randomseed = randomseed; } /** * get the seed for the random number generator * * @return the seed value */ public long getRandomSeed() { return m_randomseed; } /** * Returns an enumeration of all the available options.. * * @return an enumeration of all available options. */ @Override public Enumeration<Option> listOptions() { Vector<Option> newVector = new Vector<Option>(4); newVector.addElement(new Option("\tSet sample size\n" + "\t(default: 4)\n", "S", 4, "-S <sample size>")); newVector.addElement( new Option("\tSet the seed used to generate samples\n" + "\t(default: 0)\n", "G", 0, "-G <seed>")); newVector.addAll(Collections.list(super.listOptions())); return newVector.elements(); } /** * Sets the OptionHandler's options using the given list. All options will be * set (or reset) during this call (i.e. incremental setting of options is not * possible). * * <!-- options-start --> Valid options are: * <p/> * * <pre> * -S <sample size> * Set sample size * (default: 4) * </pre> * * <pre> * -G <seed> * Set the seed used to generate samples * (default: 0) * </pre> * * <pre> * -D * Produce debugging output * (default no debugging output) * </pre> * * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ @Override public void setOptions(String[] options) throws Exception { String curropt = Utils.getOption('S', options); if (curropt.length() != 0) { setSampleSize(Integer.parseInt(curropt)); } else { setSampleSize(4); } curropt = Utils.getOption('G', options); if (curropt.length() != 0) { setRandomSeed(Long.parseLong(curropt)); } else { setRandomSeed(0); } super.setOptions(options); Utils.checkForRemainingOptions(options); } /** * Gets the current option settings for the OptionHandler. * * @return the list of current option settings as an array of strings */ @Override public String[] getOptions() { Vector<String> options = new Vector<String>(); options.add("-S"); options.add("" + getSampleSize()); options.add("-G"); options.add("" + getRandomSeed()); Collections.addAll(options, super.getOptions()); return options.toArray(new String[0]); } /** * Produces the combination nCr * * @param n * @param r * @return the combination * @throws Exception if r is greater than n */ @SuppressWarnings("unused") public static int combinations(int n, int r) throws Exception { int c = 1, denom = 1, num = 1, i, orig = r; if (r > n) { throw new Exception("r must be less that or equal to n."); } r = Math.min(r, n - r); for (i = 1; i <= r; i++) { num *= n - i + 1; denom *= i; } c = num / denom; if (false) { System.out.println("n: " + n + " r: " + orig + " num: " + num + " denom: " + denom + " c: " + c); } return c; } /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision$"); } /** * generate a Linear regression predictor for testing * * @param argv options */ public static void main(String[] argv) { runClassifier(new LeastMedSq(), argv); } // main } // lmr