Example usage for weka.core Instances instance

List of usage examples for weka.core Instances instance

Introduction

In this page you can find the example usage for weka.core Instances instance.

Prototype



publicInstance instance(int index) 

Source Link

Document

Returns the instance at the given position.

Usage

From source file:milk.experiment.MIInstancesResultListener.java

License:Open Source License

/**
 * Perform any postprocessing. When this method is called, it indicates
 * that no more results will be sent that need to be grouped together
 * in any way.//from w  w  w . j  a v  a  2  s  . c  o m
 *
 * @param rp the ResultProducer that generated the results
 * @exception Exception if an error occurs
 */
public void postProcess(MIResultProducer rp) throws Exception {

    if (m_RP != rp) {
        throw new Error("Unrecognized ResultProducer sending results!!");
    }
    String[] keyNames = m_RP.getKeyNames();
    String[] resultNames = m_RP.getResultNames();
    FastVector attribInfo = new FastVector();
    for (int i = 0; i < m_AttributeTypes.length; i++) {
        String attribName = "Unknown";
        if (i < keyNames.length) {
            attribName = "Key_" + keyNames[i];
        } else {
            attribName = resultNames[i - keyNames.length];
        }

        switch (m_AttributeTypes[i]) {
        case Attribute.NOMINAL:
            if (m_NominalStrings[i].size() > 0) {
                attribInfo.addElement(new Attribute(attribName, m_NominalStrings[i]));
            } else {
                attribInfo.addElement(new Attribute(attribName, (FastVector) null));
            }
            break;
        case Attribute.NUMERIC:
            attribInfo.addElement(new Attribute(attribName));
            break;
        case Attribute.STRING:
            attribInfo.addElement(new Attribute(attribName, (FastVector) null));
            break;
        default:
            throw new Exception("Unknown attribute type");
        }
    }

    Instances result = new Instances("InstanceResultListener", attribInfo, m_Instances.size());
    for (int i = 0; i < m_Instances.size(); i++) {
        result.add((Instance) m_Instances.elementAt(i));
    }

    m_Out.println(new Instances(result, 0));
    for (int i = 0; i < result.numInstances(); i++) {
        m_Out.println(result.instance(i));
    }

    if (!(m_OutputFile == null) && !(m_OutputFile.getName().equals("-"))) {
        m_Out.close();
    }
}

From source file:milk.gui.experiment.MIResultsPanel.java

License:Open Source License

/**
 * Queries the user enough to make a database query to retrieve experiment
 * results./*from   w w w  .  j  a v a2 s  . c o  m*/
 */
protected void setInstancesFromDBaseQuery() {

    try {
        if (m_InstanceQuery == null) {
            m_InstanceQuery = new MIInstanceQuery();
        }
        String dbaseURL = m_InstanceQuery.getDatabaseURL();
        dbaseURL = (String) JOptionPane.showInputDialog(this, "Enter the database URL", "Query Database",
                JOptionPane.PLAIN_MESSAGE, null, null, dbaseURL);
        if (dbaseURL == null) {
            m_FromLab.setText("Cancelled");
            return;
        }
        m_InstanceQuery.setDatabaseURL(dbaseURL);
        m_InstanceQuery.connectToDatabase();
        if (!m_InstanceQuery.experimentIndexExists()) {
            m_FromLab.setText("No experiment index");
            return;
        }
        m_FromLab.setText("Getting experiment index");
        Instances index = m_InstanceQuery.retrieveInstances("SELECT * FROM " + MIInstanceQuery.EXP_INDEX_TABLE);
        if (index.numInstances() == 0) {
            m_FromLab.setText("No experiments available");
            return;
        }
        m_FromLab.setText("Got experiment index");

        DefaultListModel lm = new DefaultListModel();
        for (int i = 0; i < index.numInstances(); i++) {
            lm.addElement(index.instance(i).toString());
        }
        JList jl = new JList(lm);
        ListSelectorDialog jd = new ListSelectorDialog(null, jl);
        int result = jd.showDialog();
        if (result != ListSelectorDialog.APPROVE_OPTION) {
            m_FromLab.setText("Cancelled");
            return;
        }
        Instance selInst = index.instance(jl.getSelectedIndex());
        Attribute tableAttr = index.attribute(MIInstanceQuery.EXP_RESULT_COL);
        String table = MIInstanceQuery.EXP_RESULT_PREFIX + selInst.toString(tableAttr);

        setInstancesFromDatabaseTable(table);
    } catch (Exception ex) {
        m_FromLab.setText("Problem reading database");
    }
}

