List of usage examples for weka.core Instance classAttribute
public Attribute classAttribute();
From source file:gyc.SMOTEBagging.java
License:Open Source License
/** * Calculates the class membership probabilities for the given test * instance./* ww w.j a v a 2s .c o m*/ * * @param instance the instance to be classified * @return preedicted class probability distribution * @throws Exception if distribution can't be computed successfully */ public double[] distributionForInstance(Instance instance) throws Exception { double[] sums = new double[instance.numClasses()], newProbs; for (int i = 0; i < m_NumIterations; i++) { if (instance.classAttribute().isNumeric() == true) { sums[0] += m_Classifiers[i].classifyInstance(instance); } else { newProbs = m_Classifiers[i].distributionForInstance(instance); for (int j = 0; j < newProbs.length; j++) sums[j] += newProbs[j]; } } if (instance.classAttribute().isNumeric() == true) { sums[0] /= (double) m_NumIterations; return sums; } else if (Utils.eq(Utils.sum(sums), 0)) { return sums; } else { Utils.normalize(sums); return sums; } }
From source file:hr.irb.fastRandomForest.FastRfBagging.java
License:Open Source License
/** * Calculates the class membership probabilities for the given test * instance.//from w ww. j ava2 s .co m * * @param instance the instance to be classified * * @return predicted class probability distribution * * @throws Exception if distribution can't be computed successfully */ @Override public double[] distributionForInstance(Instance instance) throws Exception { double[] sums = new double[instance.numClasses()], newProbs; for (int i = 0; i < m_NumIterations; i++) { if (instance.classAttribute().isNumeric()) { sums[0] += m_Classifiers[i].classifyInstance(instance); } else { newProbs = m_Classifiers[i].distributionForInstance(instance); for (int j = 0; j < newProbs.length; j++) sums[j] += newProbs[j]; } } if (instance.classAttribute().isNumeric()) { sums[0] /= (double) m_NumIterations; return sums; } else if (Utils.eq(Utils.sum(sums), 0)) { return sums; } else { Utils.normalize(sums); return sums; } }
From source file:LogReg.FilteredLogRegClassifier.java
License:Open Source License
/** * Classifies a given instance after filtering. * * @param instance the instance to be classified * @return the class distribution for the given instance * @throws Exception if instance could not be classified * successfully//from w w w . ja va2 s.c om */ public double[] distributionForInstance(Instance instance) throws Exception { /* System.err.println("FilteredClassifier:: " + m_Filter.getClass().getName() + " in: " + instance); */ if (m_Filter.numPendingOutput() > 0) { throw new Exception("Filter output queue not empty!"); } /* String fname = m_Filter.getClass().getName(); fname = fname.substring(fname.lastIndexOf('.') + 1); util.Timer t = util.Timer.getTimer("FilteredClassifier::" + fname); t.start(); */ if (!m_Filter.input(instance)) { if (!m_Filter.mayRemoveInstanceAfterFirstBatchDone()) { throw new Exception("Filter didn't make the test instance" + " immediately available!"); } else { // filter has consumed the instance (e.g. RemoveWithValues // may do this). We will indicate no prediction for this // instance double[] unclassified = null; if (instance.classAttribute().isNumeric()) { unclassified = new double[1]; unclassified[0] = Utils.missingValue(); } else { // all zeros unclassified = new double[instance.classAttribute().numValues()]; } m_Filter.batchFinished(); return unclassified; } } m_Filter.batchFinished(); Instance newInstance = m_Filter.output(); //t.stop(); /* System.err.println("FilteredClassifier:: " + m_Filter.getClass().getName() + " out: " + newInstance); */ return m_Classifier.distributionForInstance(newInstance); }
From source file:lu.lippmann.cdb.lab.mds.MDSViewBuilder.java
License:Open Source License
/** * //from w ww.jav a2 s. c om */ private static void buildFilteredSeries(final MDSResult mdsResult, final XYPlot xyPlot, final String... attrNameToUseAsPointTitle) throws Exception { final CollapsedInstances distMdsRes = mdsResult.getCInstances(); final Instances instances = distMdsRes.getInstances(); final SimpleMatrix coordinates = mdsResult.getCoordinates(); final Instances collapsedInstances = mdsResult.getCollapsedInstances(); int maxSize = 0; if (distMdsRes.isCollapsed()) { final List<Instances> clusters = distMdsRes.getCentroidMap().getClusters(); final int nbCentroids = clusters.size(); maxSize = clusters.get(0).size(); for (int i = 1; i < nbCentroids; i++) { final int currentSize = clusters.get(i).size(); if (currentSize > maxSize) { maxSize = currentSize; } } } Attribute clsAttribute = null; int nbClass = 1; if (instances.classIndex() != -1) { clsAttribute = instances.classAttribute(); nbClass = clsAttribute.numValues(); } final XYSeriesCollection dataset = (XYSeriesCollection) xyPlot.getDataset(); final int fMaxSize = maxSize; final List<XYSeries> lseries = new ArrayList<XYSeries>(); //No class : add one dummy serie if (nbClass <= 1) { lseries.add(new XYSeries("Serie #1", false)); } else { //Some class : add one serie per class for (int i = 0; i < nbClass; i++) { lseries.add(new XYSeries(clsAttribute.value(i), false)); } } dataset.removeAllSeries(); /** * Initialize filtered series */ final List<Instances> filteredInstances = new ArrayList<Instances>(); for (int i = 0; i < lseries.size(); i++) { filteredInstances.add(new Instances(collapsedInstances, 0)); } final Map<Tuple<Integer, Integer>, Integer> correspondanceMap = new HashMap<Tuple<Integer, Integer>, Integer>(); for (int i = 0; i < collapsedInstances.numInstances(); i++) { final Instance oInst = collapsedInstances.instance(i); int indexOfSerie = 0; if (oInst.classIndex() != -1) { if (distMdsRes.isCollapsed()) { indexOfSerie = getStrongestClass(i, distMdsRes); } else { indexOfSerie = (int) oInst.value(oInst.classAttribute()); } } lseries.get(indexOfSerie).add(coordinates.get(i, 0), coordinates.get(i, 1)); filteredInstances.get(indexOfSerie).add(oInst); if (distMdsRes.isCollapsed()) { correspondanceMap.put(new Tuple<Integer, Integer>(indexOfSerie, filteredInstances.get(indexOfSerie).numInstances() - 1), i); } } final List<Paint> colors = new ArrayList<Paint>(); for (final XYSeries series : lseries) { dataset.addSeries(series); } if (distMdsRes.isCollapsed()) { final XYLineAndShapeRenderer xyRenderer = new XYLineAndShapeRenderer(false, true) { private static final long serialVersionUID = -6019883886470934528L; @Override public void drawItem(Graphics2D g2, XYItemRendererState state, java.awt.geom.Rectangle2D dataArea, PlotRenderingInfo info, XYPlot plot, ValueAxis domainAxis, ValueAxis rangeAxis, XYDataset dataset, int series, int item, CrosshairState crosshairState, int pass) { if (distMdsRes.isCollapsed()) { final Integer centroidIndex = correspondanceMap .get(new Tuple<Integer, Integer>(series, item)); final Instances cluster = distMdsRes.getCentroidMap().getClusters().get(centroidIndex); int size = cluster.size(); final int shapeSize = (int) (MAX_POINT_SIZE * size / fMaxSize + 1); final double x1 = plot.getDataset().getX(series, item).doubleValue(); final double y1 = plot.getDataset().getY(series, item).doubleValue(); Map<Object, Integer> mapRepartition = new HashMap<Object, Integer>(); mapRepartition.put("No class", size); if (cluster.classIndex() != -1) { mapRepartition = WekaDataStatsUtil.getClassRepartition(cluster); } final RectangleEdge xAxisLocation = plot.getDomainAxisEdge(); final RectangleEdge yAxisLocation = plot.getRangeAxisEdge(); final double fx = domainAxis.valueToJava2D(x1, dataArea, xAxisLocation); final double fy = rangeAxis.valueToJava2D(y1, dataArea, yAxisLocation); setSeriesShape(series, new Ellipse2D.Double(-shapeSize / 2, -shapeSize / 2, shapeSize, shapeSize)); super.drawItem(g2, state, dataArea, info, plot, domainAxis, rangeAxis, dataset, series, item, crosshairState, pass); //Draw pie if (ENABLE_PIE_SHART) { createPieChart(g2, (int) (fx - shapeSize / 2), (int) (fy - shapeSize / 2), shapeSize, mapRepartition, size, colors); } } else { super.drawItem(g2, state, dataArea, info, plot, domainAxis, rangeAxis, dataset, series, item, crosshairState, pass); } } }; xyPlot.setRenderer(xyRenderer); } final XYToolTipGenerator gen = new XYToolTipGenerator() { @Override public String generateToolTip(XYDataset dataset, int series, int item) { if (distMdsRes.isCollapsed()) { final StringBuilder res = new StringBuilder("<html>"); final Integer centroidIndex = correspondanceMap.get(new Tuple<Integer, Integer>(series, item)); final Instance centroid = distMdsRes.getCentroidMap().getCentroids().get(centroidIndex); final Instances cluster = distMdsRes.getCentroidMap().getClusters().get(centroidIndex); //Set same class index for cluster than for original instances //System.out.println("Cluster index = " + cluster.classIndex() + "/" + instances.classIndex()); cluster.setClassIndex(instances.classIndex()); Map<Object, Integer> mapRepartition = new HashMap<Object, Integer>(); mapRepartition.put("No class", cluster.size()); if (cluster.classIndex() != -1) { mapRepartition = WekaDataStatsUtil.getClassRepartition(cluster); } res.append(InstanceFormatter.htmlFormat(centroid, false)).append("<br/>"); for (final Map.Entry<Object, Integer> entry : mapRepartition.entrySet()) { if (entry.getValue() != 0) { res.append("Class :<b>'" + StringEscapeUtils.escapeHtml(entry.getKey().toString()) + "</b>' -> " + entry.getValue()).append("<br/>"); } } res.append("</html>"); return res.toString(); } else { //return InstanceFormatter.htmlFormat(filteredInstances.get(series).instance(item),true); return InstanceFormatter.shortHtmlFormat(filteredInstances.get(series).instance(item)); } } }; final Shape shape = new Ellipse2D.Float(0f, 0f, MAX_POINT_SIZE, MAX_POINT_SIZE); ((XYLineAndShapeRenderer) xyPlot.getRenderer()).setUseOutlinePaint(true); for (int p = 0; p < nbClass; p++) { xyPlot.getRenderer().setSeriesToolTipGenerator(p, gen); ((XYLineAndShapeRenderer) xyPlot.getRenderer()).setLegendShape(p, shape); xyPlot.getRenderer().setSeriesOutlinePaint(p, Color.BLACK); } for (int ii = 0; ii < nbClass; ii++) { colors.add(xyPlot.getRenderer().getItemPaint(ii, 0)); } if (attrNameToUseAsPointTitle.length > 0) { final Attribute attrToUseAsPointTitle = instances.attribute(attrNameToUseAsPointTitle[0]); if (attrToUseAsPointTitle != null) { final XYItemLabelGenerator lg = new XYItemLabelGenerator() { @Override public String generateLabel(final XYDataset dataset, final int series, final int item) { return filteredInstances.get(series).instance(item).stringValue(attrToUseAsPointTitle); } }; xyPlot.getRenderer().setBaseItemLabelGenerator(lg); xyPlot.getRenderer().setBaseItemLabelsVisible(true); } } }
From source file:lu.lippmann.cdb.lab.mds.UniversalMDS.java
License:Open Source License
public JXPanel buildMDSViewFromDataSet(Instances ds, MDSTypeEnum type) throws Exception { final XYSeriesCollection dataset = new XYSeriesCollection(); final JFreeChart chart = ChartFactory.createScatterPlot("", // title "X", "Y", // axis labels dataset, // dataset PlotOrientation.VERTICAL, true, // legend? yes true, // tooltips? yes false // URLs? no );//from w w w .j av a 2 s . c o m final XYPlot xyPlot = (XYPlot) chart.getPlot(); chart.setTitle(type.name() + " MDS"); Attribute clsAttribute = null; int nbClass = 1; if (ds.classIndex() != -1) { clsAttribute = ds.classAttribute(); nbClass = clsAttribute.numValues(); } final List<XYSeries> lseries = new ArrayList<XYSeries>(); if (nbClass <= 1) { lseries.add(new XYSeries("Serie #1", false)); } else { for (int i = 0; i < nbClass; i++) { lseries.add(new XYSeries(clsAttribute.value(i), false)); } } dataset.removeAllSeries(); /** * Initialize filtered series */ final List<Instances> filteredInstances = new ArrayList<Instances>(); for (int i = 0; i < lseries.size(); i++) { filteredInstances.add(new Instances(ds, 0)); } for (int i = 0; i < ds.numInstances(); i++) { final Instance oInst = ds.instance(i); int indexOfSerie = 0; if (oInst.classIndex() != -1) { indexOfSerie = (int) oInst.value(oInst.classAttribute()); } lseries.get(indexOfSerie).add(coordinates[i][0], coordinates[i][1]); filteredInstances.get(indexOfSerie).add(oInst); } final List<Paint> colors = new ArrayList<Paint>(); for (final XYSeries series : lseries) { dataset.addSeries(series); } final XYToolTipGenerator gen = new XYToolTipGenerator() { @Override public String generateToolTip(XYDataset dataset, int series, int item) { return InstanceFormatter.htmlFormat(filteredInstances.get(series).instance(item), true); } }; final Shape shape = new Ellipse2D.Float(0f, 0f, 5f, 5f); ((XYLineAndShapeRenderer) xyPlot.getRenderer()).setUseOutlinePaint(true); for (int p = 0; p < nbClass; p++) { xyPlot.getRenderer().setSeriesToolTipGenerator(p, gen); ((XYLineAndShapeRenderer) xyPlot.getRenderer()).setLegendShape(p, shape); xyPlot.getRenderer().setSeriesOutlinePaint(p, Color.BLACK); } for (int ii = 0; ii < nbClass; ii++) { colors.add(xyPlot.getRenderer().getItemPaint(ii, 0)); } final ChartPanel chartPanel = new ChartPanel(chart); chartPanel.setMouseWheelEnabled(true); chartPanel.setPreferredSize(new Dimension(1200, 900)); chartPanel.setBorder(new TitledBorder("MDS Projection")); chartPanel.setBackground(Color.WHITE); final JXPanel allPanel = new JXPanel(); allPanel.setLayout(new BorderLayout()); allPanel.add(chartPanel, BorderLayout.CENTER); return allPanel; }
From source file:ml.engine.LibSVM.java
License:Open Source License
/** * Computes the distribution for a given instance. In case of 1-class * classification, 1 is returned at index 0 if libsvm returns 1 and NaN (= * missing) if libsvm returns -1./*from w ww . j av a 2s . c o m*/ * * @param instance the instance for which distribution is computed * @return the distribution * @throws Exception if the distribution can't be computed successfully */ @Override public double[] distributionForInstance(Instance instance) throws Exception { int[] labels = new int[instance.numClasses()]; double[] prob_estimates = null; if (m_ProbabilityEstimates) { invokeMethod(Class.forName(CLASS_SVM).newInstance(), "svm_get_labels", new Class[] { Class.forName(CLASS_SVMMODEL), Array.newInstance(Integer.TYPE, instance.numClasses()).getClass() }, new Object[] { m_Model, labels }); prob_estimates = new double[instance.numClasses()]; } if (!getDoNotReplaceMissingValues()) { m_ReplaceMissingValues.input(instance); m_ReplaceMissingValues.batchFinished(); instance = m_ReplaceMissingValues.output(); } if (m_Filter != null) { m_Filter.input(instance); m_Filter.batchFinished(); instance = m_Filter.output(); } m_NominalToBinary.input(instance); m_NominalToBinary.batchFinished(); instance = m_NominalToBinary.output(); Object x = instanceToArray(instance); double v; double[] result = new double[instance.numClasses()]; if (m_ProbabilityEstimates && ((m_SVMType == SVMTYPE_C_SVC) || (m_SVMType == SVMTYPE_NU_SVC))) { v = ((Double) invokeMethod(Class.forName(CLASS_SVM).newInstance(), "svm_predict_probability", new Class[] { Class.forName(CLASS_SVMMODEL), Array.newInstance(Class.forName(CLASS_SVMNODE), Array.getLength(x)).getClass(), Array.newInstance(Double.TYPE, prob_estimates.length).getClass() }, new Object[] { m_Model, x, prob_estimates })).doubleValue(); // Return order of probabilities to canonical weka attribute order for (int k = 0; k < prob_estimates.length; k++) { result[labels[k]] = prob_estimates[k]; } } else { v = ((Double) invokeMethod(Class.forName(CLASS_SVM).newInstance(), "svm_predict", new Class[] { Class.forName(CLASS_SVMMODEL), Array.newInstance(Class.forName(CLASS_SVMNODE), Array.getLength(x)).getClass() }, new Object[] { m_Model, x })).doubleValue(); if (instance.classAttribute().isNominal()) { if (m_SVMType == SVMTYPE_ONE_CLASS_SVM) { if (v > 0) { result[0] = 1; } else { // outlier (interface for Classifier specifies that unclassified // instances // should return a distribution of all zeros) result[0] = 0; } } else { result[(int) v] = 1; } } else { result[0] = v; } } return result; }
From source file:mlflex.WekaInMemoryLearner.java
License:Open Source License
@Override protected ModelPredictions TrainTest(ArrayList<String> classificationParameters, DataInstanceCollection trainData, DataInstanceCollection testData, DataInstanceCollection dependentVariableInstances) throws Exception { ArrayList<String> dataPointNames = Lists.SortStringList(trainData.GetDataPointNames()); FastVector attVector = GetAttributeVector(dependentVariableInstances, dataPointNames, trainData, testData); Instances wekaTrainingInstances = GetInstances(dependentVariableInstances, attVector, trainData); Instances wekaTestInstances = GetInstances(dependentVariableInstances, attVector, testData); ArrayList<String> dependentVariableClasses = Utilities.ProcessorVault.DependentVariableDataProcessor .GetUniqueDependentVariableValues(); Classifier classifier = GetClassifier(classificationParameters); classifier.buildClassifier(wekaTrainingInstances); Predictions predictions = new Predictions(); for (DataValues testInstance : testData) { String dependentVariableValue = dependentVariableInstances.Get(testInstance.GetID()) .GetDataPointValue(0);//from ww w. j a va2 s . c o m // This is the default before the prediction is made Prediction prediction = new Prediction(testInstance.GetID(), dependentVariableValue, Lists.PickRandomValue(dependentVariableClasses), Lists.CreateDoubleList(0.5, dependentVariableClasses.size())); if (!testInstance.HasOnlyMissingValues()) { Instance wekaTestInstance = GetInstance(wekaTestInstances, attVector, testInstance, null); double clsLabel = classifier.classifyInstance(wekaTestInstance); String predictedClass = wekaTestInstance.classAttribute().value((int) clsLabel); double[] probabilities = classifier.distributionForInstance(wekaTestInstance); ArrayList<Double> classProbabilities = Lists.CreateDoubleList(probabilities); prediction = new Prediction(testInstance.GetID(), dependentVariableValue, predictedClass, classProbabilities); } predictions.Add(prediction); } classifier = null; return new ModelPredictions("", predictions); }
From source file:moa.classifiers.featureselection.OFSL.java
License:Open Source License
@Override public double[] getVotesForInstance(Instance inst) { if (this.weights == null) return (inst.classAttribute().isNominal()) ? new double[2] : new double[1]; double[] result = (inst.classAttribute().isNominal()) ? new double[2] : new double[1]; double f_t = dot(inst.toDoubleArray(), this.weights); f_t += this.bias; if (inst.classAttribute().isNumeric()) { result[0] = f_t;/*from w w w .j av a 2 s.com*/ return result; } if (f_t <= 0) { result[0] = 1; } else { result[1] = 1; } return result; }
From source file:moa.classifiers.featureselection.OFSL.java
License:Open Source License
@Override public void trainOnInstanceImpl(Instance inst) { double y_t, m_bias_p1, m_bias_p2, m_bias; double[] m_weights_p1, m_weights_p2, m_weights; if (this.weights == null) { this.weights = new double[inst.numValues()]; for (int i = 0; i < this.weights.length; i++) this.weights[i] = 0.0; this.bias = 0.0; }//from w ww . j a v a2 s . c o m if (inst.classAttribute().isNominal()) { y_t = (inst.classValue() == 0) ? -1 : 1; } else { y_t = inst.classValue(); } double f_t = dot(inst.toDoubleArray(), this.weights); f_t += this.bias; if (y_t * f_t < 0) { m_weights_p1 = scalar_vector(1.0 - this.stepSizeOption.getValue() * this.learningRateOption.getValue(), this.weights); m_bias_p1 = (1.0 - this.stepSizeOption.getValue() * this.learningRateOption.getValue()) * this.bias; m_weights_p2 = scalar_vector(this.learningRateOption.getValue() * y_t, inst.toDoubleArray()); m_bias_p2 = this.learningRateOption.getValue() * y_t; m_weights = vector_add(m_weights_p1, m_weights_p2); m_bias = m_bias_p1 + m_bias_p2; m_weights = l2_projection(m_weights, m_bias, this.learningRateOption.getValue()); m_weights = truncate(m_weights, this.numSelectOption.getValue()); for (int i = 0; i < m_weights_p1.length; i++) this.weights[i] = m_weights[i]; this.bias = m_weights[m_weights.length - 1]; } else { this.weights = scalar_vector(1.0 - this.stepSizeOption.getValue() * this.learningRateOption.getValue(), this.weights); this.bias = (1.0 - this.stepSizeOption.getValue() * this.learningRateOption.getValue()) * this.bias; } }
From source file:moa.classifiers.featureselection.OFSP.java
License:Open Source License
@Override public double[] getVotesForInstance(Instance inst) { if (this.weights == null) return (inst.classAttribute().isNominal()) ? new double[2] : new double[1]; double[] result = (inst.classAttribute().isNominal()) ? new double[2] : new double[1]; double f_t = 0; int[] indices = new int[this.numSelectOption.getValue()]; if (this.evalOption.getChosenIndex() == 0) { f_t = dot(inst.toDoubleArray(), this.weights); f_t += this.bias; } else {/*from w w w. j a va2 s.co m*/ for (int i = 0; i < this.numSelectOption.getValue(); i++) indices[i] = this.rand.nextInt(inst.numAttributes()); } if (inst.classAttribute().isNumeric()) { result[0] = f_t; return result; } if (f_t <= 0) { result[0] = 1; } else { result[1] = 1; } return result; }