Example usage for weka.core Instance setValue

List of usage examples for weka.core Instance setValue

Introduction

In this page you can find the example usage for weka.core Instance setValue.

Prototype

public void setValue(Attribute att, String value);

Source Link

Document

Sets a value of an nominal or string attribute to the given value.

Usage

From source file:machinelearning_cw.KNN.java

@Override
public void buildClassifier(Instances data) throws Exception {

    if (useStandardisedAttributes) {
        mean = new double[data.numAttributes() - 1];
        standardDeviation = new double[data.numAttributes() - 1];

        // For each data attribute
        for (int i = 0; i < data.numAttributes() - 1; i++) {
            // Calculate mean and Standard deviation
            double[] meanAndStdDev = Helpers.meanAndStandardDeviation(data, i);
            double mean = meanAndStdDev[0];
            double stdDev = meanAndStdDev[1];
            this.mean[i] = mean;
            this.standardDeviation[i] = stdDev;
            isMeanAndStdDevInitialised = true;

            // Standardise the values in all instances for given attribute
            for (Instance eachInstance : data) {
                double value = eachInstance.value(i);
                double standardisedValue = (value - mean) / stdDev;
                // Instead of setValue, use toDoubleArray
                eachInstance.setValue(i, standardisedValue);
            }//from w  w w . j a  va2  s .c  o  m
        }
    }

    trainingData = new Instances(data);

    if (autoDetermineK) {
        determineK();
    }
}

From source file:machinelearning_cw.KNN.java

@Override
public double classifyInstance(Instance instance) throws Exception {
    // Check that classifier has been trained
    if (trainingData == null) {
        throw new Exception("Classifier has not been trained." + " No call to buildClassifier() was made");
    }//from  ww w .  ja  v a 2s  .c o m

    if (useStandardisedAttributes) {
        if (!isMeanAndStdDevInitialised) {
            // throw exception
        } else {
            /* Standardise test instance */
            for (int i = 0; i < instance.numAttributes() - 1; i++) {
                double value = instance.value(i);
                double standardisedValue = (value - mean[i]) / standardDeviation[i];
                instance.setValue(i, standardisedValue);
            }
        }
    }

    if (!useWeightedVoting) {
        return super.classifyInstance(instance);
    } else {

        if (!useAcceleratedNNSearch) {
            /* Calculate euclidean distances */
            double[] distances = Helpers.findEuclideanDistances(trainingData, instance);

            /* 
             * Create a list of dictionaries where each dictionary contains
             * the keys "distance", "weight" and "id".
             * The distance key stores the euclidean distance for an  
             * instance and the id key stores the hashcode for that 
             * instance object.
             */
            ArrayList<HashMap<String, Object>> table = Helpers.buildDistanceTable(trainingData, distances);

            /* Find the k smallest distances */
            Object[] kClosestRows = new Object[k];
            Object[] kClosestInstances = new Object[k];
            double[] classValues = new double[k];

            for (int i = 1; i <= k; i++) {
                ArrayList<Integer> tieIndices = new ArrayList<Integer>();

                /* Find the positions in the table of the ith closest 
                 * neighbour.
                 */
                int[] closestRowIndices = this.findNthClosestNeighbourByWeights(table, i);

                if (closestRowIndices.length > 0) {
                    /* Keep track of distance ties */
                    for (int j = 0; j < closestRowIndices.length; j++) {
                        tieIndices.add(closestRowIndices[j]);
                    }

                    /* Break ties (by choosing winner at random) */
                    Random rand = new Random();
                    int matchingNeighbourPosition = tieIndices.get(rand.nextInt(tieIndices.size()));
                    HashMap<String, Object> matchingRow = table.get(matchingNeighbourPosition);
                    kClosestRows[i - 1] = matchingRow;
                }
            }

            /* 
             * Find the closestInstances from their rows in the table and 
             * also get their class values.
             */
            for (int i = 0; i < kClosestRows.length; i++) {
                /* Build up closestInstances array */
                for (int j = 0; j < trainingData.numInstances(); j++) {
                    Instance inst = trainingData.get(j);
                    HashMap<String, Object> row = (HashMap<String, Object>) kClosestRows[i];
                    if (Integer.toHexString(inst.hashCode()).equals(row.get("id"))) {
                        kClosestInstances[i] = inst;
                    }
                }
            }

            /* Vote by weights */
            /* Get max class value */
            double[] possibleClassValues = trainingData.attributeToDoubleArray(trainingData.classIndex());
            int maxClassIndex = Utils.maxIndex(possibleClassValues);
            double maxClassValue = possibleClassValues[maxClassIndex];
            ArrayList<Double> weightedVotes = new ArrayList<Double>();

            /* Calculate the sum of votes for each class */
            for (double i = 0; i <= maxClassValue; i++) {
                double weightCount = 0;

                /* Calculate sum */
                for (int j = 0; j < kClosestInstances.length; j++) {
                    Instance candidateInstance = (Instance) kClosestInstances[j];
                    if (candidateInstance.classValue() == i) {
                        // Get weight
                        HashMap<String, Object> row = (HashMap<String, Object>) kClosestRows[(int) j];
                        weightCount += (double) row.get("weight");
                    }
                }

                weightedVotes.add(weightCount);
            }

            /* Select instance with highest vote */
            Double[] votesArray = new Double[weightedVotes.size()];
            weightedVotes.toArray(votesArray);
            double greatestSoFar = votesArray[0];
            int greatestIndex = 0;
            for (int i = 0; i < votesArray.length; i++) {
                if (votesArray[i] > greatestSoFar) {
                    greatestSoFar = votesArray[i];
                    greatestIndex = i;
                }
            }

            /* 
             * Class value will be the index because classes are indexed 
             * from 0 upwards.
             */
            return greatestIndex;

        }
        /* Use Orchards algorithm to accelerate NN search */
        else {
            // find k nearest neighbours
            ArrayList<Instance> nearestNeighbours = new ArrayList<Instance>();
            for (int i = 0; i < k; i++) {
                nearestNeighbours.add(findNthClosestWithOrchards(instance, trainingData, i));
            }

            // Find their class values
            double[] classValues = new double[nearestNeighbours.size()];

            for (int i = 0; i < nearestNeighbours.size(); i++) {
                classValues[i] = nearestNeighbours.get(i).classValue();
            }

            return Helpers.mode(Helpers.arrayToArrayList(classValues));
        }

    }

}

