List of usage examples for weka.core Instances instance
publicInstance instance(int index)
From source file:milk.experiment.MIInstancesResultListener.java
License:Open Source License
/** * Perform any postprocessing. When this method is called, it indicates * that no more results will be sent that need to be grouped together * in any way.//from w w w . j a v a 2 s . c o m * * @param rp the ResultProducer that generated the results * @exception Exception if an error occurs */ public void postProcess(MIResultProducer rp) throws Exception { if (m_RP != rp) { throw new Error("Unrecognized ResultProducer sending results!!"); } String[] keyNames = m_RP.getKeyNames(); String[] resultNames = m_RP.getResultNames(); FastVector attribInfo = new FastVector(); for (int i = 0; i < m_AttributeTypes.length; i++) { String attribName = "Unknown"; if (i < keyNames.length) { attribName = "Key_" + keyNames[i]; } else { attribName = resultNames[i - keyNames.length]; } switch (m_AttributeTypes[i]) { case Attribute.NOMINAL: if (m_NominalStrings[i].size() > 0) { attribInfo.addElement(new Attribute(attribName, m_NominalStrings[i])); } else { attribInfo.addElement(new Attribute(attribName, (FastVector) null)); } break; case Attribute.NUMERIC: attribInfo.addElement(new Attribute(attribName)); break; case Attribute.STRING: attribInfo.addElement(new Attribute(attribName, (FastVector) null)); break; default: throw new Exception("Unknown attribute type"); } } Instances result = new Instances("InstanceResultListener", attribInfo, m_Instances.size()); for (int i = 0; i < m_Instances.size(); i++) { result.add((Instance) m_Instances.elementAt(i)); } m_Out.println(new Instances(result, 0)); for (int i = 0; i < result.numInstances(); i++) { m_Out.println(result.instance(i)); } if (!(m_OutputFile == null) && !(m_OutputFile.getName().equals("-"))) { m_Out.close(); } }
From source file:milk.gui.experiment.MIResultsPanel.java
License:Open Source License
/** * Queries the user enough to make a database query to retrieve experiment * results./*from w w w . j a v a2 s . c o m*/ */ protected void setInstancesFromDBaseQuery() { try { if (m_InstanceQuery == null) { m_InstanceQuery = new MIInstanceQuery(); } String dbaseURL = m_InstanceQuery.getDatabaseURL(); dbaseURL = (String) JOptionPane.showInputDialog(this, "Enter the database URL", "Query Database", JOptionPane.PLAIN_MESSAGE, null, null, dbaseURL); if (dbaseURL == null) { m_FromLab.setText("Cancelled"); return; } m_InstanceQuery.setDatabaseURL(dbaseURL); m_InstanceQuery.connectToDatabase(); if (!m_InstanceQuery.experimentIndexExists()) { m_FromLab.setText("No experiment index"); return; } m_FromLab.setText("Getting experiment index"); Instances index = m_InstanceQuery.retrieveInstances("SELECT * FROM " + MIInstanceQuery.EXP_INDEX_TABLE); if (index.numInstances() == 0) { m_FromLab.setText("No experiments available"); return; } m_FromLab.setText("Got experiment index"); DefaultListModel lm = new DefaultListModel(); for (int i = 0; i < index.numInstances(); i++) { lm.addElement(index.instance(i).toString()); } JList jl = new JList(lm); ListSelectorDialog jd = new ListSelectorDialog(null, jl); int result = jd.showDialog(); if (result != ListSelectorDialog.APPROVE_OPTION) { m_FromLab.setText("Cancelled"); return; } Instance selInst = index.instance(jl.getSelectedIndex()); Attribute tableAttr = index.attribute(MIInstanceQuery.EXP_RESULT_COL); String table = MIInstanceQuery.EXP_RESULT_PREFIX + selInst.toString(tableAttr); setInstancesFromDatabaseTable(table); } catch (Exception ex) { m_FromLab.setText("Problem reading database"); } }
From source file:milk.visualize.MIPlot2D.java
License:Open Source License
/** * Renders this component//from w w w. j a va2s . c om * @param gx the graphics context */ public void paintComponent(Graphics gx) { //if(!isEnabled()) // return; super.paintComponent(gx); if (plotExemplars != null) { gx.setColor(m_axisColour); // Draw the axis name String xname = plotExemplars.attribute(m_xIndex).name(), yname = plotExemplars.attribute(m_yIndex).name(); gx.drawString(yname, m_XaxisStart + m_labelMetrics.stringWidth("M"), m_YaxisStart + m_labelMetrics.getAscent() / 2 + m_tickSize); gx.drawString(xname, m_XaxisEnd - m_labelMetrics.stringWidth(yname) + m_tickSize, (int) (m_YaxisEnd - m_labelMetrics.getAscent() / 2)); // Draw points Attribute classAtt = plotExemplars.classAttribute(); for (int j = 0; j < m_plots.size(); j++) { PlotData2D temp_plot = (PlotData2D) (m_plots.elementAt(j)); Instances instances = temp_plot.getPlotInstances(); StringTokenizer st = new StringTokenizer( instances.firstInstance().stringValue(plotExemplars.idIndex()), "_"); //////////////////// TLD stuff ///////////////////////////////// /* double[] mu = new double[plotExemplars.numAttributes()], sgm = new double[plotExemplars.numAttributes()]; st.nextToken(); // Squeeze first element int p=0; while(p<mu.length){ if((p==plotExemplars.idIndex()) || (p==plotExemplars.classIndex())) p++; if(p<mu.length){ mu[p] = Double.parseDouble(st.nextToken()); sgm[p] = Double.parseDouble(st.nextToken()); p++; } } Instance ins = instances.firstInstance(); gx.setColor((Color)m_colorList.elementAt((int)ins.classValue())); double mux=mu[m_xIndex], muy=mu[m_yIndex], sgmx=sgm[m_xIndex], sgmy=sgm[m_yIndex]; double xs = convertToPanelX(mux-3*sgmx), xe = convertToPanelX(mux+3*sgmx), xleng = Math.abs(xe-xs); double ys = convertToPanelY(muy+3*sgmy), ye = convertToPanelY(muy-3*sgmy), yleng = Math.abs(ye-ys); // Draw oval gx.drawOval((int)xs,(int)ys,(int)xleng,(int)yleng); // Draw a dot gx.fillOval((int)convertToPanelX(mux)-2, (int)convertToPanelY(muy)-2, 4, 4); */ //////////////////// TLD stuff ///////////////////////////////// //////////////////// instance-based stuff ///////////////////////////////// /* double[] core = new double[plotExemplars.numAttributes()], range=new double[plotExemplars.numAttributes()]; st.nextToken(); // Squeeze first element int p=0; while(p<range.length){ if((p==plotExemplars.idIndex()) || (p==plotExemplars.classIndex())) p++; if(p<range.length) range[p++] = Double.parseDouble(st.nextToken()); } p=0; while(st.hasMoreTokens()){ if((p==plotExemplars.idIndex()) || (p==plotExemplars.classIndex())) p++; core[p++] = Double.parseDouble(st.nextToken()); } Instance ins = instances.firstInstance(); gx.setColor((Color)m_colorList.elementAt((int)ins.classValue())); double rgx=range[m_xIndex], rgy=range[m_yIndex]; double x1 = convertToPanelX(core[m_xIndex]-rgx/2), y1 = convertToPanelY(core[m_yIndex]-rgy/2), x2 = convertToPanelX(core[m_xIndex]+rgx/2), y2 = convertToPanelY(core[m_yIndex]+rgy/2), x = convertToPanelX(core[m_xIndex]), y = convertToPanelY(core[m_yIndex]); // Draw a rectangle gx.drawLine((int)x1, (int)y1, (int)x2, (int)y1); gx.drawLine((int)x1, (int)y1, (int)x1, (int)y2); gx.drawLine((int)x2, (int)y1, (int)x2, (int)y2); gx.drawLine((int)x1, (int)y2, (int)x2, (int)y2); // Draw a dot gx.fillOval((int)x-3, (int)y-3, 6, 6); // Draw string StringBuffer text =new StringBuffer(temp_plot.getPlotName()+":"+instances.numInstances()); gx.drawString(text.toString(), (int)x1, (int)y2+m_labelMetrics.getHeight()); */ //////////////////// instance-based stuff ///////////////////////////////// //////////////////// normal graph ///////////////////////////////// // Paint numbers for (int i = 0; i < instances.numInstances(); i++) { Instance ins = instances.instance(i); if (!ins.isMissing(m_xIndex) && !ins.isMissing(m_yIndex)) { if (classAtt.isNominal()) gx.setColor((Color) m_colorList.elementAt((int) ins.classValue())); else { double r = (ins.classValue() - m_minC) / (m_maxC - m_minC); r = (r * 240) + 15; gx.setColor(new Color((int) r, 150, (int) (255 - r))); } double x = convertToPanelX(ins.value(m_xIndex)); double y = convertToPanelY(ins.value(m_yIndex)); String id = temp_plot.getPlotName(); gx.drawString(id, (int) (x - m_labelMetrics.stringWidth(id) / 2), (int) (y + m_labelMetrics.getHeight() / 2)); } } //////////////////// normal graph ///////////////////////////////// } } //////////////////// TLD stuff ///////////////////////////////// // Draw two Guassian contour with 3 stdDev // (-1, -1) with stdDev 1, 2 // (1, 1) with stdDev 2, 1 /*gx.setColor(Color.black); double mu=-1.5, sigmx, sigmy; // class 0 if(m_xIndex == 1) sigmx = 1; else sigmx = 2; if(m_yIndex == 1) sigmy = 1; else sigmy = 2; double x1 = convertToPanelX(mu-3*sigmx), x2 = convertToPanelX(mu+3*sigmx), xlen = Math.abs(x2-x1); double y1 = convertToPanelY(mu+3*sigmy), y2 = convertToPanelY(mu-3*sigmy), ylen = Math.abs(y2-y1); // Draw heavy oval gx.drawOval((int)x1,(int)y1,(int)xlen,(int)ylen); gx.drawOval((int)x1-1,(int)y1-1,(int)xlen+2,(int)ylen+2); gx.drawOval((int)x1+1,(int)y1+1,(int)xlen-2,(int)ylen-2); // Draw a dot gx.fillOval((int)convertToPanelX(mu)-3, (int)convertToPanelY(mu)-3, 6, 6); mu=1.5; // class 1 if(m_xIndex == 1) sigmx = 1; else sigmx = 2; if(m_yIndex == 1) sigmy = 1; else sigmy = 2; x1 = convertToPanelX(mu-3*sigmx); x2 = convertToPanelX(mu+3*sigmx); xlen = Math.abs(x2-x1); y1 = convertToPanelY(mu+3*sigmy); y2 = convertToPanelY(mu-3*sigmy); ylen = Math.abs(y2-y1); // Draw heavy oval gx.drawOval((int)x1,(int)y1,(int)xlen,(int)ylen); gx.drawOval((int)x1-1,(int)y1-1,(int)xlen+2,(int)ylen+2); gx.drawOval((int)x1+1,(int)y1+1,(int)xlen-2,(int)ylen-2); // Draw a dot gx.fillOval((int)convertToPanelX(mu)-3, (int)convertToPanelY(mu)-3, 6, 6); */ //////////////////// TLD stuff ///////////////////////////////// //////////////////// instance-based stuff ///////////////////////////////// /* // Paint a log-odds line: 1*x0+2*x1=0 double xstart, xend, ystart, yend, xCoeff, yCoeff; if(m_xIndex == 1) xCoeff = 1; else xCoeff = 2; if(m_yIndex == 1) yCoeff = 1; else yCoeff = 2; xstart = m_minX; ystart = -xstart*xCoeff/yCoeff; if(ystart > m_maxY){ ystart = m_maxY; xstart = -ystart*yCoeff/xCoeff; } yend = m_minY; xend = -yend*yCoeff/xCoeff; if(xend > m_maxX){ xend = m_maxX; yend = -xend*xCoeff/yCoeff; } // Draw a heavy line gx.setColor(Color.black); gx.drawLine((int)convertToPanelX(xstart), (int)convertToPanelY(ystart), (int)convertToPanelX(xend), (int)convertToPanelY(yend)); gx.drawLine((int)convertToPanelX(xstart)+1, (int)convertToPanelY(ystart)+1, (int)convertToPanelX(xend)+1, (int)convertToPanelY(yend)+1); gx.drawLine((int)convertToPanelX(xstart)-1, (int)convertToPanelY(ystart)-1, (int)convertToPanelX(xend)-1, (int)convertToPanelY(yend)-1); */ //////////////////// instance-based stuff ///////////////////////////////// }
From source file:ml.ann.MultiClassPTR.java
@Override public void buildClassifier(Instances instances) throws Exception { initAttributes(instances);//from w ww. j a v a 2 s. c om // REMEMBER: only works if class index is in the last position for (int instanceIdx = 0; instanceIdx < instances.numInstances(); instanceIdx++) { Instance instance = instances.get(instanceIdx); double[] inputInstance = inputInstances[instanceIdx]; inputInstance[0] = 1.0; // initialize bias value for (int attrIdx = 0; attrIdx < instance.numAttributes() - 1; attrIdx++) { inputInstance[attrIdx + 1] = instance.value(attrIdx); // the first index of input instance is for bias } } // Initialize target values if (instances.classAttribute().isNominal()) { for (int instanceIdx = 0; instanceIdx < instances.numInstances(); instanceIdx++) { Instance instance = instances.instance(instanceIdx); for (int classIdx = 0; classIdx < instances.numClasses(); classIdx++) { targetInstances[instanceIdx][classIdx] = 0.0; } targetInstances[instanceIdx][(int) instance.classValue()] = 1.0; } } else { for (int instanceIdx = 0; instanceIdx < instances.numInstances(); instanceIdx++) { Instance instance = instances.instance(instanceIdx); targetInstances[instanceIdx][0] = instance.classValue(); } } if (algo == 1) { setActFunction(); buildClassifier(); } else if (algo == 2) { buildClassifier(); } else if (algo == 3) { buildClassifierBatch(); } }
From source file:ml.ann.SinglePTR.java
@Override public void buildClassifier(Instances train) throws Exception { double[][] input; double weightawal = 0.0; input = new double[train.numInstances()][train.numAttributes()]; for (int i = 0; i < train.numInstances(); i++) { for (int j = 1; j < train.numAttributes(); j++) { System.out.println(train.attribute(j - 1)); input[i][j] = train.instance(i).value(j - 1); System.out.println("input[" + i + "][" + j + "]: " + input[i][j]); }//www .j a va 2 s.com } double[] target = new double[train.numInstances()]; for (int i = 0; i < train.numInstances(); i++) { target[i] = train.instance(i).classValue(); System.out.println("target[" + i + "]: " + target[i]); } double[][] weight = new double[train.numAttributes()][1]; for (int i = 0; i < train.numAttributes(); i++) { weight[i][0] = weightawal; } if (algo == 1) { SinglePTR testrun; testrun = new SinglePTR(train.numInstances(), train.numAttributes() - 1, 10, 0.1, 0.01, input, target, weight, 1, momentum, randomWeight); } else if (algo == 2) { SinglePTR testrun; testrun = new SinglePTR(train.numInstances(), train.numAttributes() - 1, 10, 0.1, 0.01, input, target, weight, 2, momentum, randomWeight); } else if (algo == 3) { SinglePTR testrun; testrun = new SinglePTR(train.numInstances(), train.numAttributes() - 1, 10, 0.1, 0.01, input, target, weight, 3, momentum, randomWeight); } }
From source file:ml.dataprocess.CorrelationAttributeEval.java
License:Open Source License
/** * Initializes an information gain attribute evaluator. Replaces missing * values with means/modes; Deletes instances with missing class values. * /*from w w w . j av a 2 s .com*/ * @param data set of instances serving as training data * @throws Exception if the evaluator has not been generated successfully */ @Override public void buildEvaluator(Instances data) throws Exception { data = new Instances(data); data.deleteWithMissingClass(); ReplaceMissingValues rmv = new ReplaceMissingValues(); rmv.setInputFormat(data); data = Filter.useFilter(data, rmv); int numClasses = data.classAttribute().numValues(); int classIndex = data.classIndex(); int numInstances = data.numInstances(); m_correlations = new double[data.numAttributes()]; /* * boolean hasNominals = false; boolean hasNumerics = false; */ List<Integer> numericIndexes = new ArrayList<Integer>(); List<Integer> nominalIndexes = new ArrayList<Integer>(); if (m_detailedOutput) { m_detailedOutputBuff = new StringBuffer(); } // TODO for instance weights (folded into computing weighted correlations) // add another dimension just before the last [2] (0 for 0/1 binary vector // and // 1 for corresponding instance weights for the 1's) double[][][] nomAtts = new double[data.numAttributes()][][]; for (int i = 0; i < data.numAttributes(); i++) { if (data.attribute(i).isNominal() && i != classIndex) { nomAtts[i] = new double[data.attribute(i).numValues()][data.numInstances()]; Arrays.fill(nomAtts[i][0], 1.0); // set zero index for this att to all // 1's nominalIndexes.add(i); } else if (data.attribute(i).isNumeric() && i != classIndex) { numericIndexes.add(i); } } // do the nominal attributes if (nominalIndexes.size() > 0) { for (int i = 0; i < data.numInstances(); i++) { Instance current = data.instance(i); for (int j = 0; j < current.numValues(); j++) { if (current.attribute(current.index(j)).isNominal() && current.index(j) != classIndex) { // Will need to check for zero in case this isn't a sparse // instance (unless we add 1 and subtract 1) nomAtts[current.index(j)][(int) current.valueSparse(j)][i] += 1; nomAtts[current.index(j)][0][i] -= 1; } } } } if (data.classAttribute().isNumeric()) { double[] classVals = data.attributeToDoubleArray(classIndex); // do the numeric attributes for (Integer i : numericIndexes) { double[] numAttVals = data.attributeToDoubleArray(i); m_correlations[i] = Utils.correlation(numAttVals, classVals, numAttVals.length); if (m_correlations[i] == 1.0) { // check for zero variance (useless numeric attribute) if (Utils.variance(numAttVals) == 0) { m_correlations[i] = 0; } } } // do the nominal attributes if (nominalIndexes.size() > 0) { // now compute the correlations for the binarized nominal attributes for (Integer i : nominalIndexes) { double sum = 0; double corr = 0; double sumCorr = 0; double sumForValue = 0; if (m_detailedOutput) { m_detailedOutputBuff.append("\n\n").append(data.attribute(i).name()); } for (int j = 0; j < data.attribute(i).numValues(); j++) { sumForValue = Utils.sum(nomAtts[i][j]); corr = Utils.correlation(nomAtts[i][j], classVals, classVals.length); // useless attribute - all instances have the same value if (sumForValue == numInstances || sumForValue == 0) { corr = 0; } if (corr < 0.0) { corr = -corr; } sumCorr += sumForValue * corr; sum += sumForValue; if (m_detailedOutput) { m_detailedOutputBuff.append("\n\t").append(data.attribute(i).value(j)).append(": "); m_detailedOutputBuff.append(Utils.doubleToString(corr, 6)); } } m_correlations[i] = (sum > 0) ? sumCorr / sum : 0; } } } else { // class is nominal // TODO extra dimension for storing instance weights too double[][] binarizedClasses = new double[data.classAttribute().numValues()][data.numInstances()]; // this is equal to the number of instances for all inst weights = 1 double[] classValCounts = new double[data.classAttribute().numValues()]; for (int i = 0; i < data.numInstances(); i++) { Instance current = data.instance(i); binarizedClasses[(int) current.classValue()][i] = 1; } for (int i = 0; i < data.classAttribute().numValues(); i++) { classValCounts[i] = Utils.sum(binarizedClasses[i]); } double sumClass = Utils.sum(classValCounts); // do numeric attributes first if (numericIndexes.size() > 0) { for (Integer i : numericIndexes) { double[] numAttVals = data.attributeToDoubleArray(i); double corr = 0; double sumCorr = 0; for (int j = 0; j < data.classAttribute().numValues(); j++) { corr = Utils.correlation(numAttVals, binarizedClasses[j], numAttVals.length); if (corr < 0.0) { corr = -corr; } if (corr == 1.0) { // check for zero variance (useless numeric attribute) if (Utils.variance(numAttVals) == 0) { corr = 0; } } sumCorr += classValCounts[j] * corr; } m_correlations[i] = sumCorr / sumClass; } } if (nominalIndexes.size() > 0) { for (Integer i : nominalIndexes) { if (m_detailedOutput) { m_detailedOutputBuff.append("\n\n").append(data.attribute(i).name()); } double sumForAtt = 0; double corrForAtt = 0; for (int j = 0; j < data.attribute(i).numValues(); j++) { double sumForValue = Utils.sum(nomAtts[i][j]); double corr = 0; double sumCorr = 0; double avgCorrForValue = 0; sumForAtt += sumForValue; for (int k = 0; k < numClasses; k++) { // corr between value j and class k corr = Utils.correlation(nomAtts[i][j], binarizedClasses[k], binarizedClasses[k].length); // useless attribute - all instances have the same value if (sumForValue == numInstances || sumForValue == 0) { corr = 0; } if (corr < 0.0) { corr = -corr; } sumCorr += classValCounts[k] * corr; } avgCorrForValue = sumCorr / sumClass; corrForAtt += sumForValue * avgCorrForValue; if (m_detailedOutput) { m_detailedOutputBuff.append("\n\t").append(data.attribute(i).value(j)).append(": "); m_detailedOutputBuff.append(Utils.doubleToString(avgCorrForValue, 6)); } } // the weighted average corr for att i as // a whole (wighted by value frequencies) m_correlations[i] = (sumForAtt > 0) ? corrForAtt / sumForAtt : 0; } } } if (m_detailedOutputBuff != null && m_detailedOutputBuff.length() > 0) { m_detailedOutputBuff.append("\n"); } }
From source file:ml.engine.LibSVM.java
License:Open Source License
/** * builds the classifier//from www. ja v a2 s . co m * * @param insts the training instances * @throws Exception if libsvm classes not in classpath or libsvm encountered * a problem */ @Override public void buildClassifier(Instances insts) throws Exception { m_Filter = null; if (!isPresent()) { throw new Exception("libsvm classes not in CLASSPATH!"); } // remove instances with missing class insts = new Instances(insts); insts.deleteWithMissingClass(); if (!getDoNotReplaceMissingValues()) { m_ReplaceMissingValues = new ReplaceMissingValues(); m_ReplaceMissingValues.setInputFormat(insts); insts = Filter.useFilter(insts, m_ReplaceMissingValues); } // can classifier handle the data? // we check this here so that if the user turns off // replace missing values filtering, it will fail // if the data actually does have missing values getCapabilities().testWithFail(insts); if (getNormalize()) { m_Filter = new Normalize(); m_Filter.setInputFormat(insts); insts = Filter.useFilter(insts, m_Filter); } // nominal to binary m_NominalToBinary = new NominalToBinary(); m_NominalToBinary.setInputFormat(insts); insts = Filter.useFilter(insts, m_NominalToBinary); Vector vy = new Vector(); Vector vx = new Vector(); int max_index = 0; for (int d = 0; d < insts.numInstances(); d++) { Instance inst = insts.instance(d); Object x = instanceToArray(inst); int m = Array.getLength(x); if (m > 0) { max_index = Math.max(max_index, ((Integer) getField(Array.get(x, m - 1), "index")).intValue()); } vx.addElement(x); vy.addElement(new Double(inst.classValue())); } // calculate actual gamma if (getGamma() == 0) { m_GammaActual = 1.0 / max_index; } else { m_GammaActual = m_Gamma; } // check parameter String error_msg = (String) invokeMethod(Class.forName(CLASS_SVM).newInstance(), "svm_check_parameter", new Class[] { Class.forName(CLASS_SVMPROBLEM), Class.forName(CLASS_SVMPARAMETER) }, new Object[] { getProblem(vx, vy), getParameters() }); if (error_msg != null) { throw new Exception("Error: " + error_msg); } // make probability estimates deterministic from run to run Class svmClass = Class.forName(CLASS_SVM); Field randF = svmClass.getField("rand"); Random rand = (Random) randF.get(null); // static field rand.setSeed(m_Seed); // train model m_Model = invokeMethod(Class.forName(CLASS_SVM).newInstance(), "svm_train", new Class[] { Class.forName(CLASS_SVMPROBLEM), Class.forName(CLASS_SVMPARAMETER) }, new Object[] { getProblem(vx, vy), getParameters() }); }
From source file:mlda.util.Utils.java
License:Open Source License
/** * Get array of ImbalancedFeature with labels frequency * /* w w w. j a v a 2 s.c om*/ * @param dataset Multi-label dataset * @return Array of ImbalancedFeature with the labels frequency */ public static ImbalancedFeature[] getAppearancesPerLabel(MultiLabelInstances dataset) { int[] labelIndices = dataset.getLabelIndices(); ImbalancedFeature[] labels = new ImbalancedFeature[labelIndices.length]; Instances instances = dataset.getDataSet(); int appearances = 0; Attribute currentAtt; for (int i = 0; i < labelIndices.length; i++) { currentAtt = instances.attribute(labelIndices[i]); appearances = 0; for (int j = 0; j < instances.size(); j++) { if (instances.instance(j).value(currentAtt) == 1.0) { appearances++; } } labels[i] = new ImbalancedFeature(currentAtt.name(), appearances); } return labels; }
From source file:mlda.util.Utils.java
License:Open Source License
/** * Calculate IRs of the ImbalancedFeatures * /*from ww w.ja v a 2s. co m*/ * @param dataset Multi-label dataset * @param labels Labels of the dataset as ImbalancedFeature objects * @return Array of ImbalancedFeature objects with calculated IR */ public static ImbalancedFeature[] getImbalancedWithIR(MultiLabelInstances dataset, ImbalancedFeature[] labels) { int[] labelIndices = dataset.getLabelIndices(); ImbalancedFeature[] labels_imbalanced = new ImbalancedFeature[labelIndices.length]; Instances instances = dataset.getDataSet(); int nOnes = 0, nZeros = 0, maxAppearance = 0; double IRIntraClass; double variance; double IRInterClass; double mean = dataset.getNumInstances() / 2; Attribute current; ImbalancedFeature currentLabel; for (int i = 0; i < labelIndices.length; i++) //for each label { nZeros = 0; nOnes = 0; current = instances.attribute(labelIndices[i]); //current label for (int j = 0; j < instances.size(); j++) //for each instance { if (instances.instance(j).value(current) == 1.0) { nOnes++; } else { nZeros++; } } try { if (nZeros == 0 || nOnes == 0) { IRIntraClass = 0; } else if (nZeros > nOnes) { IRIntraClass = (double) nZeros / nOnes; } else { IRIntraClass = (double) nOnes / nZeros; } } catch (Exception e1) { IRIntraClass = 0; } variance = (Math.pow((nZeros - mean), 2) + Math.pow((nOnes - mean), 2)) / 2; currentLabel = getLabelByName(current.name(), labels); maxAppearance = labels[0].getAppearances(); if (currentLabel.getAppearances() <= 0) { IRInterClass = Double.NaN; } else { IRInterClass = (double) maxAppearance / currentLabel.getAppearances(); } labels_imbalanced[i] = new ImbalancedFeature(current.name(), currentLabel.getAppearances(), IRInterClass, IRIntraClass, variance); } return labels_imbalanced; }