test.org.moa.opencl.IBk.java Source code

Java tutorial

Introduction

Here is the source code for test.org.moa.opencl.IBk.java

Source

package test.org.moa.opencl;

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    IBk.java
 *    Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
 *
 */

import java.util.Enumeration;
import java.util.Vector;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.rules.ZeroR;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.neighboursearch.LinearNNSearch;
import weka.core.neighboursearch.NearestNeighbourSearch;

/**
 <!-- globalinfo-start -->
 * K-nearest neighbours classifier. Can select appropriate value of K based on cross-validation. Can also do distance weighting.<br/>
 * <br/>
 * For more information, see<br/>
 * <br/>
 * D. Aha, D. Kibler (1991). Instance-based learning algorithms. Machine Learning. 6:37-66.
 * <p/>
 <!-- globalinfo-end -->
 * 
 <!-- technical-bibtex-start -->
 * BibTeX:
 * <pre>
 * &#64;article{Aha1991,
 *    author = {D. Aha and D. Kibler},
 *    journal = {Machine Learning},
 *    pages = {37-66},
 *    title = {Instance-based learning algorithms},
 *    volume = {6},
 *    year = {1991}
 * }
 * </pre>
 * <p/>
 <!-- technical-bibtex-end -->
 *
 <!-- options-start -->
 * Valid options are: <p/>
 * 
 * <pre> -I
 *  Weight neighbours by the inverse of their distance
 *  (use when k &gt; 1)</pre>
 * 
 * <pre> -F
 *  Weight neighbours by 1 - their distance
 *  (use when k &gt; 1)</pre>
 * 
 * <pre> -K &lt;number of neighbors&gt;
 *  Number of nearest neighbours (k) used in classification.
 *  (Default = 1)</pre>
 * 
 * <pre> -E
 *  Minimise mean squared error rather than mean absolute
 *  error when using -X option with numeric prediction.</pre>
 * 
 * <pre> -W &lt;window size&gt;
 *  Maximum number of training instances maintained.
 *  Training instances are dropped FIFO. (Default = no window)</pre>
 * 
 * <pre> -X
 *  Select the number of nearest neighbours between 1
 *  and the k value specified using hold-one-out evaluation
 *  on the training data (use when k &gt; 1)</pre>
 * 
 * <pre> -A
 *  The nearest neighbour search algorithm to use (default: weka.core.neighboursearch.LinearNNSearch).
 * </pre>
 * 
 <!-- options-end -->
 *
 * @author Stuart Inglis (singlis@cs.waikato.ac.nz)
 * @author Len Trigg (trigg@cs.waikato.ac.nz)
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @version $Revision: 8034 $
 */
