Example usage for weka.core Instances numClasses

List of usage examples for weka.core Instances numClasses

Introduction

In this page you can find the example usage for weka.core Instances numClasses.

Prototype


publicint numClasses() 

Source Link

Document

Returns the number of class labels.

Usage

From source file:tr.gov.ulakbim.jDenetX.experiments.wrappers.EvalActiveBoostingID.java

License:Open Source License

public static Instances clusterInstances(Instances data) {
    XMeans xmeans = new XMeans();
    Remove filter = new Remove();
    Instances dataClusterer = null;//from   w  ww.j  a  v a2 s  .com
    if (data == null) {
        throw new NullPointerException("Data is null at clusteredInstances method");
    }
    //Get the attributes from the data for creating the sampled_data object

    ArrayList<Attribute> attrList = new ArrayList<Attribute>();
    Enumeration attributes = data.enumerateAttributes();
    while (attributes.hasMoreElements()) {
        attrList.add((Attribute) attributes.nextElement());
    }

    Instances sampled_data = new Instances(data.relationName(), attrList, 0);
    data.setClassIndex(data.numAttributes() - 1);
    sampled_data.setClassIndex(data.numAttributes() - 1);
    filter.setAttributeIndices("" + (data.classIndex() + 1));
    data.remove(0);//In Wavelet Stream of MOA always the first element comes without class

    try {
        filter.setInputFormat(data);
        dataClusterer = Filter.useFilter(data, filter);
        String[] options = new String[4];
        options[0] = "-L"; // max. iterations
        options[1] = Integer.toString(noOfClassesInPool - 1);
        if (noOfClassesInPool > 2) {
            options[1] = Integer.toString(noOfClassesInPool - 1);
            xmeans.setMinNumClusters(noOfClassesInPool - 1);
        } else {
            options[1] = Integer.toString(noOfClassesInPool);
            xmeans.setMinNumClusters(noOfClassesInPool);
        }
        xmeans.setMaxNumClusters(data.numClasses() + 1);
        System.out.println("No of classes in the pool: " + noOfClassesInPool);
        xmeans.setUseKDTree(true);
        //xmeans.setOptions(options);
        xmeans.buildClusterer(dataClusterer);
        System.out.println("Xmeans\n:" + xmeans);
    } catch (Exception e) {
        e.printStackTrace();
    }
    //System.out.println("Assignments\n: " + assignments);
    ClusterEvaluation eval = new ClusterEvaluation();
    eval.setClusterer(xmeans);
    try {
        eval.evaluateClusterer(data);
        int classesToClustersMap[] = eval.getClassesToClusters();
        //check the classes to cluster map
        int clusterNo = 0;
        for (int i = 0; i < data.size(); i++) {
            clusterNo = xmeans.clusterInstance(dataClusterer.get(i));
            //Check if the class value of instance and class value of cluster matches
            if ((int) data.get(i).classValue() == classesToClustersMap[clusterNo]) {
                sampled_data.add(data.get(i));
            }
        }
    } catch (Exception e) {
        e.printStackTrace();
    }
    return ((Instances) sampled_data);
}

From source file:trainableSegmentation.Trainable_Segmentation.java

License:GNU General Public License

/**
 * Get probability distribution for classified instance concurrently
 * @param data classified set of instances
 * @param classifier current classifier// w  w  w  .ja  v a 2 s.c om
 * @return classification result
 */
private static Callable<double[][]> probFromInstances(final Instances data, final AbstractClassifier classifier,
        final AtomicInteger counter) {
    return new Callable<double[][]>() {
        public double[][] call() {
            final int numInstances = data.numInstances();
            final int numOfClasses = data.numClasses();
            final double[][] probabilityDistribution = new double[numOfClasses][numInstances];
            for (int i = 0; i < numInstances; i++) {
                try {
                    if (0 == i % 4000)
                        counter.addAndGet(4000);
                    double[] probs = classifier.distributionForInstance(data.instance(i));
                    for (int c = 0; c < numOfClasses; c++)
                        probabilityDistribution[c][i] = probs[c];
                } catch (Exception e) {
                    IJ.showMessage("Could not apply Classifier!");
                    e.printStackTrace();
                    return null;
                }
            }
            return probabilityDistribution;
        }
    };

}

From source file:trainableSegmentation.WekaSegmentation.java

License:GNU General Public License

/**
 * Classify a slice in a concurrent way//ww w  .j a  va 2 s.c om
 * @param slice image to classify
 * @param dataInfo empty set of instances containing the data structure (attributes and classes)
 * @param classifier classifier to use
 * @param counter counter used to display the progress in the tool bar
 * @param probabilityMaps flag to calculate probabilities or binary results
 * @return classification result
 */
