Example usage for weka.core Instances numAttributes

List of usage examples for weka.core Instances numAttributes

Introduction

In this page you can find the example usage for weka.core Instances numAttributes.

Prototype


publicint numAttributes() 

Source Link

Document

Returns the number of attributes.

Usage

From source file:com.sliit.neuralnetwork.RecurrentNN.java

public String trainModel(String modelName, String filePath, int outputs, int inputsTot) throws NeuralException {
    System.out.println("calling trainModel");
    try {/*  w  w  w .j a v  a 2s.  co  m*/

        System.out.println("Neural Network Training start");
        loadSaveNN(modelName, false);
        if (model == null) {

            buildModel();
        }

        File fileGeneral = new File(filePath);
        CSVLoader loader = new CSVLoader();
        loader.setSource(fileGeneral);
        Instances instances = loader.getDataSet();
        instances.setClassIndex(instances.numAttributes() - 1);
        StratifiedRemoveFolds stratified = new StratifiedRemoveFolds();
        String[] options = new String[6];
        options[0] = "-N";
        options[1] = Integer.toString(5);
        options[2] = "-F";
        options[3] = Integer.toString(1);
        options[4] = "-S";
        options[5] = Integer.toString(1);
        stratified.setOptions(options);
        stratified.setInputFormat(instances);
        stratified.setInvertSelection(false);
        Instances testInstances = Filter.useFilter(instances, stratified);
        stratified.setInvertSelection(true);
        Instances trainInstances = Filter.useFilter(instances, stratified);
        String directory = fileGeneral.getParent();
        CSVSaver saver = new CSVSaver();
        File trainFile = new File(directory + "/" + "normtrainadded.csv");
        File testFile = new File(directory + "/" + "normtestadded.csv");
        if (trainFile.exists()) {

            trainFile.delete();
        }
        trainFile.createNewFile();
        if (testFile.exists()) {

            testFile.delete();
        }
        testFile.createNewFile();
        saver.setFile(trainFile);
        saver.setInstances(trainInstances);
        saver.writeBatch();
        saver = new CSVSaver();
        saver.setFile(testFile);
        saver.setInstances(testInstances);
        saver.writeBatch();
        SequenceRecordReader recordReader = new CSVSequenceRecordReader(0, ",");
        recordReader.initialize(new org.datavec.api.split.FileSplit(trainFile));
        SequenceRecordReader testReader = new CSVSequenceRecordReader(0, ",");
        testReader.initialize(new org.datavec.api.split.FileSplit(testFile));
        DataSetIterator iterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator(
                recordReader, 2, outputs, inputsTot, false);
        DataSetIterator testIterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator(
                testReader, 2, outputs, inputsTot, false);
        roc = new ArrayList<Map<String, Double>>();
        String statMsg = "";
        Evaluation evaluation;

        for (int i = 0; i < 100; i++) {
            if (i % 2 == 0) {

                model.fit(iterator);
                evaluation = model.evaluate(testIterator);
            } else {

                model.fit(testIterator);
                evaluation = model.evaluate(iterator);
            }
            Map<String, Double> map = new HashMap<String, Double>();
            Map<Integer, Integer> falsePositives = evaluation.falsePositives();
            Map<Integer, Integer> trueNegatives = evaluation.trueNegatives();
            Map<Integer, Integer> truePositives = evaluation.truePositives();
            Map<Integer, Integer> falseNegatives = evaluation.falseNegatives();
            double fpr = falsePositives.get(1) / (falsePositives.get(1) + trueNegatives.get(1));
            double tpr = truePositives.get(1) / (truePositives.get(1) + falseNegatives.get(1));
            map.put("FPR", fpr);
            map.put("TPR", tpr);
            roc.add(map);
            statMsg = evaluation.stats();
            iterator.reset();
            testIterator.reset();
        }
        loadSaveNN(modelName, true);
        System.out.println("ROC " + roc);

        return statMsg;

    } catch (Exception e) {
        e.printStackTrace();
        System.out.println("Error ocuured while building neural netowrk :" + e.getMessage());
        throw new NeuralException(e.getLocalizedMessage(), e);
    }
}

From source file:com.sliit.normalize.NormalizeDataset.java