From source file:milk.visualize.MIPlot2D.java

License:Open Source License

/**
 * Renders this component//from  w  w  w.  j a va2s  .  c  om
 * @param gx the graphics context
 */
public void paintComponent(Graphics gx) {

    //if(!isEnabled())
    //    return;

    super.paintComponent(gx);

    if (plotExemplars != null) {
        gx.setColor(m_axisColour);
        // Draw the axis name
        String xname = plotExemplars.attribute(m_xIndex).name(),
                yname = plotExemplars.attribute(m_yIndex).name();
        gx.drawString(yname, m_XaxisStart + m_labelMetrics.stringWidth("M"),
                m_YaxisStart + m_labelMetrics.getAscent() / 2 + m_tickSize);
        gx.drawString(xname, m_XaxisEnd - m_labelMetrics.stringWidth(yname) + m_tickSize,
                (int) (m_YaxisEnd - m_labelMetrics.getAscent() / 2));

        // Draw points
        Attribute classAtt = plotExemplars.classAttribute();
        for (int j = 0; j < m_plots.size(); j++) {
            PlotData2D temp_plot = (PlotData2D) (m_plots.elementAt(j));
            Instances instances = temp_plot.getPlotInstances();

            StringTokenizer st = new StringTokenizer(
                    instances.firstInstance().stringValue(plotExemplars.idIndex()), "_");

            //////////////////// TLD stuff /////////////////////////////////
            /*
            double[] mu = new double[plotExemplars.numAttributes()],
                sgm = new double[plotExemplars.numAttributes()];
            st.nextToken(); // Squeeze first element
            int p=0;
            while(p<mu.length){
                if((p==plotExemplars.idIndex()) || (p==plotExemplars.classIndex()))
               p++;
                if(p<mu.length){
               mu[p] = Double.parseDouble(st.nextToken());
               sgm[p] = Double.parseDouble(st.nextToken());
               p++;
                }
            }
            Instance ins = instances.firstInstance();
            gx.setColor((Color)m_colorList.elementAt((int)ins.classValue()));
            double mux=mu[m_xIndex], muy=mu[m_yIndex],
                sgmx=sgm[m_xIndex], sgmy=sgm[m_yIndex];
            double xs = convertToPanelX(mux-3*sgmx), xe = convertToPanelX(mux+3*sgmx),
                xleng = Math.abs(xe-xs);
            double ys = convertToPanelY(muy+3*sgmy), ye = convertToPanelY(muy-3*sgmy),
                yleng = Math.abs(ye-ys);
            // Draw oval
            gx.drawOval((int)xs,(int)ys,(int)xleng,(int)yleng);
            // Draw a dot
            gx.fillOval((int)convertToPanelX(mux)-2, (int)convertToPanelY(muy)-2, 4, 4);
            */
            //////////////////// TLD stuff /////////////////////////////////

            //////////////////// instance-based stuff /////////////////////////////////
            /*
              double[] core = new double[plotExemplars.numAttributes()],
                range=new double[plotExemplars.numAttributes()];
            st.nextToken(); // Squeeze first element
            int p=0;
            while(p<range.length){
                if((p==plotExemplars.idIndex()) || (p==plotExemplars.classIndex()))
               p++;
                if(p<range.length)
               range[p++] = Double.parseDouble(st.nextToken());
            }
                    
            p=0;
            while(st.hasMoreTokens()){
                if((p==plotExemplars.idIndex()) || (p==plotExemplars.classIndex()))
               p++;
                core[p++] = Double.parseDouble(st.nextToken());
            }
                    
            Instance ins = instances.firstInstance();
            gx.setColor((Color)m_colorList.elementAt((int)ins.classValue()));
            double rgx=range[m_xIndex], rgy=range[m_yIndex];
            double x1 = convertToPanelX(core[m_xIndex]-rgx/2),
                y1 = convertToPanelY(core[m_yIndex]-rgy/2),
                x2 = convertToPanelX(core[m_xIndex]+rgx/2),
                y2 = convertToPanelY(core[m_yIndex]+rgy/2),
                x = convertToPanelX(core[m_xIndex]),
                y = convertToPanelY(core[m_yIndex]);
                    
            // Draw a rectangle
            gx.drawLine((int)x1, (int)y1, (int)x2, (int)y1);
            gx.drawLine((int)x1, (int)y1, (int)x1, (int)y2);
            gx.drawLine((int)x2, (int)y1, (int)x2, (int)y2);
            gx.drawLine((int)x1, (int)y2, (int)x2, (int)y2);
                    
            // Draw a dot
            gx.fillOval((int)x-3, (int)y-3, 6, 6);
                    
            // Draw string
            StringBuffer text =new StringBuffer(temp_plot.getPlotName()+":"+instances.numInstances());      
            gx.drawString(text.toString(), (int)x1, (int)y2+m_labelMetrics.getHeight());
            */
            //////////////////// instance-based stuff /////////////////////////////////

            //////////////////// normal graph /////////////////////////////////

            // Paint numbers
            for (int i = 0; i < instances.numInstances(); i++) {
                Instance ins = instances.instance(i);
                if (!ins.isMissing(m_xIndex) && !ins.isMissing(m_yIndex)) {
                    if (classAtt.isNominal())
                        gx.setColor((Color) m_colorList.elementAt((int) ins.classValue()));
                    else {
                        double r = (ins.classValue() - m_minC) / (m_maxC - m_minC);
                        r = (r * 240) + 15;
                        gx.setColor(new Color((int) r, 150, (int) (255 - r)));
                    }

                    double x = convertToPanelX(ins.value(m_xIndex));
                    double y = convertToPanelY(ins.value(m_yIndex));

                    String id = temp_plot.getPlotName();
                    gx.drawString(id, (int) (x - m_labelMetrics.stringWidth(id) / 2),
                            (int) (y + m_labelMetrics.getHeight() / 2));
                }
            }

            //////////////////// normal graph /////////////////////////////////   
        }
    }

    //////////////////// TLD stuff /////////////////////////////////
    // Draw two Guassian contour with 3 stdDev
    // (-1, -1) with stdDev 1, 2
    // (1, 1) with stdDev 2, 1
    /*gx.setColor(Color.black);
    double mu=-1.5, sigmx, sigmy; // class 0
    if(m_xIndex == 1)
        sigmx = 1;          
    else
        sigmx = 2;
    if(m_yIndex == 1)
        sigmy = 1;          
    else
        sigmy = 2;
            
    double x1 = convertToPanelX(mu-3*sigmx), x2 = convertToPanelX(mu+3*sigmx),
        xlen = Math.abs(x2-x1);
    double y1 = convertToPanelY(mu+3*sigmy), y2 = convertToPanelY(mu-3*sigmy),
        ylen = Math.abs(y2-y1);
    // Draw heavy oval
    gx.drawOval((int)x1,(int)y1,(int)xlen,(int)ylen);
    gx.drawOval((int)x1-1,(int)y1-1,(int)xlen+2,(int)ylen+2);
    gx.drawOval((int)x1+1,(int)y1+1,(int)xlen-2,(int)ylen-2);
    // Draw a dot
    gx.fillOval((int)convertToPanelX(mu)-3, (int)convertToPanelY(mu)-3, 6, 6);
            
    mu=1.5; // class 1
    if(m_xIndex == 1)
        sigmx = 1;          
    else
        sigmx = 2;
    if(m_yIndex == 1)
        sigmy = 1;          
    else
        sigmy = 2;
            
    x1 = convertToPanelX(mu-3*sigmx);
    x2 = convertToPanelX(mu+3*sigmx);
    xlen = Math.abs(x2-x1);
    y1 = convertToPanelY(mu+3*sigmy);
    y2 = convertToPanelY(mu-3*sigmy);
    ylen = Math.abs(y2-y1);
    // Draw heavy oval
    gx.drawOval((int)x1,(int)y1,(int)xlen,(int)ylen);
    gx.drawOval((int)x1-1,(int)y1-1,(int)xlen+2,(int)ylen+2);
    gx.drawOval((int)x1+1,(int)y1+1,(int)xlen-2,(int)ylen-2);
    // Draw a dot
    gx.fillOval((int)convertToPanelX(mu)-3, (int)convertToPanelY(mu)-3, 6, 6);
    */
    //////////////////// TLD stuff /////////////////////////////////

    //////////////////// instance-based stuff /////////////////////////////////
    /*
    // Paint a log-odds line: 1*x0+2*x1=0
    double xstart, xend, ystart, yend, xCoeff, yCoeff;
    if(m_xIndex == 1)
        xCoeff = 1;   
    else
        xCoeff = 2;   
    if(m_yIndex == 1)
        yCoeff = 1;   
    else
        yCoeff = 2;   
            
    xstart = m_minX;
    ystart = -xstart*xCoeff/yCoeff;
    if(ystart > m_maxY){
        ystart = m_maxY;
        xstart = -ystart*yCoeff/xCoeff;
    }   
    yend = m_minY;
    xend = -yend*yCoeff/xCoeff;
    if(xend > m_maxX){
        xend = m_maxX;
        yend = -xend*xCoeff/yCoeff;
    }
            
    // Draw a heavy line
    gx.setColor(Color.black);
    gx.drawLine((int)convertToPanelX(xstart), (int)convertToPanelY(ystart),
           (int)convertToPanelX(xend), (int)convertToPanelY(yend));
    gx.drawLine((int)convertToPanelX(xstart)+1, (int)convertToPanelY(ystart)+1,
           (int)convertToPanelX(xend)+1, (int)convertToPanelY(yend)+1);
    gx.drawLine((int)convertToPanelX(xstart)-1, (int)convertToPanelY(ystart)-1,
           (int)convertToPanelX(xend)-1, (int)convertToPanelY(yend)-1);
    */
    //////////////////// instance-based stuff /////////////////////////////////
}