public Callable<ImagePlus> classifySlice(final ImagePlus slice, final Instances dataInfo,
        final AbstractClassifier classifier, final AtomicInteger counter, final boolean probabilityMaps) {
    if (Thread.currentThread().isInterrupted())
        return null;

    return new Callable<ImagePlus>() {
        public ImagePlus call() {
            // Create feature stack for slice
            IJ.showStatus("Creating features...");
            IJ.log("Creating features of slice " + slice.getTitle() + "...");
            final FeatureStack sliceFeatures = new FeatureStack(slice);
            // Use the same features as the current classifier
            sliceFeatures.setEnabledFeatures(featureStackArray.getEnabledFeatures());
            sliceFeatures.setMaximumSigma(maximumSigma);
            sliceFeatures.setMinimumSigma(minimumSigma);
            sliceFeatures.setMembranePatchSize(membranePatchSize);
            sliceFeatures.setMembraneSize(membraneThickness);
            if (false == sliceFeatures.updateFeaturesST()) {
                IJ.log("Classifier execution was interrupted.");
                return null;
            }
            filterFeatureStackByList(featureNames, sliceFeatures);

            final int width = slice.getWidth();
            final int height = slice.getHeight();
            final int numClasses = dataInfo.numClasses();

            ImageStack classificationResult = new ImageStack(width, height);

            final int numInstances = width * height;

            final double[][] probArray;

            if (probabilityMaps)
                probArray = new double[numClasses][numInstances];
            else
                probArray = new double[1][numInstances];

            IJ.log("Classifying slice " + slice.getTitle() + "...");

            for (int x = 0; x < width; x++)
                for (int y = 0; y < height; y++) {
                    try {

                        if (0 == (x + y * width) % 4000) {
                            if (Thread.currentThread().isInterrupted())
                                return null;
                            counter.addAndGet(4000);
                        }

                        final DenseInstance ins = sliceFeatures.createInstance(x, y, 0);
                        ins.setDataset(dataInfo);

                        if (probabilityMaps) {
                            double[] prob = classifier.distributionForInstance(ins);
                            for (int k = 0; k < numClasses; k++) {
                                probArray[k][x + y * width] = prob[k];
                            }
                        } else {
                            probArray[0][x + y * width] = classifier.classifyInstance(ins);
                        }

                    } catch (Exception e) {

                        IJ.showMessage("Could not apply Classifier!");
                        e.printStackTrace();
                        return null;
                    }
                }

            if (probabilityMaps) {
                for (int k = 0; k < numClasses; k++)
                    classificationResult.addSlice("class-" + (k + 1),
                            new FloatProcessor(width, height, probArray[k]));
            } else
                classificationResult.addSlice("result", new FloatProcessor(width, height, probArray[0]));

            return new ImagePlus("classified-slice", classificationResult);
        }
    };
}

From source file:trainableSegmentation.WekaSegmentation.java

License:GNU General Public License

/**
 * Classify a list of images in a concurrent way
 * @param list of images to classify//ww w.  j  a v  a  2s  . com
 * @param dataInfo empty set of instances containing the data structure (attributes and classes)
 * @param classifier classifier to use
 * @param counter counter used to display the progress in the tool bar
 * @param probabilityMaps flag to calculate probabilities or binary results
 * @return classification result
 */