From source file:machine_learing_clasifier.MyC45.java

@Override
public void buildClassifier(Instances i) throws Exception {
    if (!i.classAttribute().isNominal()) {
        throw new Exception("Class not nominal");
    }/*  www  .  j av a2 s . co  m*/

    //penanganan missing value
    for (int j = 0; j < i.numAttributes(); j++) {
        Attribute attr = i.attribute(j);
        for (int k = 0; k < i.numInstances(); k++) {
            Instance inst = i.instance(k);
            if (inst.isMissing(attr)) {
                inst.setValue(attr, fillMissingValue(i, attr));
                //bisa dituning lagi performancenya
            }
        }
    }

    i = new Instances(i);
    i.deleteWithMissingClass();
    makeTree(i);
}

From source file:marytts.tools.newlanguage.LTSTrainer.java

License:Open Source License

/**
 * Train the tree, using binary decision nodes.
 * //ww w .  j a v a2 s  .  com
 * @param minLeafData
 *            the minimum number of instances that have to occur in at least two subsets induced by split
 * @return bigTree
 * @throws IOException
 *             IOException
 */
public CART trainTree(int minLeafData) throws IOException {

    Map<String, List<String[]>> grapheme2align = new HashMap<String, List<String[]>>();
    for (String gr : this.graphemeSet) {
        grapheme2align.put(gr, new ArrayList<String[]>());
    }

    Set<String> phChains = new HashSet<String>();

    // for every alignment pair collect counts
    for (int i = 0; i < this.inSplit.size(); i++) {

        StringPair[] alignment = this.getAlignment(i);

        for (int inNr = 0; inNr < alignment.length; inNr++) {

            // System.err.println(alignment[inNr]);

            // quotation signs needed to represent empty string
            String outAlNr = "'" + alignment[inNr].getString2() + "'";

            // TODO: don't consider alignments to more than three characters
            if (outAlNr.length() > 5)
                continue;

            phChains.add(outAlNr);

            // storing context and target
            String[] datapoint = new String[2 * context + 2];

            for (int ct = 0; ct < 2 * context + 1; ct++) {
                int pos = inNr - context + ct;

                if (pos >= 0 && pos < alignment.length) {
                    datapoint[ct] = alignment[pos].getString1();
                } else {
                    datapoint[ct] = "null";
                }

            }

            // set target
            datapoint[2 * context + 1] = outAlNr;

            // add datapoint
            grapheme2align.get(alignment[inNr].getString1()).add(datapoint);
        }
    }

    // for conversion need feature definition file
    FeatureDefinition fd = this.graphemeFeatureDef(phChains);

    int centerGrapheme = fd.getFeatureIndex("att" + (context + 1));

    List<CART> stl = new ArrayList<CART>(fd.getNumberOfValues(centerGrapheme));

    for (String gr : fd.getPossibleValues(centerGrapheme)) {
        System.out.println("      Training decision tree for: " + gr);
        logger.debug("      Training decision tree for: " + gr);

        ArrayList<Attribute> attributeDeclarations = new ArrayList<Attribute>();

        // attributes with values
        for (int att = 1; att <= context * 2 + 1; att++) {

            // ...collect possible values
            ArrayList<String> attVals = new ArrayList<String>();

            String featureName = "att" + att;

            for (String usableGrapheme : fd.getPossibleValues(fd.getFeatureIndex(featureName))) {
                attVals.add(usableGrapheme);
            }

            attributeDeclarations.add(new Attribute(featureName, attVals));
        }

        List<String[]> datapoints = grapheme2align.get(gr);

        // maybe training is faster with targets limited to grapheme
        Set<String> graphSpecPh = new HashSet<String>();
        for (String[] dp : datapoints) {
            graphSpecPh.add(dp[dp.length - 1]);
        }

        // targetattribute
        // ...collect possible values
        ArrayList<String> targetVals = new ArrayList<String>();
        for (String phc : graphSpecPh) {// todo: use either fd of phChains
            targetVals.add(phc);
        }
        attributeDeclarations.add(new Attribute(TrainedLTS.PREDICTED_STRING_FEATURENAME, targetVals));

        // now, create the dataset adding the datapoints
        Instances data = new Instances(gr, attributeDeclarations, 0);

        // datapoints
        for (String[] point : datapoints) {

            Instance currInst = new DenseInstance(data.numAttributes());
            currInst.setDataset(data);

            for (int i = 0; i < point.length; i++) {

                currInst.setValue(i, point[i]);
            }

            data.add(currInst);
        }

        // Make the last attribute be the class
        data.setClassIndex(data.numAttributes() - 1);

        // build the tree without using the J48 wrapper class
        // standard parameters are:
        // binary split selection with minimum x instances at the leaves, tree is pruned, confidenced value, subtree raising,
        // cleanup, don't collapse
        // Here is used a modifed version of C45PruneableClassifierTree that allow using Unary Classes (see Issue #51)
        C45PruneableClassifierTree decisionTree;
        try {
            decisionTree = new C45PruneableClassifierTreeWithUnary(
                    new BinC45ModelSelection(minLeafData, data, true), true, 0.25f, true, true, false);
            decisionTree.buildClassifier(data);
        } catch (Exception e) {
            throw new RuntimeException("couldn't train decisiontree using weka: ", e);
        }

        CART maryTree = TreeConverter.c45toStringCART(decisionTree, fd, data);

        stl.add(maryTree);
    }

    DecisionNode.ByteDecisionNode rootNode = new DecisionNode.ByteDecisionNode(centerGrapheme, stl.size(), fd);
    for (CART st : stl) {
        rootNode.addDaughter(st.getRootNode());
    }

    Properties props = new Properties();
    props.setProperty("lowercase", String.valueOf(convertToLowercase));
    props.setProperty("stress", String.valueOf(considerStress));
    props.setProperty("context", String.valueOf(context));

    CART bigTree = new CART(rootNode, fd, props);

    return bigTree;
}