From source file:ml.ann.MultiClassPTR.java

@Override
public void buildClassifier(Instances instances) throws Exception {
    initAttributes(instances);//from  w  ww.  j a  v a  2  s. c  om

    // REMEMBER: only works if class index is in the last position
    for (int instanceIdx = 0; instanceIdx < instances.numInstances(); instanceIdx++) {
        Instance instance = instances.get(instanceIdx);
        double[] inputInstance = inputInstances[instanceIdx];
        inputInstance[0] = 1.0; // initialize bias value
        for (int attrIdx = 0; attrIdx < instance.numAttributes() - 1; attrIdx++) {
            inputInstance[attrIdx + 1] = instance.value(attrIdx); // the first index of input instance is for bias
        }
    }

    // Initialize target values
    if (instances.classAttribute().isNominal()) {
        for (int instanceIdx = 0; instanceIdx < instances.numInstances(); instanceIdx++) {
            Instance instance = instances.instance(instanceIdx);
            for (int classIdx = 0; classIdx < instances.numClasses(); classIdx++) {
                targetInstances[instanceIdx][classIdx] = 0.0;
            }
            targetInstances[instanceIdx][(int) instance.classValue()] = 1.0;
        }
    } else {
        for (int instanceIdx = 0; instanceIdx < instances.numInstances(); instanceIdx++) {
            Instance instance = instances.instance(instanceIdx);
            targetInstances[instanceIdx][0] = instance.classValue();
        }
    }

    if (algo == 1) {
        setActFunction();
        buildClassifier();
    } else if (algo == 2) {
        buildClassifier();
    } else if (algo == 3) {
        buildClassifierBatch();
    }
}