public Callable<ArrayList<ImagePlus>> classifyListOfImages(final ArrayList<ImagePlus> images,
        final Instances dataInfo, final AbstractClassifier classifier, final AtomicInteger counter,
        final boolean probabilityMaps) {
    if (Thread.currentThread().isInterrupted())
        return null;

    return new Callable<ArrayList<ImagePlus>>() {
        public ArrayList<ImagePlus> call() {
            ArrayList<ImagePlus> result = new ArrayList<ImagePlus>();

            for (ImagePlus image : images) {
                // Create feature stack for the image
                IJ.showStatus("Creating features...");
                IJ.log("Creating features of slice " + image.getTitle() + ", size = " + image.getWidth() + "x"
                        + image.getHeight() + "...");
                final FeatureStack sliceFeatures = new FeatureStack(image);
                // Use the same features as the current classifier
                sliceFeatures.setEnabledFeatures(featureStackArray.getEnabledFeatures());
                sliceFeatures.setMaximumSigma(maximumSigma);
                sliceFeatures.setMinimumSigma(minimumSigma);
                sliceFeatures.setMembranePatchSize(membranePatchSize);
                sliceFeatures.setMembraneSize(membraneThickness);
                if (false == sliceFeatures.updateFeaturesST()) {
                    IJ.log("Classifier execution was interrupted.");
                    return null;
                }
                filterFeatureStackByList(featureNames, sliceFeatures);

                final int width = image.getWidth();
                final int height = image.getHeight();
                final int numClasses = dataInfo.numClasses();

                ImageStack classificationResult = new ImageStack(width, height);

                final int numInstances = width * height;

                final double[][] probArray;

                if (probabilityMaps)
                    probArray = new double[numClasses][numInstances];
                else
                    probArray = new double[1][numInstances];

                IJ.log("Classifying slice " + image.getTitle() + "...");

                for (int x = 0; x < width; x++)
                    for (int y = 0; y < height; y++) {
                        try {

                            if (0 == (x + y * width) % 4000) {
                                if (Thread.currentThread().isInterrupted())
                                    return null;
                                counter.addAndGet(4000);
                            }

                            final DenseInstance ins = sliceFeatures.createInstance(x, y, 0);
                            ins.setDataset(dataInfo);

                            if (probabilityMaps) {
                                double[] prob = classifier.distributionForInstance(ins);
                                for (int k = 0; k < numClasses; k++) {
                                    probArray[k][x + y * width] = prob[k];
                                }
                            } else {
                                probArray[0][x + y * width] = classifier.classifyInstance(ins);
                            }

                        } catch (Exception e) {

                            IJ.showMessage("Could not apply Classifier!");
                            e.printStackTrace();
                            return null;
                        }
                    }

                if (probabilityMaps) {
                    for (int k = 0; k < numClasses; k++)
                        classificationResult.addSlice("class-" + (k + 1),
                                new FloatProcessor(width, height, probArray[k]));
                } else
                    classificationResult.addSlice("result", new FloatProcessor(width, height, probArray[0]));

                result.add(new ImagePlus("classified-image-" + image.getTitle(), classificationResult));
            }
            return result;
        }
    };
}

From source file:trainableSegmentation.WekaSegmentation.java

License:GNU General Public License

/**
 * Apply current classifier to set of instances
 * @param data set of instances/*from   w  ww. j a va  2  s . co  m*/
 * @param w image width
 * @param h image height
 * @param numThreads The number of threads to use. Set to zero for
 * auto-detection.
 * @return result image
 */
public ImagePlus applyClassifier(final Instances data, int w, int h, int numThreads, boolean probabilityMaps) {
    if (numThreads == 0)
        numThreads = Prefs.getThreads();

    final int numClasses = data.numClasses();
    final int numInstances = data.numInstances();
    final int numChannels = (probabilityMaps ? numClasses : 1);
    final int numSlices = (numChannels * numInstances) / (w * h);

    IJ.showStatus("Classifying image...");

    final long start = System.currentTimeMillis();

    ExecutorService exe = Executors.newFixedThreadPool(numThreads);
    final double[][][] results = new double[numThreads][][];
    final Instances[] partialData = new Instances[numThreads];
    final int partialSize = numInstances / numThreads;
    Future<double[][]> fu[] = new Future[numThreads];

    final AtomicInteger counter = new AtomicInteger();

    for (int i = 0; i < numThreads; i++) {
        if (Thread.currentThread().isInterrupted()) {
            exe.shutdown();
            return null;
        }
        if (i == numThreads - 1)
            partialData[i] = new Instances(data, i * partialSize, numInstances - i * partialSize);
        else
            partialData[i] = new Instances(data, i * partialSize, partialSize);

        AbstractClassifier classifierCopy = null;
        try {
            // The Weka random forest classifiers do not need to be duplicated on each thread 
            // (that saves much memory)            
            if (classifier instanceof FastRandomForest || classifier instanceof RandomForest)
                classifierCopy = classifier;
            else
                classifierCopy = (AbstractClassifier) (AbstractClassifier.makeCopy(classifier));

        } catch (Exception e) {
            IJ.log("Error: classifier could not be copied to classify in a multi-thread way.");
            e.printStackTrace();
        }
        fu[i] = exe.submit(classifyInstances(partialData[i], classifierCopy, counter, probabilityMaps));
    }

    ScheduledExecutorService monitor = Executors.newScheduledThreadPool(1);
    ScheduledFuture task = monitor.scheduleWithFixedDelay(new Runnable() {
        public void run() {
            IJ.showProgress(counter.get(), numInstances);
        }
    }, 0, 1, TimeUnit.SECONDS);

    // Join threads
    for (int i = 0; i < numThreads; i++) {
        try {
            results[i] = fu[i].get();
        } catch (InterruptedException e) {
            //e.printStackTrace();
            return null;
        } catch (ExecutionException e) {
            e.printStackTrace();
            return null;
        } finally {
            exe.shutdown();
            task.cancel(true);
            monitor.shutdownNow();
            IJ.showProgress(1);
        }
    }

    exe.shutdown();

    // Create final array
    double[][] classificationResult;
    classificationResult = new double[numChannels][numInstances];

    for (int i = 0; i < numThreads; i++)
        for (int c = 0; c < numChannels; c++)
            System.arraycopy(results[i][c], 0, classificationResult[c], i * partialSize, results[i][c].length);

    IJ.showProgress(1.0);
    final long end = System.currentTimeMillis();
    IJ.log("Classifying whole image data took: " + (end - start) + "ms");

    double[] classifiedSlice = new double[w * h];
    final ImageStack classStack = new ImageStack(w, h);

    for (int i = 0; i < numSlices / numChannels; i++) {
        for (int c = 0; c < numChannels; c++) {
            System.arraycopy(classificationResult[c], i * (w * h), classifiedSlice, 0, w * h);
            ImageProcessor classifiedSliceProcessor = new FloatProcessor(w, h, classifiedSlice);
            classStack.addSlice(probabilityMaps ? getClassLabels()[c] : "", classifiedSliceProcessor);
        }
    }
    ImagePlus classImg = new ImagePlus(probabilityMaps ? "Probability maps" : "Classification result",
            classStack);

    return classImg;
}