From source file:marytts.tools.voiceimport.PauseDurationTrainer.java

License:Open Source License

private Instances enterDurations(Instances data, List<Integer> durs) {

    // System.out.println("discretizing durations...");

    // now discretize and set target attributes (= pause durations)
    // for that, first train discretizer
    GmmDiscretizer discr = GmmDiscretizer.trainDiscretizer(durs, 6, true);

    // used to store the collected values
    ArrayList<String> targetVals = new ArrayList<String>();

    for (int mappedDur : discr.getPossibleValues()) {
        targetVals.add(mappedDur + "ms");
    }//from  w  w  w . ja va2 s .  c  o m

    // FastVector attributeDeclarations = data.;

    // attribute declaration finished
    data.insertAttributeAt(new Attribute("target", targetVals), data.numAttributes());

    for (int i = 0; i < durs.size(); i++) {

        Instance currInst = data.instance(i);
        int dur = durs.get(i);

        // System.out.println(" mapping " + dur + " to " + discr.discretize(dur) + " - bi:" +
        // data.instance(i).value(data.attribute("breakindex")));

        currInst.setValue(data.numAttributes() - 1, discr.discretize(dur) + "ms");

    }

    // Make the last attribute be the class
    data.setClassIndex(data.numAttributes() - 1);

    return data;
}