From source file:ml.ann.SinglePTR.java

@Override
public void buildClassifier(Instances train) throws Exception {
    double[][] input;
    double weightawal = 0.0;
    input = new double[train.numInstances()][train.numAttributes()];
    for (int i = 0; i < train.numInstances(); i++) {
        for (int j = 1; j < train.numAttributes(); j++) {
            System.out.println(train.attribute(j - 1));
            input[i][j] = train.instance(i).value(j - 1);
            System.out.println("input[" + i + "][" + j + "]: " + input[i][j]);
        }//www  .j a  va 2 s.com
    }

    double[] target = new double[train.numInstances()];

    for (int i = 0; i < train.numInstances(); i++) {
        target[i] = train.instance(i).classValue();
        System.out.println("target[" + i + "]: " + target[i]);
    }

    double[][] weight = new double[train.numAttributes()][1];
    for (int i = 0; i < train.numAttributes(); i++) {
        weight[i][0] = weightawal;
    }

    if (algo == 1) {
        SinglePTR testrun;
        testrun = new SinglePTR(train.numInstances(), train.numAttributes() - 1, 10, 0.1, 0.01, input, target,
                weight, 1, momentum, randomWeight);
    } else if (algo == 2) {
        SinglePTR testrun;
        testrun = new SinglePTR(train.numInstances(), train.numAttributes() - 1, 10, 0.1, 0.01, input, target,
                weight, 2, momentum, randomWeight);
    } else if (algo == 3) {
        SinglePTR testrun;
        testrun = new SinglePTR(train.numInstances(), train.numAttributes() - 1, 10, 0.1, 0.01, input, target,
                weight, 3, momentum, randomWeight);
    }
}