From source file:trainableSegmentation.WekaSegmentation.java

License:GNU General Public License

/**
 * Classify instances concurrently//from   w w  w.j a v a 2s.c  om
 * 
 * @param fsa feature stack array with the feature vectors
 * @param dataInfo empty set of instances containing the data structure (attributes and classes)
 * @param first index of the first instance to classify (considering the feature stack array as a 1D array)
 * @param numInstances number of instances to classify in this thread
 * @param classifier current classifier
 * @param counter auxiliary counter to be able to update the progress bar
 * @param probabilityMaps if true return a probability map for each class instead of a classified image
 * @return classification result
 */
private static Callable<double[][]> classifyInstances(final FeatureStackArray fsa, final Instances dataInfo,
        final int first, final int numInstances, final AbstractClassifier classifier,
        final AtomicInteger counter, final boolean probabilityMaps) {
    if (Thread.currentThread().isInterrupted())
        return null;

    return new Callable<double[][]>() {

        public double[][] call() {

            final double[][] classificationResult;

            final int width = fsa.getWidth();
            final int height = fsa.getHeight();
            final int sliceSize = width * height;
            final int numClasses = dataInfo.numClasses();

            if (probabilityMaps)
                classificationResult = new double[numClasses][numInstances];
            else
                classificationResult = new double[1][numInstances];

            for (int i = 0; i < numInstances; i++) {
                try {

                    if (0 == i % 4000) {
                        if (Thread.currentThread().isInterrupted())
                            return null;
                        counter.addAndGet(4000);
                    }

                    final int absolutePos = first + i;
                    final int slice = absolutePos / sliceSize;
                    final int localPos = absolutePos - slice * sliceSize;
                    final int x = localPos % width;
                    final int y = localPos / width;
                    DenseInstance ins = fsa.get(slice).createInstance(x, y, 0);
                    ins.setDataset(dataInfo);

                    if (probabilityMaps) {
                        double[] prob = classifier.distributionForInstance(ins);
                        for (int k = 0; k < numClasses; k++)
                            classificationResult[k][i] = prob[k];
                    } else {
                        classificationResult[0][i] = classifier.classifyInstance(ins);
                    }

                } catch (Exception e) {

                    IJ.showMessage("Could not apply Classifier!");
                    e.printStackTrace();
                    return null;
                }
            }
            return classificationResult;
        }
    };
}

From source file:trainableSegmentation.WekaSegmentation.java

License:GNU General Public License

/**
 * Classify instances concurrently/*  www  . j  ava 2 s  .  com*/
 * 
 * @param data set of instances to classify
 * @param classifier current classifier
 * @param counter auxiliary counter to be able to update the progress bar
 * @param probabilityMaps return a probability map for each class instead of a
 * classified image
 * @return classification result
 */