From source file:marytts.tools.voiceimport.PauseDurationTrainer.java

License:Open Source License

private Instance createInstance(Instances data, FeatureDefinition fd, FeatureVector fv) {
    // relevant features + one target
    Instance currInst = new DenseInstance(data.numAttributes());
    currInst.setDataset(data);/* w ww.  j  a v  a  2s.  com*/

    // read only relevant features
    for (String attName : this.featureNames) {
        int featNr = fd.getFeatureIndex(attName);

        String value = fv.getFeatureAsString(featNr, fd);
        currInst.setValue(data.attribute(attName), value);
    }

    return currInst;
}

From source file:matres.MatResUI.java

private void doClassification() {
    J48 m_treeResiko;// w  w w .j  a  v a  2 s . c o  m
    J48 m_treeAksi;
    NaiveBayes m_nbResiko;
    NaiveBayes m_nbAksi;
    FastVector m_fvInstanceRisks;
    FastVector m_fvInstanceActions;

    InputStream isRiskTree = getClass().getResourceAsStream("data/ResikoTree.model");
    InputStream isRiskNB = getClass().getResourceAsStream("data/ResikoNB.model");
    InputStream isActionTree = getClass().getResourceAsStream("data/AksiTree.model");
    InputStream isActionNB = getClass().getResourceAsStream("data/AksiNB.model");

    m_treeResiko = new J48();
    m_treeAksi = new J48();
    m_nbResiko = new NaiveBayes();
    m_nbAksi = new NaiveBayes();
    try {
        //m_treeResiko = (J48) weka.core.SerializationHelper.read("ResikoTree.model");
        m_treeResiko = (J48) weka.core.SerializationHelper.read(isRiskTree);
        //m_nbResiko = (NaiveBayes) weka.core.SerializationHelper.read("ResikoNB.model");
        m_nbResiko = (NaiveBayes) weka.core.SerializationHelper.read(isRiskNB);
        //m_treeAksi = (J48) weka.core.SerializationHelper.read("AksiTree.model");
        m_treeAksi = (J48) weka.core.SerializationHelper.read(isActionTree);
        //m_nbAksi = (NaiveBayes) weka.core.SerializationHelper.read("AksiNB.model");
        m_nbAksi = (NaiveBayes) weka.core.SerializationHelper.read(isActionNB);
    } catch (Exception ex) {
        Logger.getLogger(MatResUI.class.getName()).log(Level.SEVERE, null, ex);
    }

    System.out.println("Setting up an Instance...");
    // Values for LIKELIHOOD OF OCCURRENCE
    FastVector fvLO = new FastVector(5);
    fvLO.addElement("> 10 in 1 year");
    fvLO.addElement("1 - 10 in 1 year");
    fvLO.addElement("1 in 1 year to 1 in 10 years");
    fvLO.addElement("1 in 10 years to 1 in 100 years");
    fvLO.addElement("1 in more than 100 years");
    // Values for SAFETY
    FastVector fvSafety = new FastVector(5);
    fvSafety.addElement("near miss");
    fvSafety.addElement("first aid injury, medical aid injury");
    fvSafety.addElement("lost time injury / temporary disability");
    fvSafety.addElement("permanent disability");
    fvSafety.addElement("fatality");
    // Values for EXTRA FUEL COST
    FastVector fvEFC = new FastVector(5);
    fvEFC.addElement("< 100 million rupiah");
    fvEFC.addElement("0,1 - 1 billion rupiah");
    fvEFC.addElement("1 - 10 billion rupiah");
    fvEFC.addElement("10 - 100  billion rupiah");
    fvEFC.addElement("> 100 billion rupiah");
    // Values for SYSTEM RELIABILITY
    FastVector fvSR = new FastVector(5);
    fvSR.addElement("< 100 MWh");
    fvSR.addElement("0,1 - 1 GWh");
    fvSR.addElement("1 - 10 GWh");
    fvSR.addElement("10 - 100 GWh");
    fvSR.addElement("> 100 GWh");
    // Values for EQUIPMENT COST
    FastVector fvEC = new FastVector(5);
    fvEC.addElement("< 50 million rupiah");
    fvEC.addElement("50 - 500 million rupiah");
    fvEC.addElement("0,5 - 5 billion rupiah");
    fvEC.addElement("5 -50 billion rupiah");
    fvEC.addElement("> 50 billion rupiah");
    // Values for CUSTOMER SATISFACTION SOCIAL FACTOR
    FastVector fvCSSF = new FastVector(5);
    fvCSSF.addElement("Complaint from the VIP customer");
    fvCSSF.addElement("Complaint from industrial customer");
    fvCSSF.addElement("Complaint from community");
    fvCSSF.addElement("Complaint from community that have potential riot");
    fvCSSF.addElement("High potential riot");
    // Values for RISK
    FastVector fvRisk = new FastVector(4);
    fvRisk.addElement("Low");
    fvRisk.addElement("Moderate");
    fvRisk.addElement("High");
    fvRisk.addElement("Extreme");
    // Values for ACTION
    FastVector fvAction = new FastVector(3);
    fvAction.addElement("Life Extension Program");
    fvAction.addElement("Repair/Refurbish");
    fvAction.addElement("Replace/Run to Fail + Investment");

    // Defining Attributes, including Class(es) Attributes
    Attribute attrLO = new Attribute("LO", fvLO);
    Attribute attrSafety = new Attribute("Safety", fvSafety);
    Attribute attrEFC = new Attribute("EFC", fvEFC);
    Attribute attrSR = new Attribute("SR", fvSR);
    Attribute attrEC = new Attribute("EC", fvEC);
    Attribute attrCSSF = new Attribute("CSSF", fvCSSF);
    Attribute attrRisk = new Attribute("Risk", fvRisk);
    Attribute attrAction = new Attribute("Action", fvAction);

    m_fvInstanceRisks = new FastVector(7);
    m_fvInstanceRisks.addElement(attrLO);
    m_fvInstanceRisks.addElement(attrSafety);
    m_fvInstanceRisks.addElement(attrEFC);
    m_fvInstanceRisks.addElement(attrSR);
    m_fvInstanceRisks.addElement(attrEC);
    m_fvInstanceRisks.addElement(attrCSSF);
    m_fvInstanceRisks.addElement(attrRisk);

    m_fvInstanceActions = new FastVector(7);
    m_fvInstanceActions.addElement(attrLO);
    m_fvInstanceActions.addElement(attrSafety);
    m_fvInstanceActions.addElement(attrEFC);
    m_fvInstanceActions.addElement(attrSR);
    m_fvInstanceActions.addElement(attrEC);
    m_fvInstanceActions.addElement(attrCSSF);
    m_fvInstanceActions.addElement(attrAction);

    Instances dataRisk = new Instances("A-Risk-instance-to-classify", m_fvInstanceRisks, 0);
    Instances dataAction = new Instances("An-Action-instance-to-classify", m_fvInstanceActions, 0);
    double[] riskValues = new double[dataRisk.numAttributes()];
    double[] actionValues = new double[dataRisk.numAttributes()];

    String strLO = (String) m_cmbLO.getSelectedItem();
    String strSafety = (String) m_cmbSafety.getSelectedItem();
    String strEFC = (String) m_cmbEFC.getSelectedItem();
    String strSR = (String) m_cmbSR.getSelectedItem();
    String strEC = (String) m_cmbEC.getSelectedItem();
    String strCSSF = (String) m_cmbCSSF.getSelectedItem();

    Instance instRisk = new DenseInstance(7);
    Instance instAction = new DenseInstance(7);

    if (strLO.equals("-- none --")) {
        instRisk.setMissing(0);
        instAction.setMissing(0);
    } else {
        instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(0), strLO);
        instAction.setValue((Attribute) m_fvInstanceActions.elementAt(0), strLO);
    }
    if (strSafety.equals("-- none --")) {
        instRisk.setMissing(1);
        instAction.setMissing(1);
    } else {
        instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(1), strSafety);
        instAction.setValue((Attribute) m_fvInstanceActions.elementAt(1), strSafety);
    }
    if (strEFC.equals("-- none --")) {
        instRisk.setMissing(2);
        instAction.setMissing(2);
    } else {
        instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(2), strEFC);
        instAction.setValue((Attribute) m_fvInstanceActions.elementAt(2), strEFC);
    }
    if (strSR.equals("-- none --")) {
        instRisk.setMissing(3);
        instAction.setMissing(3);
    } else {
        instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(3), strSR);
        instAction.setValue((Attribute) m_fvInstanceActions.elementAt(3), strSR);
    }
    if (strEC.equals("-- none --")) {
        instRisk.setMissing(4);
        instAction.setMissing(4);
    } else {
        instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(4), strEC);
        instAction.setValue((Attribute) m_fvInstanceActions.elementAt(4), strEC);
    }
    if (strCSSF.equals("-- none --")) {
        instRisk.setMissing(5);
        instAction.setMissing(5);
    } else {
        instAction.setValue((Attribute) m_fvInstanceActions.elementAt(5), strCSSF);
        instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(5), strCSSF);
    }
    instRisk.setMissing(6);
    instAction.setMissing(6);

    dataRisk.add(instRisk);
    instRisk.setDataset(dataRisk);
    dataRisk.setClassIndex(dataRisk.numAttributes() - 1);

    dataAction.add(instAction);
    instAction.setDataset(dataAction);
    dataAction.setClassIndex(dataAction.numAttributes() - 1);

    System.out.println("Instance Resiko: " + dataRisk.instance(0));
    System.out.println("\tNum Attributes : " + dataRisk.numAttributes());
    System.out.println("\tNum instances  : " + dataRisk.numInstances());
    System.out.println("Instance Action: " + dataAction.instance(0));
    System.out.println("\tNum Attributes : " + dataAction.numAttributes());
    System.out.println("\tNum instances  : " + dataAction.numInstances());

    int classIndexRisk = 0;
    int classIndexAction = 0;
    String strClassRisk = null;
    String strClassAction = null;

    try {
        //classIndexRisk = (int) m_treeResiko.classifyInstance(dataRisk.instance(0));
        classIndexRisk = (int) m_treeResiko.classifyInstance(instRisk);
        classIndexAction = (int) m_treeAksi.classifyInstance(instAction);
    } catch (Exception ex) {
        Logger.getLogger(MatResUI.class.getName()).log(Level.SEVERE, null, ex);
    }

    strClassRisk = (String) fvRisk.elementAt(classIndexRisk);
    strClassAction = (String) fvAction.elementAt(classIndexAction);
    System.out.println("[Risk  Class Index: " + classIndexRisk + " Class Label: " + strClassRisk + "]");
    System.out.println("[Action  Class Index: " + classIndexAction + " Class Label: " + strClassAction + "]");
    if (strClassRisk != null) {
        m_txtRisk.setText(strClassRisk);
    }

    double[] riskDist = null;
    double[] actionDist = null;
    try {
        riskDist = m_nbResiko.distributionForInstance(dataRisk.instance(0));
        actionDist = m_nbAksi.distributionForInstance(dataAction.instance(0));
        String strProb;
        // set up RISK progress bars
        m_jBarRiskLow.setValue((int) (100 * riskDist[0]));
        m_jBarRiskLow.setString(String.format("%6.3f%%", 100 * riskDist[0]));
        m_jBarRiskModerate.setValue((int) (100 * riskDist[1]));
        m_jBarRiskModerate.setString(String.format("%6.3f%%", 100 * riskDist[1]));
        m_jBarRiskHigh.setValue((int) (100 * riskDist[2]));
        m_jBarRiskHigh.setString(String.format("%6.3f%%", 100 * riskDist[2]));
        m_jBarRiskExtreme.setValue((int) (100 * riskDist[3]));
        m_jBarRiskExtreme.setString(String.format("%6.3f%%", 100 * riskDist[3]));
    } catch (Exception ex) {
        Logger.getLogger(MatResUI.class.getName()).log(Level.SEVERE, null, ex);
    }

    double predictedProb = 0.0;
    String predictedClass = "";

    // Loop over all the prediction labels in the distribution.
    for (int predictionDistributionIndex = 0; predictionDistributionIndex < riskDist.length; predictionDistributionIndex++) {
        // Get this distribution index's class label.
        String predictionDistributionIndexAsClassLabel = dataRisk.classAttribute()
                .value(predictionDistributionIndex);
        int classIndex = dataRisk.classAttribute().indexOfValue(predictionDistributionIndexAsClassLabel);
        // Get the probability.
        double predictionProbability = riskDist[predictionDistributionIndex];

        if (predictionProbability > predictedProb) {
            predictedProb = predictionProbability;
            predictedClass = predictionDistributionIndexAsClassLabel;
        }

        System.out.printf("[%2d %10s : %6.3f]", classIndex, predictionDistributionIndexAsClassLabel,
                predictionProbability);
    }
    m_txtRiskNB.setText(predictedClass);
}