From source file:ml.dataprocess.CorrelationAttributeEval.java

License:Open Source License

/**
 * Initializes an information gain attribute evaluator. Replaces missing
 * values with means/modes; Deletes instances with missing class values.
 * /*from   w  w w  .  j av  a  2 s  .com*/
 * @param data set of instances serving as training data
 * @throws Exception if the evaluator has not been generated successfully
 */
@Override
public void buildEvaluator(Instances data) throws Exception {
    data = new Instances(data);
    data.deleteWithMissingClass();

    ReplaceMissingValues rmv = new ReplaceMissingValues();
    rmv.setInputFormat(data);
    data = Filter.useFilter(data, rmv);

    int numClasses = data.classAttribute().numValues();
    int classIndex = data.classIndex();
    int numInstances = data.numInstances();
    m_correlations = new double[data.numAttributes()];
    /*
     * boolean hasNominals = false; boolean hasNumerics = false;
     */
    List<Integer> numericIndexes = new ArrayList<Integer>();
    List<Integer> nominalIndexes = new ArrayList<Integer>();
    if (m_detailedOutput) {
        m_detailedOutputBuff = new StringBuffer();
    }

    // TODO for instance weights (folded into computing weighted correlations)
    // add another dimension just before the last [2] (0 for 0/1 binary vector
    // and
    // 1 for corresponding instance weights for the 1's)
    double[][][] nomAtts = new double[data.numAttributes()][][];
    for (int i = 0; i < data.numAttributes(); i++) {
        if (data.attribute(i).isNominal() && i != classIndex) {
            nomAtts[i] = new double[data.attribute(i).numValues()][data.numInstances()];
            Arrays.fill(nomAtts[i][0], 1.0); // set zero index for this att to all
                                             // 1's
            nominalIndexes.add(i);
        } else if (data.attribute(i).isNumeric() && i != classIndex) {
            numericIndexes.add(i);
        }
    }

    // do the nominal attributes
    if (nominalIndexes.size() > 0) {
        for (int i = 0; i < data.numInstances(); i++) {
            Instance current = data.instance(i);
            for (int j = 0; j < current.numValues(); j++) {
                if (current.attribute(current.index(j)).isNominal() && current.index(j) != classIndex) {
                    // Will need to check for zero in case this isn't a sparse
                    // instance (unless we add 1 and subtract 1)
                    nomAtts[current.index(j)][(int) current.valueSparse(j)][i] += 1;
                    nomAtts[current.index(j)][0][i] -= 1;
                }
            }
        }
    }

    if (data.classAttribute().isNumeric()) {
        double[] classVals = data.attributeToDoubleArray(classIndex);

        // do the numeric attributes
        for (Integer i : numericIndexes) {
            double[] numAttVals = data.attributeToDoubleArray(i);
            m_correlations[i] = Utils.correlation(numAttVals, classVals, numAttVals.length);

            if (m_correlations[i] == 1.0) {
                // check for zero variance (useless numeric attribute)
                if (Utils.variance(numAttVals) == 0) {
                    m_correlations[i] = 0;
                }
            }
        }

        // do the nominal attributes
        if (nominalIndexes.size() > 0) {

            // now compute the correlations for the binarized nominal attributes
            for (Integer i : nominalIndexes) {
                double sum = 0;
                double corr = 0;
                double sumCorr = 0;
                double sumForValue = 0;

                if (m_detailedOutput) {
                    m_detailedOutputBuff.append("\n\n").append(data.attribute(i).name());
                }

                for (int j = 0; j < data.attribute(i).numValues(); j++) {
                    sumForValue = Utils.sum(nomAtts[i][j]);
                    corr = Utils.correlation(nomAtts[i][j], classVals, classVals.length);

                    // useless attribute - all instances have the same value
                    if (sumForValue == numInstances || sumForValue == 0) {
                        corr = 0;
                    }
                    if (corr < 0.0) {
                        corr = -corr;
                    }
                    sumCorr += sumForValue * corr;
                    sum += sumForValue;

                    if (m_detailedOutput) {
                        m_detailedOutputBuff.append("\n\t").append(data.attribute(i).value(j)).append(": ");
                        m_detailedOutputBuff.append(Utils.doubleToString(corr, 6));
                    }
                }
                m_correlations[i] = (sum > 0) ? sumCorr / sum : 0;
            }
        }
    } else {
        // class is nominal
        // TODO extra dimension for storing instance weights too
        double[][] binarizedClasses = new double[data.classAttribute().numValues()][data.numInstances()];

        // this is equal to the number of instances for all inst weights = 1
        double[] classValCounts = new double[data.classAttribute().numValues()];

        for (int i = 0; i < data.numInstances(); i++) {
            Instance current = data.instance(i);
            binarizedClasses[(int) current.classValue()][i] = 1;
        }
        for (int i = 0; i < data.classAttribute().numValues(); i++) {
            classValCounts[i] = Utils.sum(binarizedClasses[i]);
        }

        double sumClass = Utils.sum(classValCounts);

        // do numeric attributes first
        if (numericIndexes.size() > 0) {
            for (Integer i : numericIndexes) {
                double[] numAttVals = data.attributeToDoubleArray(i);
                double corr = 0;
                double sumCorr = 0;

                for (int j = 0; j < data.classAttribute().numValues(); j++) {
                    corr = Utils.correlation(numAttVals, binarizedClasses[j], numAttVals.length);
                    if (corr < 0.0) {
                        corr = -corr;
                    }

                    if (corr == 1.0) {
                        // check for zero variance (useless numeric attribute)
                        if (Utils.variance(numAttVals) == 0) {
                            corr = 0;
                        }
                    }

                    sumCorr += classValCounts[j] * corr;
                }
                m_correlations[i] = sumCorr / sumClass;
            }
        }

        if (nominalIndexes.size() > 0) {
            for (Integer i : nominalIndexes) {
                if (m_detailedOutput) {
                    m_detailedOutputBuff.append("\n\n").append(data.attribute(i).name());
                }

                double sumForAtt = 0;
                double corrForAtt = 0;
                for (int j = 0; j < data.attribute(i).numValues(); j++) {
                    double sumForValue = Utils.sum(nomAtts[i][j]);
                    double corr = 0;
                    double sumCorr = 0;
                    double avgCorrForValue = 0;

                    sumForAtt += sumForValue;
                    for (int k = 0; k < numClasses; k++) {

                        // corr between value j and class k
                        corr = Utils.correlation(nomAtts[i][j], binarizedClasses[k],
                                binarizedClasses[k].length);

                        // useless attribute - all instances have the same value
                        if (sumForValue == numInstances || sumForValue == 0) {
                            corr = 0;
                        }
                        if (corr < 0.0) {
                            corr = -corr;
                        }
                        sumCorr += classValCounts[k] * corr;
                    }
                    avgCorrForValue = sumCorr / sumClass;
                    corrForAtt += sumForValue * avgCorrForValue;

                    if (m_detailedOutput) {
                        m_detailedOutputBuff.append("\n\t").append(data.attribute(i).value(j)).append(": ");
                        m_detailedOutputBuff.append(Utils.doubleToString(avgCorrForValue, 6));
                    }
                }

                // the weighted average corr for att i as
                // a whole (wighted by value frequencies)
                m_correlations[i] = (sumForAtt > 0) ? corrForAtt / sumForAtt : 0;
            }
        }
    }

    if (m_detailedOutputBuff != null && m_detailedOutputBuff.length() > 0) {
        m_detailedOutputBuff.append("\n");
    }
}