public String normalizeDataset() {
    System.out.println("start normalizing data");

    String filePathOut = "";
    try {//w  w  w.j a  v  a 2 s  . c  om

        CSVLoader loader = new CSVLoader();
        if (reducedDiemensionFile != null) {

            loader.setSource(reducedDiemensionFile);
        } else {
            if (tempFIle != null && tempFIle.exists()) {

                loader.setSource(tempFIle);
            } else {

                loader.setSource(csvFile);
            }
        }
        Instances dataInstance = loader.getDataSet();
        Normalize normalize = new Normalize();
        dataInstance.setClassIndex(dataInstance.numAttributes() - 1);
        normalize.setInputFormat(dataInstance);
        String directory = csvFile.getParent();
        outputFile = new File(directory + "/" + "normalized" + csvFile.getName());
        if (!outputFile.exists()) {

            outputFile.createNewFile();
        }
        CSVSaver saver = new CSVSaver();
        saver.setFile(outputFile);
        for (int i = 1; i < dataInstance.numInstances(); i++) {

            normalize.input(dataInstance.instance(i));
        }
        normalize.batchFinished();
        Instances outPut = new Instances(dataInstance, 0);
        for (int i = 1; i < dataInstance.numInstances(); i++) {

            outPut.add(normalize.output());
        }
        Attribute attribute = dataInstance.attribute(outPut.numAttributes() - 1);
        for (int j = 0; j < attribute.numValues(); j++) {

            if (attribute.value(j).equals("normal.")) {
                outPut.renameAttributeValue(attribute, attribute.value(j), "0");
            } else {
                outPut.renameAttributeValue(attribute, attribute.value(j), "1");
            }
        }
        saver.setInstances(outPut);
        saver.writeBatch();
        writeToNewFile(directory);
        filePathOut = directory + "norm" + csvFile.getName();
        if (tempFIle != null) {

            tempFIle.delete();
        }
        if (reducedDiemensionFile != null) {

            reducedDiemensionFile.delete();
        }
        outputFile.delete();
    } catch (IOException e) {

        log.error("Error occurred:" + e.getMessage());
    } catch (Exception e) {

        log.error("Error occurred:" + e.getMessage());
    }
    return filePathOut;
}

From source file:com.sliit.normalize.NormalizeDataset.java

public int whiteningData() {
    System.out.println("whiteningData");

    int nums = 0;
    try {//  w  w  w .j  a v  a 2 s.  c o m

        if (tempFIle != null && tempFIle.exists()) {

            csv.setSource(tempFIle);
        } else {

            csv.setSource(csvFile);
        }
        Instances instances = csv.getDataSet();
        if (instances.numAttributes() > 10) {
            instances.setClassIndex(instances.numAttributes() - 1);
            RandomProjection random = new RandomProjection();
            random.setDistribution(
                    new SelectedTag(RandomProjection.GAUSSIAN, RandomProjection.TAGS_DSTRS_TYPE));
            reducedDiemensionFile = new File(csvFile.getParent() + "/tempwhite.csv");
            if (!reducedDiemensionFile.exists()) {

                reducedDiemensionFile.createNewFile();
            }
            // CSVSaver saver = new CSVSaver();
            /// saver.setFile(reducedDiemensionFile);
            random.setInputFormat(instances);
            //saver.setRetrieval(AbstractSaver.INCREMENTAL);
            BufferedWriter writer = new BufferedWriter(new FileWriter(reducedDiemensionFile));
            for (int i = 0; i < instances.numInstances(); i++) {

                random.input(instances.instance(i));
                random.setNumberOfAttributes(10);
                random.setReplaceMissingValues(true);
                writer.write(random.output().toString());
                writer.newLine();
                //saver.writeIncremental(random.output());
            }
            writer.flush();
            writer.close();
            nums = random.getNumberOfAttributes();
        } else {

            nums = instances.numAttributes();
        }
    } catch (IOException e) {

        log.error("Error occurred:" + e.getMessage());
    } catch (Exception e) {

        log.error("Error occurred:" + e.getMessage());
    }
    return nums;
}

From source file:com.sliit.rules.RuleContainer.java

public String predictionResult(String filePath) throws Exception {

    File testPath = new File(filePath);
    CSVLoader loader = new CSVLoader();
    loader.setSource(testPath);//from ww  w  .j  a  v  a 2s .  c o  m
    Instances testInstances = loader.getDataSet();
    testInstances.setClassIndex(testInstances.numAttributes() - 1);
    Evaluation eval = new Evaluation(testInstances);
    eval.evaluateModel(ruleMoldel, testInstances);
    ArrayList<Prediction> predictions = eval.predictions();
    int predictedVal = (int) predictions.get(0).predicted();
    String cdetails = instances.classAttribute().value(predictedVal);
    return cdetails;
}

From source file:com.sliit.views.DataVisualizerPanel.java