public class IBk extends AbstractClassifier implements OptionHandler, UpdateableClassifier,
        WeightedInstancesHandler, TechnicalInformationHandler, AdditionalMeasureProducer {

    /** for serialization. */
    static final long serialVersionUID = -3080186098777067172L;

    /** The training instances used for classification. */
    protected Instances m_Train;

    /** The number of class values (or 1 if predicting numeric). */
    protected int m_NumClasses;

    /** The class attribute type. */
    protected int m_ClassType;

    /** The number of neighbours to use for classification (currently). */
    protected int m_kNN;

    /**
     * The value of kNN provided by the user. This may differ from
     * m_kNN if cross-validation is being used.
     */
    protected int m_kNNUpper;

    /**
     * Whether the value of k selected by cross validation has
     * been invalidated by a change in the training instances.
     */
    protected boolean m_kNNValid;

    /**
     * The maximum number of training instances allowed. When
     * this limit is reached, old training instances are removed,
     * so the training data is "windowed". Set to 0 for unlimited
     * numbers of instances.
     */
    protected int m_WindowSize;

    /** Whether the neighbours should be distance-weighted. */
    protected int m_DistanceWeighting;

    /** Whether to select k by cross validation. */
    protected boolean m_CrossValidate;

    /**
     * Whether to minimise mean squared error rather than mean absolute
     * error when cross-validating on numeric prediction tasks.
     */
    protected boolean m_MeanSquared;

    /** Default ZeroR model to use when there are no training instances */
    protected ZeroR m_defaultModel;

    /** no weighting. */
    public static final int WEIGHT_NONE = 1;
    /** weight by 1/distance. */
    public static final int WEIGHT_INVERSE = 2;
    /** weight by 1-distance. */
    public static final int WEIGHT_SIMILARITY = 4;
    /** possible instance weighting methods. */
    public static final Tag[] TAGS_WEIGHTING = { new Tag(WEIGHT_NONE, "No distance weighting"),
            new Tag(WEIGHT_INVERSE, "Weight by 1/distance"), new Tag(WEIGHT_SIMILARITY, "Weight by 1-distance") };

    /** for nearest-neighbor search. */
    protected NearestNeighbourSearch m_NNSearch = new LinearNNSearch();

    /** The number of attributes the contribute to a prediction. */
    protected double m_NumAttributesUsed;

    /**
     * IBk classifier. Simple instance-based learner that uses the class
     * of the nearest k training instances for the class of the test
     * instances.
     *
     * @param k the number of nearest neighbors to use for prediction
     */
    public IBk(int k) {

        init();
        setKNN(k);
    }

    /**
     * IB1 classifer. Instance-based learner. Predicts the class of the
     * single nearest training instance for each test instance.
     */
    public IBk() {

        init();
    }

    /**
     * Returns a string describing classifier.
     * @return a description suitable for
     * displaying in the explorer/experimenter gui
     */
    public String globalInfo() {

        return "K-nearest neighbours classifier. Can "
                + "select appropriate value of K based on cross-validation. Can also do "
                + "distance weighting.\n\n" + "For more information, see\n\n"
                + getTechnicalInformation().toString();
    }

    /**
     * Returns an instance of a TechnicalInformation object, containing 
     * detailed information about the technical background of this class,
     * e.g., paper reference or book this class is based on.
     * 
     * @return the technical information about this class
     */
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result;

        result = new TechnicalInformation(Type.ARTICLE);
        result.setValue(Field.AUTHOR, "D. Aha and D. Kibler");
        result.setValue(Field.YEAR, "1991");
        result.setValue(Field.TITLE, "Instance-based learning algorithms");
        result.setValue(Field.JOURNAL, "Machine Learning");
        result.setValue(Field.VOLUME, "6");
        result.setValue(Field.PAGES, "37-66");

        return result;
    }

    /**
     * Returns the tip text for this property.
     * @return tip text for this property suitable for
     * displaying in the explorer/experimenter gui
     */
    public String KNNTipText() {
        return "The number of neighbours to use.";
    }

    /**
     * Set the number of neighbours the learner is to use.
     *
     * @param k the number of neighbours.
     */
    public void setKNN(int k) {
        m_kNN = k;
        m_kNNUpper = k;
        m_kNNValid = false;
    }

    /**
     * Gets the number of neighbours the learner will use.
     *
     * @return the number of neighbours.
     */
    public int getKNN() {

        return m_kNN;
    }

    /**
     * Returns the tip text for this property.
     * @return tip text for this property suitable for
     * displaying in the explorer/experimenter gui
     */
    public String windowSizeTipText() {
        return "Gets the maximum number of instances allowed in the training "
                + "pool. The addition of new instances above this value will result "
                + "in old instances being removed. A value of 0 signifies no limit "
                + "to the number of training instances.";
    }

    /**
     * Gets the maximum number of instances allowed in the training
     * pool. The addition of new instances above this value will result
     * in old instances being removed. A value of 0 signifies no limit
     * to the number of training instances.
     *
     * @return Value of WindowSize.
     */
    public int getWindowSize() {

        return m_WindowSize;
    }

    /**
     * Sets the maximum number of instances allowed in the training
     * pool. The addition of new instances above this value will result
     * in old instances being removed. A value of 0 signifies no limit
     * to the number of training instances.
     *
     * @param newWindowSize Value to assign to WindowSize.
     */
    public void setWindowSize(int newWindowSize) {

        m_WindowSize = newWindowSize;
    }

    /**
     * Returns the tip text for this property.
     * @return tip text for this property suitable for
     * displaying in the explorer/experimenter gui
     */
    public String distanceWeightingTipText() {

        return "Gets the distance weighting method used.";
    }

    /**
     * Gets the distance weighting method used. Will be one of
     * WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY
     *
     * @return the distance weighting method used.
     */
    public SelectedTag getDistanceWeighting() {

        return new SelectedTag(m_DistanceWeighting, TAGS_WEIGHTING);
    }

    /**
     * Sets the distance weighting method used. Values other than
     * WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY will be ignored.
     *
     * @param newMethod the distance weighting method to use
     */
    public void setDistanceWeighting(SelectedTag newMethod) {

        if (newMethod.getTags() == TAGS_WEIGHTING) {
            m_DistanceWeighting = newMethod.getSelectedTag().getID();
        }
    }

    /**
     * Returns the tip text for this property.
     * @return tip text for this property suitable for
     * displaying in the explorer/experimenter gui
     */
    public String meanSquaredTipText() {

        return "Whether the mean squared error is used rather than mean "
                + "absolute error when doing cross-validation for regression problems.";
    }

    /**
     * Gets whether the mean squared error is used rather than mean
     * absolute error when doing cross-validation.
     *
     * @return true if so.
     */
    public boolean getMeanSquared() {

        return m_MeanSquared;
    }

    /**
     * Sets whether the mean squared error is used rather than mean
     * absolute error when doing cross-validation.
     *
     * @param newMeanSquared true if so.
     */
    public void setMeanSquared(boolean newMeanSquared) {

        m_MeanSquared = newMeanSquared;
    }

    /**
     * Returns the tip text for this property.
     * @return tip text for this property suitable for
     * displaying in the explorer/experimenter gui
     */
    public String crossValidateTipText() {

        return "Whether hold-one-out cross-validation will be used " + "to select the best k value.";
    }

    /**
     * Gets whether hold-one-out cross-validation will be used
     * to select the best k value.
     *
     * @return true if cross-validation will be used.
     */
    public boolean getCrossValidate() {

        return m_CrossValidate;
    }

    /**
     * Sets whether hold-one-out cross-validation will be used
     * to select the best k value.
     *
     * @param newCrossValidate true if cross-validation should be used.
     */
    public void setCrossValidate(boolean newCrossValidate) {

        m_CrossValidate = newCrossValidate;
    }

    /**
     * Returns the tip text for this property.
     * @return tip text for this property suitable for
     * displaying in the explorer/experimenter gui
     */
    public String nearestNeighbourSearchAlgorithmTipText() {
        return "The nearest neighbour search algorithm to use "
                + "(Default: weka.core.neighboursearch.LinearNNSearch).";
    }

    /**
     * Returns the current nearestNeighbourSearch algorithm in use.
     * @return the NearestNeighbourSearch algorithm currently in use.
     */
    public NearestNeighbourSearch getNearestNeighbourSearchAlgorithm() {
        return m_NNSearch;
    }

    /**
     * Sets the nearestNeighbourSearch algorithm to be used for finding nearest
     * neighbour(s).
     * @param nearestNeighbourSearchAlgorithm - The NearestNeighbourSearch class.
     */
    public void setNearestNeighbourSearchAlgorithm(NearestNeighbourSearch nearestNeighbourSearchAlgorithm) {
        m_NNSearch = nearestNeighbourSearchAlgorithm;
    }

    /**
     * Get the number of training instances the classifier is currently using.
     * 
     * @return the number of training instances the classifier is currently using
     */
    public int getNumTraining() {

        return m_Train.numInstances();
    }

    /**
     * Returns default capabilities of the classifier.
     *
     * @return      the capabilities of this classifier
     */
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();

        // attributes
        result.enable(Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capability.DATE_ATTRIBUTES);
        result.enable(Capability.MISSING_VALUES);

        // class
        result.enable(Capability.NOMINAL_CLASS);
        result.enable(Capability.NUMERIC_CLASS);
        result.enable(Capability.DATE_CLASS);
        result.enable(Capability.MISSING_CLASS_VALUES);

        // instances
        result.setMinimumNumberInstances(0);

        return result;
    }

    /**
     * Generates the classifier.
     *
     * @param instances set of instances serving as training data 
     * @throws Exception if the classifier has not been generated successfully
     */
    public void buildClassifier(Instances instances) throws Exception {

        // can classifier handle the data?
        getCapabilities().testWithFail(instances);

        // remove instances with missing class
        instances = new Instances(instances);
        instances.deleteWithMissingClass();

        m_NumClasses = instances.numClasses();
        m_ClassType = instances.classAttribute().type();
        m_Train = new Instances(instances, 0, instances.numInstances());

        // Throw away initial instances until within the specified window size
        if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) {
            m_Train = new Instances(m_Train, m_Train.numInstances() - m_WindowSize, m_WindowSize);
        }

        m_NumAttributesUsed = 0.0;
        for (int i = 0; i < m_Train.numAttributes(); i++) {
            if ((i != m_Train.classIndex())
                    && (m_Train.attribute(i).isNominal() || m_Train.attribute(i).isNumeric())) {
                m_NumAttributesUsed += 1.0;
            }
        }

        m_NNSearch.setInstances(m_Train);

        // Invalidate any currently cross-validation selected k
        m_kNNValid = false;

        m_defaultModel = new ZeroR();
        m_defaultModel.buildClassifier(instances);
    }

    /**
     * Adds the supplied instance to the training set.
     *
     * @param instance the instance to add
     * @throws Exception if instance could not be incorporated
     * successfully
     */
    public void updateClassifier(Instance instance) throws Exception {

        if (m_Train.equalHeaders(instance.dataset()) == false) {
            throw new Exception("Incompatible instance types\n" + m_Train.equalHeadersMsg(instance.dataset()));
        }
        if (instance.classIsMissing()) {
            return;
        }

        m_Train.add(instance);
        m_NNSearch.update(instance);
        m_kNNValid = false;
        if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {
            boolean deletedInstance = false;
            while (m_Train.numInstances() > m_WindowSize) {
                m_Train.delete(0);
                deletedInstance = true;
            }
            //rebuild datastructure KDTree currently can't delete
            if (deletedInstance == true)
                m_NNSearch.setInstances(m_Train);
        }
    }

    /**
     * Calculates the class membership probabilities for the given test instance.
     *
     * @param instance the instance to be classified
     * @return predicted class probability distribution
     * @throws Exception if an error occurred during the prediction
     */
    public double[] distributionForInstance(Instance instance) throws Exception {

        if (m_Train.numInstances() == 0) {
            //throw new Exception("No training instances!");
            return m_defaultModel.distributionForInstance(instance);
        }
        if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {
            m_kNNValid = false;
            boolean deletedInstance = false;
            while (m_Train.numInstances() > m_WindowSize) {
                m_Train.delete(0);
            }
            //rebuild datastructure KDTree currently can't delete
            if (deletedInstance == true)
                m_NNSearch.setInstances(m_Train);
        }

        // Select k by cross validation
        if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) {
            crossValidate();
        }

        m_NNSearch.addInstanceInfo(instance);

        Instances neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN);
        double[] distances = m_NNSearch.getDistances();

        System.out.print("distances weka ");
        for (int i = 0; i < distances.length; ++i)
            System.out.print(" " + distances[i]);
        System.out.println();
        System.out.println("Neighbours");
        for (int i = 0; i < neighbours.size(); ++i)
            System.out.println(neighbours.get(i));
        double[] distribution = makeDistribution(neighbours, distances);

        return distribution;
    }

    /**
     * Returns an enumeration describing the available options.
     *
     * @return an enumeration of all the available options.
     */
    public Enumeration listOptions() {

        Vector newVector = new Vector(8);

        newVector.addElement(new Option(
                "\tWeight neighbours by the inverse of their distance\n" + "\t(use when k > 1)", "I", 0, "-I"));
        newVector.addElement(
                new Option("\tWeight neighbours by 1 - their distance\n" + "\t(use when k > 1)", "F", 0, "-F"));
        newVector.addElement(
                new Option("\tNumber of nearest neighbours (k) used in classification.\n" + "\t(Default = 1)", "K",
                        1, "-K <number of neighbors>"));
        newVector.addElement(new Option("\tMinimise mean squared error rather than mean absolute\n"
                + "\terror when using -X option with numeric prediction.", "E", 0, "-E"));
        newVector.addElement(new Option(
                "\tMaximum number of training instances maintained.\n"
                        + "\tTraining instances are dropped FIFO. (Default = no window)",
                "W", 1, "-W <window size>"));
        newVector.addElement(new Option("\tSelect the number of nearest neighbours between 1\n"
                + "\tand the k value specified using hold-one-out evaluation\n"
                + "\ton the training data (use when k > 1)", "X", 0, "-X"));
        newVector.addElement(new Option("\tThe nearest neighbour search algorithm to use "
                + "(default: weka.core.neighboursearch.LinearNNSearch).\n", "A", 0, "-A"));

        return newVector.elements();
    }

    /**
     * Parses a given list of options. <p/>
     *
     <!-- options-start -->
     * Valid options are: <p/>
     * 
     * <pre> -I
     *  Weight neighbours by the inverse of their distance
     *  (use when k &gt; 1)</pre>
     * 
     * <pre> -F
     *  Weight neighbours by 1 - their distance
     *  (use when k &gt; 1)</pre>
     * 
     * <pre> -K &lt;number of neighbors&gt;
     *  Number of nearest neighbours (k) used in classification.
     *  (Default = 1)</pre>
     * 
     * <pre> -E
     *  Minimise mean squared error rather than mean absolute
     *  error when using -X option with numeric prediction.</pre>
     * 
     * <pre> -W &lt;window size&gt;
     *  Maximum number of training instances maintained.
     *  Training instances are dropped FIFO. (Default = no window)</pre>
     * 
     * <pre> -X
     *  Select the number of nearest neighbours between 1
     *  and the k value specified using hold-one-out evaluation
     *  on the training data (use when k &gt; 1)</pre>
     * 
     * <pre> -A
     *  The nearest neighbour search algorithm to use (default: weka.core.neighboursearch.LinearNNSearch).
     * </pre>
     * 
     <!-- options-end -->
     *
     * @param options the list of options as an array of strings
     * @throws Exception if an option is not supported
     */
    public void setOptions(String[] options) throws Exception {

        String knnString = Utils.getOption('K', options);
        if (knnString.length() != 0) {
            setKNN(Integer.parseInt(knnString));
        } else {
            setKNN(1);
        }
        String windowString = Utils.getOption('W', options);
        if (windowString.length() != 0) {
            setWindowSize(Integer.parseInt(windowString));
        } else {
            setWindowSize(0);
        }
        if (Utils.getFlag('I', options)) {
            setDistanceWeighting(new SelectedTag(WEIGHT_INVERSE, TAGS_WEIGHTING));
        } else if (Utils.getFlag('F', options)) {
            setDistanceWeighting(new SelectedTag(WEIGHT_SIMILARITY, TAGS_WEIGHTING));
        } else {
            setDistanceWeighting(new SelectedTag(WEIGHT_NONE, TAGS_WEIGHTING));
        }
        setCrossValidate(Utils.getFlag('X', options));
        setMeanSquared(Utils.getFlag('E', options));

        String nnSearchClass = Utils.getOption('A', options);
        if (nnSearchClass.length() != 0) {
            String nnSearchClassSpec[] = Utils.splitOptions(nnSearchClass);
            if (nnSearchClassSpec.length == 0) {
                throw new Exception("Invalid NearestNeighbourSearch algorithm " + "specification string.");
            }
            String className = nnSearchClassSpec[0];
            nnSearchClassSpec[0] = "";

            setNearestNeighbourSearchAlgorithm((NearestNeighbourSearch) Utils.forName(NearestNeighbourSearch.class,
                    className, nnSearchClassSpec));
        } else
            this.setNearestNeighbourSearchAlgorithm(new LinearNNSearch());

        Utils.checkForRemainingOptions(options);
    }

    /**
     * Gets the current settings of IBk.
     *
     * @return an array of strings suitable for passing to setOptions()
     */
    public String[] getOptions() {

        String[] options = new String[11];
        int current = 0;
        options[current++] = "-K";
        options[current++] = "" + getKNN();
        options[current++] = "-W";
        options[current++] = "" + m_WindowSize;
        if (getCrossValidate()) {
            options[current++] = "-X";
        }
        if (getMeanSquared()) {
            options[current++] = "-E";
        }
        if (m_DistanceWeighting == WEIGHT_INVERSE) {
            options[current++] = "-I";
        } else if (m_DistanceWeighting == WEIGHT_SIMILARITY) {
            options[current++] = "-F";
        }

        options[current++] = "-A";
        options[current++] = m_NNSearch.getClass().getName() + " " + Utils.joinOptions(m_NNSearch.getOptions());

        while (current < options.length) {
            options[current++] = "";
        }

        return options;
    }

    /**
     * Returns an enumeration of the additional measure names 
     * produced by the neighbour search algorithm, plus the chosen K in case
     * cross-validation is enabled.
     * 
     * @return an enumeration of the measure names
     */
    public Enumeration enumerateMeasures() {
        if (m_CrossValidate) {
            Enumeration enm = m_NNSearch.enumerateMeasures();
            Vector measures = new Vector();
            while (enm.hasMoreElements())
                measures.add(enm.nextElement());
            measures.add("measureKNN");
            return measures.elements();
        } else {
            return m_NNSearch.enumerateMeasures();
        }
    }

    /**
     * Returns the value of the named measure from the 
     * neighbour search algorithm, plus the chosen K in case
     * cross-validation is enabled.
     * 
     * @param additionalMeasureName the name of the measure to query for its value
     * @return the value of the named measure
     * @throws IllegalArgumentException if the named measure is not supported
     */
    public double getMeasure(String additionalMeasureName) {
        if (additionalMeasureName.equals("measureKNN"))
            return m_kNN;
        else
            return m_NNSearch.getMeasure(additionalMeasureName);
    }

    /**
     * Returns a description of this classifier.
     *
     * @return a description of this classifier as a string.
     */
    public String toString() {

        if (m_Train == null) {
            return "IBk: No model built yet.";
        }

        if (m_Train.numInstances() == 0) {
            return "Warning: no training instances - ZeroR model used.";
        }

        if (!m_kNNValid && m_CrossValidate) {
            crossValidate();
        }

        String result = "IB1 instance-based classifier\n" + "using " + m_kNN;

        switch (m_DistanceWeighting) {
        case WEIGHT_INVERSE:
            result += " inverse-distance-weighted";
            break;
        case WEIGHT_SIMILARITY:
            result += " similarity-weighted";
            break;
        }
        result += " nearest neighbour(s) for classification\n";

        if (m_WindowSize != 0) {
            result += "using a maximum of " + m_WindowSize + " (windowed) training instances\n";
        }
        return result;
    }

    /**
     * Initialise scheme variables.
     */
    protected void init() {

        setKNN(1);
        m_WindowSize = 0;
        m_DistanceWeighting = WEIGHT_NONE;
        m_CrossValidate = false;
        m_MeanSquared = false;
    }

    /**
     * Turn the list of nearest neighbors into a probability distribution.
     *
     * @param neighbours the list of nearest neighboring instances
     * @param distances the distances of the neighbors
     * @return the probability distribution
     * @throws Exception if computation goes wrong or has no class attribute
     */
    protected double[] makeDistribution(Instances neighbours, double[] distances) throws Exception {

        double total = 0, weight;
        double[] distribution = new double[m_NumClasses];

        // Set up a correction to the estimator
        if (m_ClassType == Attribute.NOMINAL) {
            for (int i = 0; i < m_NumClasses; i++) {
                distribution[i] = 1.0 / Math.max(1, m_Train.numInstances());
            }
            total = (double) m_NumClasses / Math.max(1, m_Train.numInstances());
        }

        for (int i = 0; i < neighbours.numInstances(); i++) {
            // Collect class counts
            Instance current = neighbours.instance(i);
            distances[i] = distances[i] * distances[i];
            distances[i] = Math.sqrt(distances[i] / m_NumAttributesUsed);
            switch (m_DistanceWeighting) {
            case WEIGHT_INVERSE:
                weight = 1.0 / (distances[i] + 0.001); // to avoid div by zero
                break;
            case WEIGHT_SIMILARITY:
                weight = 1.0 - distances[i];
                break;
            default: // WEIGHT_NONE:
                weight = 1.0;
                break;
            }
            weight *= current.weight();
            try {
                switch (m_ClassType) {
                case Attribute.NOMINAL:
                    distribution[(int) current.classValue()] += weight;
                    break;
                case Attribute.NUMERIC:
                    distribution[0] += current.classValue() * weight;
                    break;
                }
            } catch (Exception ex) {
                throw new Error("Data has no class attribute!");
            }
            total += weight;
        }

        // Normalise distribution
        if (total > 0) {
            Utils.normalize(distribution, total);
        }
        return distribution;
    }

    /**
     * Select the best value for k by hold-one-out cross-validation.
     * If the class attribute is nominal, classification error is
     * minimised. If the class attribute is numeric, mean absolute
     * error is minimised
     */
    protected void crossValidate() {

        try {
            if (m_NNSearch instanceof weka.core.neighboursearch.CoverTree)
                throw new Exception("CoverTree doesn't support hold-one-out "
                        + "cross-validation. Use some other NN " + "method.");

            double[] performanceStats = new double[m_kNNUpper];
            double[] performanceStatsSq = new double[m_kNNUpper];

            for (int i = 0; i < m_kNNUpper; i++) {
                performanceStats[i] = 0;
                performanceStatsSq[i] = 0;
            }

            m_kNN = m_kNNUpper;
            Instance instance;
            Instances neighbours;
            double[] origDistances, convertedDistances;
            for (int i = 0; i < m_Train.numInstances(); i++) {
                if (m_Debug && (i % 50 == 0)) {
                    System.err.print("Cross validating " + i + "/" + m_Train.numInstances() + "\r");
                }
                instance = m_Train.instance(i);
                neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN);
                origDistances = m_NNSearch.getDistances();

                for (int j = m_kNNUpper - 1; j >= 0; j--) {
                    // Update the performance stats
                    convertedDistances = new double[origDistances.length];
                    System.arraycopy(origDistances, 0, convertedDistances, 0, origDistances.length);
                    double[] distribution = makeDistribution(neighbours, convertedDistances);
                    double thisPrediction = Utils.maxIndex(distribution);
                    if (m_Train.classAttribute().isNumeric()) {
                        thisPrediction = distribution[0];
                        double err = thisPrediction - instance.classValue();
                        performanceStatsSq[j] += err * err; // Squared error
                        performanceStats[j] += Math.abs(err); // Absolute error
                    } else {
                        if (thisPrediction != instance.classValue()) {
                            performanceStats[j]++; // Classification error
                        }
                    }
                    if (j >= 1) {
                        neighbours = pruneToK(neighbours, convertedDistances, j);
                    }
                }
            }

            // Display the results of the cross-validation
            for (int i = 0; i < m_kNNUpper; i++) {
                if (m_Debug) {
                    System.err.print("Hold-one-out performance of " + (i + 1) + " neighbors ");
                }
                if (m_Train.classAttribute().isNumeric()) {
                    if (m_Debug) {
                        if (m_MeanSquared) {
                            System.err.println(
                                    "(RMSE) = " + Math.sqrt(performanceStatsSq[i] / m_Train.numInstances()));
                        } else {
                            System.err.println("(MAE) = " + performanceStats[i] / m_Train.numInstances());
                        }
                    }
                } else {
                    if (m_Debug) {
                        System.err.println("(%ERR) = " + 100.0 * performanceStats[i] / m_Train.numInstances());
                    }
                }
            }

            // Check through the performance stats and select the best
            // k value (or the lowest k if more than one best)
            double[] searchStats = performanceStats;
            if (m_Train.classAttribute().isNumeric() && m_MeanSquared) {
                searchStats = performanceStatsSq;
            }
            double bestPerformance = Double.NaN;
            int bestK = 1;
            for (int i = 0; i < m_kNNUpper; i++) {
                if (Double.isNaN(bestPerformance) || (bestPerformance > searchStats[i])) {
                    bestPerformance = searchStats[i];
                    bestK = i + 1;
                }
            }
            m_kNN = bestK;
            if (m_Debug) {
                System.err.println("Selected k = " + bestK);
            }

            m_kNNValid = true;
        } catch (Exception ex) {
            throw new Error("Couldn't optimize by cross-validation: " + ex.getMessage());
        }
    }

    /**
     * Prunes the list to contain the k nearest neighbors. If there are
     * multiple neighbors at the k'th distance, all will be kept.
     *
     * @param neighbours the neighbour instances.
     * @param distances the distances of the neighbours from target instance.
     * @param k the number of neighbors to keep.
     * @return the pruned neighbours.
     */
    public Instances pruneToK(Instances neighbours, double[] distances, int k) {

        if (neighbours == null || distances == null || neighbours.numInstances() == 0) {
            return null;
        }
        if (k < 1) {
            k = 1;
        }

        int currentK = 0;
        double currentDist;
        for (int i = 0; i < neighbours.numInstances(); i++) {
            currentK++;
            currentDist = distances[i];
            if (currentK > k && currentDist != distances[i - 1]) {
                currentK--;
                neighbours = new Instances(neighbours, 0, currentK);
                break;
            }
        }

        return neighbours;
    }

    /**
     * Returns the revision string.
     * 
     * @return      the revision
     */
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8034 $");
    }

    /**
     * Main method for testing this class.
     *
     * @param argv should contain command line options (see setOptions)
     */
    public static void main(String[] argv) {
        runClassifier(new IBk(), argv);
    }
}