From source file:ml.engine.LibSVM.java

License:Open Source License

/**
 * builds the classifier//from   www. ja v  a2 s . co  m
 * 
 * @param insts the training instances
 * @throws Exception if libsvm classes not in classpath or libsvm encountered
 *           a problem
 */
@Override
public void buildClassifier(Instances insts) throws Exception {
    m_Filter = null;

    if (!isPresent()) {
        throw new Exception("libsvm classes not in CLASSPATH!");
    }

    // remove instances with missing class
    insts = new Instances(insts);
    insts.deleteWithMissingClass();

    if (!getDoNotReplaceMissingValues()) {
        m_ReplaceMissingValues = new ReplaceMissingValues();
        m_ReplaceMissingValues.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_ReplaceMissingValues);
    }

    // can classifier handle the data?
    // we check this here so that if the user turns off
    // replace missing values filtering, it will fail
    // if the data actually does have missing values
    getCapabilities().testWithFail(insts);

    if (getNormalize()) {
        m_Filter = new Normalize();
        m_Filter.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_Filter);
    }

    // nominal to binary
    m_NominalToBinary = new NominalToBinary();
    m_NominalToBinary.setInputFormat(insts);
    insts = Filter.useFilter(insts, m_NominalToBinary);

    Vector vy = new Vector();
    Vector vx = new Vector();
    int max_index = 0;

    for (int d = 0; d < insts.numInstances(); d++) {
        Instance inst = insts.instance(d);
        Object x = instanceToArray(inst);
        int m = Array.getLength(x);

        if (m > 0) {
            max_index = Math.max(max_index, ((Integer) getField(Array.get(x, m - 1), "index")).intValue());
        }
        vx.addElement(x);
        vy.addElement(new Double(inst.classValue()));
    }

    // calculate actual gamma
    if (getGamma() == 0) {
        m_GammaActual = 1.0 / max_index;
    } else {
        m_GammaActual = m_Gamma;
    }

    // check parameter
    String error_msg = (String) invokeMethod(Class.forName(CLASS_SVM).newInstance(), "svm_check_parameter",
            new Class[] { Class.forName(CLASS_SVMPROBLEM), Class.forName(CLASS_SVMPARAMETER) },
            new Object[] { getProblem(vx, vy), getParameters() });

    if (error_msg != null) {
        throw new Exception("Error: " + error_msg);
    }

    // make probability estimates deterministic from run to run
    Class svmClass = Class.forName(CLASS_SVM);
    Field randF = svmClass.getField("rand");
    Random rand = (Random) randF.get(null); // static field
    rand.setSeed(m_Seed);

    // train model
    m_Model = invokeMethod(Class.forName(CLASS_SVM).newInstance(), "svm_train",
            new Class[] { Class.forName(CLASS_SVMPROBLEM), Class.forName(CLASS_SVMPARAMETER) },
            new Object[] { getProblem(vx, vy), getParameters() });
}