void getRocCurve() {
    try {// w  w w . jav a2s. co m
        Instances data;
        data = new Instances(new BufferedReader(new FileReader(datasetPathText.getText())));
        data.setClassIndex(data.numAttributes() - 1);

        // train classifier
        Classifier cl = new NaiveBayes();
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(cl, data, 10, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")");
        vmc.setName(result.relationName());
        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();
        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++) {
            cp[n] = true;
        }
        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        // display curve
        String plotName = vmc.getName();
        final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName);
        jf.setSize(500, 400);
        jf.getContentPane().setLayout(new BorderLayout());
        jf.getContentPane().add(vmc, BorderLayout.CENTER);
        jf.addWindowListener(new java.awt.event.WindowAdapter() {
            public void windowClosing(java.awt.event.WindowEvent e) {
                jf.dispose();
            }
        });
        jf.setVisible(true);
    } catch (IOException ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    }
}

From source file:com.sliit.views.KNNView.java

void getRocCurve() {
    try {//from   w w  w.  j  a v  a 2 s. c om
        Instances data;
        data = new Instances(new BufferedReader(new java.io.FileReader(PredictorPanel.modalText.getText())));
        data.setClassIndex(data.numAttributes() - 1);

        // train classifier
        Classifier cl = new NaiveBayes();
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(cl, data, 10, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")");
        vmc.setName(result.relationName());
        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();
        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++) {
            cp[n] = true;
        }
        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        rocPanel.removeAll();
        rocPanel.add(vmc, "vmc", 0);
        rocPanel.revalidate();

    } catch (IOException ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    }
}

From source file:com.sliit.views.SVMView.java

/**
 * draw ROC curve/*from   w w  w . j a v  a  2s.co  m*/
 */
void getRocCurve() {
    try {
        Instances data;
        data = new Instances(new BufferedReader(new FileReader(PredictorPanel.modalText.getText())));
        data.setClassIndex(data.numAttributes() - 1);

        //train classifier
        Classifier cl = new NaiveBayes();
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(cl, data, 10, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")");
        vmc.setName(result.relationName());
        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();
        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++) {
            cp[n] = true;
        }
        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        //            rocPanel.removeAll();
        //            rocPanel.add(vmc, "vmc", 0);
        //            rocPanel.revalidate();
    } catch (IOException ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    }
}

From source file:com.spread.experiment.tempuntilofficialrelease.ClassificationViaClustering108.java

License:Open Source License

/**
 * Returns class probability distribution for the given instance.
 * //www .j a  v a 2s  .c o  m
 * @param instance the instance to be classified
 * @return the class probabilities
 * @throws Exception if an error occurred during the prediction
 */
@Override
public double[] distributionForInstance(Instance instance) throws Exception {

    if (m_ZeroR != null) {
        return m_ZeroR.distributionForInstance(instance);
    } else {
        double[] result = new double[instance.numClasses()];

        if (m_ActualClusterer != null) {
            // build new instance
            Instances tempData = m_ClusteringHeader.stringFreeStructure();
            double[] values = new double[tempData.numAttributes()];
            int n = 0;
            for (int i = 0; i < instance.numAttributes(); i++) {
                if (i == instance.classIndex()) {
                    continue;
                }
                if (instance.attribute(i).isString()) {
                    values[n] = tempData.attribute(n).addStringValue(instance.stringValue(i));
                } else if (instance.attribute(i).isRelationValued()) {
                    values[n] = tempData.attribute(n).addRelation(instance.relationalValue(i));
                } else {
                    values[n] = instance.value(i);
                }
                n++;
            }
            Instance newInst = new DenseInstance(instance.weight(), values);
            newInst.setDataset(tempData);

            if (!getLabelAllClusters()) {

                // determine cluster/class
                double r = m_ClustersToClasses[m_ActualClusterer.clusterInstance(newInst)];
                if (r == -1) {
                    return result; // Unclassified
                } else {
                    result[(int) r] = 1.0;
                    return result;
                }
            } else {
                double[] classProbs = new double[instance.numClasses()];
                double[] dist = m_ActualClusterer.distributionForInstance(newInst);
                for (int i = 0; i < dist.length; i++) {
                    for (int j = 0; j < instance.numClasses(); j++) {
                        classProbs[j] += dist[i] * m_ClusterClassProbs[i][j];
                    }
                }
                Utils.normalize(classProbs);
                return classProbs;
            }
        } else {
            return result; // Unclassified
        }
    }
}

From source file:com.tum.classifiertest.DataCache.java