private static Callable<double[][]> classifyInstances(final Instances data, final AbstractClassifier classifier,
        final AtomicInteger counter, final boolean probabilityMaps) {
    if (Thread.currentThread().isInterrupted())
        return null;

    return new Callable<double[][]>() {

        public double[][] call() {

            final int numInstances = data.numInstances();
            final int numClasses = data.numClasses();

            final double[][] classificationResult;

            if (probabilityMaps)
                classificationResult = new double[numClasses][numInstances];
            else
                classificationResult = new double[1][numInstances];

            for (int i = 0; i < numInstances; i++) {
                try {

                    if (0 == i % 4000) {
                        if (Thread.currentThread().isInterrupted())
                            return null;
                        counter.addAndGet(4000);
                    }

                    if (probabilityMaps) {
                        double[] prob = classifier.distributionForInstance(data.get(i));
                        for (int k = 0; k < numClasses; k++)
                            classificationResult[k][i] = prob[k];
                    } else {
                        classificationResult[0][i] = classifier.classifyInstance(data.get(i));
                    }

                } catch (Exception e) {

                    IJ.showMessage("Could not apply Classifier!");
                    e.printStackTrace();
                    return null;
                }
            }
            return classificationResult;
        }
    };
}

From source file:tubes2ai.AIJKFFNN.java

@Override
public void buildClassifier(Instances instances) throws Exception {
    getCapabilities().testWithFail(instances);
    int nInputNeuron, nOutputNeuron;

    /* Inisialisasi tiap layer */
    nInputNeuron = instances.numAttributes() - 1;
    nOutputNeuron = instances.numClasses();
    inputLayer = new Vector<Neuron>(nInputNeuron);
    hiddenLayer = new Vector<Neuron>(nHiddenNeuron);
    outputLayer = new Vector<Neuron>(nOutputNeuron);

    Random random = new Random(getSeed());

    Enumeration<Attribute> attributeEnumeration = instances.enumerateAttributes();
    attributeList = Collections.list(attributeEnumeration);

    /* Mengisi layer dengan neuron-neuron dengan weight default */
    for (int k = 0; k < nOutputNeuron; k++) {
        outputLayer.add(new Neuron());
    }/*from  w  ww  . j  a va  2  s  .  co  m*/

    for (int k = 0; k < nInputNeuron; k++) {
        inputLayer.add(new Neuron());
    }

    /* Kalau ada hidden layer */
    if (nHiddenLayer > 0) {
        for (int j = 0; j < nHiddenNeuron; j++) {
            hiddenLayer.add(new Neuron());
        }
    }

    /* Link */
    if (nHiddenLayer > 0) {
        linkNeurons(inputLayer, hiddenLayer);
        linkNeurons(hiddenLayer, outputLayer);
    } else {
        linkNeurons(inputLayer, outputLayer);
    }

    for (Neuron neuron : inputLayer) {
        neuron.initialize(random);
    }

    inputLayerArray = new Neuron[nInputNeuron];
    int i = 0;
    for (Neuron neuron : inputLayer) {
        inputLayerArray[i] = neuron;
        i++;
    }

    outputCalculationArray = new Neuron[nHiddenLayer * nHiddenNeuron + nOutputNeuron];
    int j = 0;
    for (Neuron neuron : hiddenLayer) {
        outputCalculationArray[j] = neuron;
        j++;
    }
    for (Neuron neuron : outputLayer) {
        outputCalculationArray[j] = neuron;
        j++;
    }

    if (nHiddenLayer > 0) {
        for (Neuron neuron : hiddenLayer) {
            neuron.initialize(random);
        }

    }

    for (Neuron neuron : outputLayer) {
        neuron.initialize(random);
    }

    /* Learning */
    int iterations = 0;
    List<Double> errors = new ArrayList<>();
    do {
        for (Instance instance : instances) {
            /* Memasukkan instance ke input neuron */
            loadInput(instance);

            /* Menghitung error dari layer output ke input */
            /* Menyiapkan nilai target */
            for (int ix = 0; ix < outputLayer.size(); ix++) {
                if (ix == (int) instance.classValue()) {
                    outputLayer.get(ix).errorFromTarget(1);
                } else {
                    outputLayer.get(ix).errorFromTarget(0);
                }
            }
            if (nHiddenLayer != 0) {
                for (Neuron nHid : hiddenLayer) {
                    nHid.calculateError();
                }
            }

            /* Update Weight */

            for (int k = 0; k < outputCalculationArray.length; k++) {
                outputCalculationArray[k].updateWeights(learningRate);
            }
        }

        iterations++;

        if (iterations % 500 == 0) {
            System.out.println("FFNN iteration " + iterations);
        }

    } while (iterations < maxIterations);

}

From source file:uzholdem.classifier.OnlineMultilayerPerceptron.java

License:Open Source License

/**
 * Call this function to build and train a neural network for the training
 * data provided.//from   w ww. j  a  v  a 2 s  . c  o m
 * @param i The training data.
 * @throws Exception if can't build classification properly.
 */