From source file:mlda.util.Utils.java

License:Open Source License

/**
 * Get array of ImbalancedFeature with labels frequency
 * /*  w  w  w. j  a v  a 2  s.c  om*/
 * @param dataset Multi-label dataset
 * @return Array of ImbalancedFeature with the labels frequency
 */
public static ImbalancedFeature[] getAppearancesPerLabel(MultiLabelInstances dataset) {
    int[] labelIndices = dataset.getLabelIndices();

    ImbalancedFeature[] labels = new ImbalancedFeature[labelIndices.length];

    Instances instances = dataset.getDataSet();

    int appearances = 0;
    Attribute currentAtt;

    for (int i = 0; i < labelIndices.length; i++) {
        currentAtt = instances.attribute(labelIndices[i]);
        appearances = 0;

        for (int j = 0; j < instances.size(); j++) {
            if (instances.instance(j).value(currentAtt) == 1.0) {
                appearances++;
            }
        }
        labels[i] = new ImbalancedFeature(currentAtt.name(), appearances);
    }

    return labels;
}

From source file:mlda.util.Utils.java

License:Open Source License

/**
 * Calculate IRs of the ImbalancedFeatures
 * /*from   ww  w.ja v  a 2s. co  m*/
 * @param dataset Multi-label dataset
 * @param labels Labels of the dataset as ImbalancedFeature objects
 * @return Array of ImbalancedFeature objects with calculated IR
 */