License:Open Source License

/**
 * Creates a DataCache by copying data from a weka.core.Instances object.
 *///  ww w. j a  va 2s.  co m
public DataCache(Instances origData) throws Exception {

    classIndex = origData.classIndex();
    numAttributes = origData.numAttributes();
    numClasses = origData.numClasses();
    numInstances = origData.numInstances();

    attNumVals = new int[origData.numAttributes()];
    for (int i = 0; i < attNumVals.length; i++) {
        if (origData.attribute(i).isNumeric()) {
            attNumVals[i] = 0;
        } else if (origData.attribute(i).isNominal()) {
            attNumVals[i] = origData.attribute(i).numValues();
        } else
            throw new Exception("Only numeric and nominal attributes are supported.");
    }

    /* Array is indexed by attribute first, to speed access in RF splitting. */
    vals = new float[numAttributes][numInstances];
    for (int a = 0; a < numAttributes; a++) {
        for (int i = 0; i < numInstances; i++) {
            if (origData.instance(i).isMissing(a))
                vals[a][i] = Float.MAX_VALUE; // to make sure missing values go to the end
            else
                vals[a][i] = (float) origData.instance(i).value(a); // deep copy
        }
    }

    instWeights = new double[numInstances];
    instClassValues = new int[numInstances];
    for (int i = 0; i < numInstances; i++) {
        instWeights[i] = origData.instance(i).weight();
        instClassValues[i] = (int) origData.instance(i).classValue();
    }

    /* compute the sortedInstances for the whole dataset */

    sortedIndices = new int[numAttributes][];

    for (int a = 0; a < numAttributes; a++) { // ================= attr by attr

        if (a == classIndex)
            continue;

        if (attNumVals[a] > 0) { // ------------------------------------- nominal

            // Handling nominal attributes: as of FastRF 0.99, they're sorted as well
            // missing values are coded as Float.MAX_VALUE and go to the end

            sortedIndices[a] = new int[numInstances];
            //int count = 0;

            sortedIndices[a] = FastRfUtils.sort(vals[a]);

            /*for (int i = 0; i < numInstances; i++) {
              if ( !this.isValueMissing(a, i) ) {
                sortedIndices[a][count] = i;
                count++;
              }
            }
                    
            for (int i = 0; i < numInstances; i++) {
              if ( this.isValueMissing(a, i) ) {
                sortedIndices[a][count] = i;
                count++;
              }
            }*/

        } else { // ----------------------------------------------------- numeric

            // Sorted indices are computed for numeric attributes
            // missing values are coded as Float.MAX_VALUE and go to the end
            sortedIndices[a] = FastRfUtils.sort(vals[a]);

        } // ---------------------------------------------------------- attr kind

    } // ========================================================= attr by attr

    // System.out.println(" Done.");

}

From source file:com.tum.classifiertest.FastRandomForest.java

License:Open Source License

/**
 * Builds a classifier for a set of instances.
 *
 * @param data the instances to train the classifier with
 *
 * @throws Exception if something goes wrong
 *//*from ww w  .  j a v a  2s .c  o m*/
public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    // only class? -> build ZeroR model
    if (data.numAttributes() == 1) {
        System.err.println(
                "Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!");
        m_ZeroR = new weka.classifiers.rules.ZeroR();
        m_ZeroR.buildClassifier(data);
        return;
    } else {
        m_ZeroR = null;
    }

    /* Save header with attribute info. Can be accessed later by FastRfTrees
     * through their m_MotherForest field. */
    setM_Info(new Instances(data, 0));

    m_bagger = new FastRfBagging();

    // Set up the tree options which are held in the motherForest.
    m_KValue = m_numFeatures;
    if (m_KValue > data.numAttributes() - 1)
        m_KValue = data.numAttributes() - 1;
    if (m_KValue < 1)
        m_KValue = (int) Utils.log2(data.numAttributes()) + 1;

    FastRandomTree rTree = new FastRandomTree();
    rTree.m_MotherForest = this; // allows to retrieve KValue and MaxDepth
    // some temporary arrays which need to be separate for every tree, so
    // that the trees can be trained in parallel in different threads

    // set up the bagger and build the forest
    m_bagger.setClassifier(rTree);
    m_bagger.setSeed(m_randomSeed);
    m_bagger.setNumIterations(m_numTrees);
    m_bagger.setCalcOutOfBag(true);
    m_bagger.setComputeImportances(this.getComputeImportances());

    m_bagger.buildClassifier(data, m_NumThreads, this);

}