From source file:meka.classifiers.multilabel.cc.CNode.java

License:Open Source License

/**
 * Transform - turn [y1,y2,y3,x1,x2] into [y1,y2,x1,x2].
 * @return transformed Instance/*from w w w .j  a  v  a2s . com*/
 */
public Instance transform(Instance x, double ypred[]) throws Exception {
    x = (Instance) x.copy();
    int L = x.classIndex();
    int L_c = (paY.length + 1);
    x.setDataset(null);
    for (int j = 0; j < (L - L_c); j++) {
        x.deleteAttributeAt(0);
    }
    for (int pa : paY) {
        //System.out.println("x_["+map[pa]+"] <- "+ypred[pa]);
        x.setValue(map[pa], ypred[pa]);
    }
    x.setDataset(T);
    x.setClassMissing();
    return x;
}

From source file:meka.classifiers.multilabel.cc.CNode.java

License:Open Source License

public void updateTransform(Instance t_, double ypred[]) throws Exception {
    for (int pa : this.paY) {
        t_.setValue(this.map[pa], ypred[pa]);
    }//from ww  w .  j  a va2  s .c o m
}

From source file:meka.classifiers.multilabel.CDN.java

License:Open Source License

@Override
public double[] distributionForInstance(Instance x) throws Exception {

    int L = x.classIndex();
    //ArrayList<double[]> collection = new ArrayList<double[]>(100);

    double y[] = new double[L]; // for collectiing marginal
    int sequence[] = A.make_sequence(L);

    double likelihood[] = new double[L];

    for (int i = 0; i < I; i++) {
        Collections.shuffle(Arrays.asList(sequence));
        for (int j : sequence) {
            // x = [x,y[1],...,y[j-1],y[j+1],...,y[L]]
            x.setDataset(D_templates[j]);
            // q = h_j(x)    i.e. p(y_j | x)

            double dist[] = h[j].distributionForInstance(x);
            int k = A.samplePMF(dist, m_R);
            x.setValue(j, k);
            likelihood[j] = dist[k];/* w  w  w  .  j  a v  a  2s.  com*/
            // likelihood
            double s = Utils.sum(likelihood);
            // collect  // and where is is good 
            if (i > (I - I_c)) {
                y[j] += x.value(j);
            }
            // else still burning in
        }
    }
    // finish, calculate marginals
    for (int j = 0; j < L; j++) {
        y[j] /= I_c;
    }

    return y;
}