public static ImbalancedFeature[] getImbalancedWithIR(MultiLabelInstances dataset, ImbalancedFeature[] labels) {
    int[] labelIndices = dataset.getLabelIndices();

    ImbalancedFeature[] labels_imbalanced = new ImbalancedFeature[labelIndices.length];

    Instances instances = dataset.getDataSet();

    int nOnes = 0, nZeros = 0, maxAppearance = 0;
    double IRIntraClass;
    double variance;
    double IRInterClass;
    double mean = dataset.getNumInstances() / 2;

    Attribute current;
    ImbalancedFeature currentLabel;

    for (int i = 0; i < labelIndices.length; i++) //for each label
    {
        nZeros = 0;
        nOnes = 0;
        current = instances.attribute(labelIndices[i]); //current label

        for (int j = 0; j < instances.size(); j++) //for each instance
        {
            if (instances.instance(j).value(current) == 1.0) {
                nOnes++;
            } else {
                nZeros++;
            }
        }

        try {
            if (nZeros == 0 || nOnes == 0) {
                IRIntraClass = 0;
            } else if (nZeros > nOnes) {
                IRIntraClass = (double) nZeros / nOnes;
            } else {
                IRIntraClass = (double) nOnes / nZeros;
            }
        } catch (Exception e1) {
            IRIntraClass = 0;
        }

        variance = (Math.pow((nZeros - mean), 2) + Math.pow((nOnes - mean), 2)) / 2;

        currentLabel = getLabelByName(current.name(), labels);

        maxAppearance = labels[0].getAppearances();

        if (currentLabel.getAppearances() <= 0) {
            IRInterClass = Double.NaN;
        } else {
            IRInterClass = (double) maxAppearance / currentLabel.getAppearances();
        }

        labels_imbalanced[i] = new ImbalancedFeature(current.name(), currentLabel.getAppearances(),
                IRInterClass, IRIntraClass, variance);
    }

    return labels_imbalanced;
}