public void buildClassifier(Instances i) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(i);

    // remove instances with missing class
    i = new Instances(i);
    i.deleteWithMissingClass();

    // only class? -> build ZeroR model
    if (i.numAttributes() == 1) {
        System.err.println(
                "Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!");
        m_ZeroR = new weka.classifiers.rules.ZeroR();
        m_ZeroR.buildClassifier(i);
        return;
    } else {
        m_ZeroR = null;
    }

    m_epoch = 0;
    m_error = 0;
    m_instances = null;
    m_currentInstance = null;
    m_controlPanel = null;
    m_nodePanel = null;

    m_outputs = new NeuralEnd[0];
    m_inputs = new NeuralEnd[0];
    m_numAttributes = 0;
    m_numClasses = 0;
    m_neuralNodes = new NeuralConnection[0];

    m_selected = new FastVector(4);
    m_graphers = new FastVector(2);
    m_nextId = 0;
    m_stopIt = true;
    m_stopped = true;
    m_accepted = false;
    m_instances = new Instances(i);
    m_random = new Random(m_randomSeed);
    m_instances.randomize(m_random);

    if (m_useNomToBin) {
        m_nominalToBinaryFilter = new NominalToBinary();
        m_nominalToBinaryFilter.setInputFormat(m_instances);
        m_instances = Filter.useFilter(m_instances, m_nominalToBinaryFilter);
    }
    m_numAttributes = m_instances.numAttributes() - 1;
    m_numClasses = m_instances.numClasses();

    setClassType(m_instances);

    //this sets up the validation set.
    Instances valSet = null;
    //numinval is needed later
    int numInVal = (int) (m_valSize / 100.0 * m_instances.numInstances());
    if (m_valSize > 0) {
        if (numInVal == 0) {
            numInVal = 1;
        }
        valSet = new Instances(m_instances, 0, numInVal);
    }
    ///////////

    setupInputs();

    setupOutputs();
    if (m_autoBuild) {
        setupHiddenLayer();
    }

    /////////////////////////////
    //this sets up the gui for usage
    if (m_gui) {
        m_win = new JFrame();

        m_win.addWindowListener(new WindowAdapter() {
            public void windowClosing(WindowEvent e) {
                boolean k = m_stopIt;
                m_stopIt = true;
                int well = JOptionPane
                        .showConfirmDialog(m_win,
                                "Are You Sure...\n" + "Click Yes To Accept" + " The Neural Network"
                                        + "\n Click No To Return",
                                "Accept Neural Network", JOptionPane.YES_NO_OPTION);

                if (well == 0) {
                    m_win.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
                    m_accepted = true;
                    blocker(false);
                } else {
                    m_win.setDefaultCloseOperation(JFrame.DO_NOTHING_ON_CLOSE);
                }
                m_stopIt = k;
            }
        });

        m_win.getContentPane().setLayout(new BorderLayout());
        m_win.setTitle("Neural Network");
        m_nodePanel = new NodePanel();
        // without the following two lines, the NodePanel.paintComponents(Graphics) 
        // method will go berserk if the network doesn't fit completely: it will
        // get called on a constant basis, using 100% of the CPU
        // see the following forum thread:
        // http://forum.java.sun.com/thread.jspa?threadID=580929&messageID=2945011
        m_nodePanel.setPreferredSize(new Dimension(640, 480));
        m_nodePanel.revalidate();

        JScrollPane sp = new JScrollPane(m_nodePanel, JScrollPane.VERTICAL_SCROLLBAR_ALWAYS,
                JScrollPane.HORIZONTAL_SCROLLBAR_NEVER);
        m_controlPanel = new ControlPanel();

        m_win.getContentPane().add(sp, BorderLayout.CENTER);
        m_win.getContentPane().add(m_controlPanel, BorderLayout.SOUTH);
        m_win.setSize(640, 480);
        m_win.setVisible(true);
    }

    //This sets up the initial state of the gui
    if (m_gui) {
        blocker(true);
        m_controlPanel.m_changeEpochs.setEnabled(false);
        m_controlPanel.m_changeLearning.setEnabled(false);
        m_controlPanel.m_changeMomentum.setEnabled(false);
    }

    //For silly situations in which the network gets accepted before training
    //commenses
    if (m_numeric) {
        setEndsToLinear();
    }
    if (m_accepted) {
        m_win.dispose();
        m_controlPanel = null;
        m_nodePanel = null;
        m_instances = new Instances(m_instances, 0);
        return;
    }

    //connections done.
    double right = 0;
    double driftOff = 0;
    double lastRight = Double.POSITIVE_INFINITY;
    double bestError = Double.POSITIVE_INFINITY;
    double tempRate;
    double totalWeight = 0;
    double totalValWeight = 0;
    double origRate = m_learningRate; //only used for when reset

    //ensure that at least 1 instance is trained through.
    if (numInVal == m_instances.numInstances()) {
        numInVal--;
    }
    if (numInVal < 0) {
        numInVal = 0;
    }
    for (int noa = numInVal; noa < m_instances.numInstances(); noa++) {
        if (!m_instances.instance(noa).classIsMissing()) {
            totalWeight += m_instances.instance(noa).weight();
        }
    }
    if (m_valSize != 0) {
        for (int noa = 0; noa < valSet.numInstances(); noa++) {
            if (!valSet.instance(noa).classIsMissing()) {
                totalValWeight += valSet.instance(noa).weight();
            }
        }
    }
    m_stopped = false;

    for (int noa = 1; noa < m_numEpochs + 1; noa++) {
        right = 0;
        for (int nob = numInVal; nob < m_instances.numInstances(); nob++) {
            m_currentInstance = m_instances.instance(nob);

            if (!m_currentInstance.classIsMissing()) {

                //this is where the network updating (and training occurs, for the
                //training set
                resetNetwork();
                calculateOutputs();
                tempRate = m_learningRate * m_currentInstance.weight();
                if (m_decay) {
                    tempRate /= noa;
                }

                right += (calculateErrors() / m_instances.numClasses()) * m_currentInstance.weight();
                updateNetworkWeights(tempRate, m_momentum);

            }

        }
        right /= totalWeight;
        if (Double.isInfinite(right) || Double.isNaN(right)) {
            if (!m_reset) {
                m_instances = null;
                throw new Exception("Network cannot train. Try restarting with a" + " smaller learning rate.");
            } else {
                //reset the network if possible
                if (m_learningRate <= Utils.SMALL)
                    throw new IllegalStateException(
                            "Learning rate got too small (" + m_learningRate + " <= " + Utils.SMALL + ")!");
                m_learningRate /= 2;
                buildClassifier(i);
                m_learningRate = origRate;
                m_instances = new Instances(m_instances, 0);
                return;
            }
        }

        ////////////////////////do validation testing if applicable
        if (m_valSize != 0) {
            right = 0;
            for (int nob = 0; nob < valSet.numInstances(); nob++) {
                m_currentInstance = valSet.instance(nob);
                if (!m_currentInstance.classIsMissing()) {
                    //this is where the network updating occurs, for the validation set
                    resetNetwork();
                    calculateOutputs();
                    right += (calculateErrors() / valSet.numClasses()) * m_currentInstance.weight();
                    //note 'right' could be calculated here just using
                    //the calculate output values. This would be faster.
                    //be less modular
                }

            }

            if (right < lastRight) {
                if (right < bestError) {
                    bestError = right;
                    // save the network weights at this point
                    for (int noc = 0; noc < m_numClasses; noc++) {
                        m_outputs[noc].saveWeights();
                    }
                    driftOff = 0;
                }
            } else {
                driftOff++;
            }
            lastRight = right;
            if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) {
                for (int noc = 0; noc < m_numClasses; noc++) {
                    m_outputs[noc].restoreWeights();
                }
                m_accepted = true;
            }
            right /= totalValWeight;
        }
        m_epoch = noa;
        m_error = right;
        //shows what the neuralnet is upto if a gui exists. 
        updateDisplay();
        //This junction controls what state the gui is in at the end of each
        //epoch, Such as if it is paused, if it is resumable etc...
        if (m_gui) {
            while ((m_stopIt || (m_epoch >= m_numEpochs && m_valSize == 0)) && !m_accepted) {
                m_stopIt = true;
                m_stopped = true;
                if (m_epoch >= m_numEpochs && m_valSize == 0) {

                    m_controlPanel.m_startStop.setEnabled(false);
                } else {
                    m_controlPanel.m_startStop.setEnabled(true);
                }
                m_controlPanel.m_startStop.setText("Start");
                m_controlPanel.m_startStop.setActionCommand("Start");
                m_controlPanel.m_changeEpochs.setEnabled(true);
                m_controlPanel.m_changeLearning.setEnabled(true);
                m_controlPanel.m_changeMomentum.setEnabled(true);

                blocker(true);
                if (m_numeric) {
                    setEndsToLinear();
                }
            }
            m_controlPanel.m_changeEpochs.setEnabled(false);
            m_controlPanel.m_changeLearning.setEnabled(false);
            m_controlPanel.m_changeMomentum.setEnabled(false);

            m_stopped = false;
            //if the network has been accepted stop the training loop
            if (m_accepted) {
                m_win.dispose();
                m_controlPanel = null;
                m_nodePanel = null;
                m_instances = new Instances(m_instances, 0);
                return;
            }
        }
        if (m_accepted) {
            m_instances = new Instances(m_instances, 0);
            return;
        }
    }
    if (m_gui) {
        m_win.dispose();
        m_controlPanel = null;
        m_nodePanel = null;
    }
    m_instances = new Instances(m_instances, 0);
}

From source file:uzholdem.classifier.OnlineMultilayerPerceptron.java

License:Open Source License

public void trainModel(Instances aInstances, int numIterations) throws Exception {

    // setup m_instances
    if (this.m_instances == null) {

        this.m_instances = new Instances(aInstances, 0, aInstances.size());
    }//w  w w  . j  a  v a2s  . com
    ///////////

    if (m_useNomToBin) {
        if (this.m_nominalToBinaryFilter == null) {
            m_nominalToBinaryFilter = new NominalToBinary();
            try {
                m_nominalToBinaryFilter.setInputFormat(m_instances);
            } catch (Exception e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
                return;
            }
        }
        aInstances = Filter.useFilter(aInstances, m_nominalToBinaryFilter);
    }

    Instances epochInstances = new Instances(aInstances);
    epochInstances.randomize(new Random());

    Instances valSet = new Instances(aInstances, (int) (aInstances.size() * 0.3));
    for (int i = 0; i < valSet.size(); i++) {
        valSet.add(epochInstances.instance(0));
        epochInstances.delete(0);
    }

    m_instances = epochInstances;
    double right = 0;
    double driftOff = 0;
    double lastRight = Double.POSITIVE_INFINITY;
    double bestError = Double.POSITIVE_INFINITY;
    double tempRate;
    double totalWeight = 0;
    double totalValWeight = 0;
    double origRate = m_learningRate; //only used for when reset

    int numInVal = valSet.numInstances();

    for (int noa = numInVal; noa < m_instances.numInstances(); noa++) {
        if (!m_instances.instance(noa).classIsMissing()) {
            totalWeight += m_instances.instance(noa).weight();
        }
    }
    if (m_valSize != 0) {
        for (int noa = 0; noa < valSet.numInstances(); noa++) {
            if (!valSet.instance(noa).classIsMissing()) {
                totalValWeight += valSet.instance(noa).weight();
            }
        }
    }
    m_stopped = false;

    for (int noa = 1; noa < 50 + 1; noa++) {
        right = 0;
        for (int nob = numInVal; nob < m_instances.numInstances(); nob++) {
            m_currentInstance = m_instances.instance(nob);

            if (!m_currentInstance.classIsMissing()) {

                //this is where the network updating (and training occurs, for the
                //training set
                resetNetwork();
                calculateOutputs();
                tempRate = m_learningRate * m_currentInstance.weight();
                if (m_decay) {
                    tempRate /= noa;
                }

                right += (calculateErrors() / m_instances.numClasses()) * m_currentInstance.weight();
                updateNetworkWeights(tempRate, m_momentum);

            }

        }
        right /= totalWeight;
        if (Double.isInfinite(right) || Double.isNaN(right)) {

            m_instances = null;
            throw new Exception("Network cannot train. Try restarting with a" + " smaller learning rate.");

        }

        ////////////////////////do validation testing if applicable
        if (m_valSize != 0) {
            right = 0;
            for (int nob = 0; nob < valSet.numInstances(); nob++) {
                m_currentInstance = valSet.instance(nob);
                if (!m_currentInstance.classIsMissing()) {
                    //this is where the network updating occurs, for the validation set
                    resetNetwork();
                    calculateOutputs();
                    right += (calculateErrors() / valSet.numClasses()) * m_currentInstance.weight();
                    //note 'right' could be calculated here just using
                    //the calculate output values. This would be faster.
                    //be less modular
                }

            }

            if (right < lastRight) {
                if (right < bestError) {
                    bestError = right;
                    // save the network weights at this point
                    for (int noc = 0; noc < m_numClasses; noc++) {
                        m_outputs[noc].saveWeights();
                    }
                    driftOff = 0;
                }
            } else {
                driftOff++;
            }
            lastRight = right;
            if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) {
                for (int noc = 0; noc < m_numClasses; noc++) {
                    m_outputs[noc].restoreWeights();
                }
                m_accepted = true;
            }
            right /= totalValWeight;
        }
        m_epoch = noa;
        m_error = right;
        //shows what the neuralnet is upto if a gui exists. 

        if (m_accepted) {
            m_instances = new Instances(m_instances, 0);
            return;
        }
    }

}