Example usage for weka.core Instances setClassIndex

List of usage examples for weka.core Instances setClassIndex

Introduction

In this page you can find the example usage for weka.core Instances setClassIndex.

Prototype

public void setClassIndex(int classIndex) 

Source Link

Document

Sets the class index of the set.

Usage

From source file:edu.utexas.cs.tactex.utils.RegressionUtils.java

License:Open Source License

/**
 * adding y attributes without giving it values
 *///  ww  w  .  j  ava 2  s  . co m
public static Instances addYforWeka(Instances xInsts) {

    // add another column for y
    int n = xInsts.numAttributes();
    xInsts.insertAttributeAt(new Attribute(Integer.toString(n)), n);

    // last attribute is y value, the class 'label'
    xInsts.setClassIndex(n);

    return xInsts;
}

From source file:edu.washington.cs.knowitall.summarization.RedundancyClassifier.java

License:Open Source License

public Instances setupInstances(StringReader testReader) {

    Instances instances = null;
    try {//from w  w w .  ja v  a 2s .  c om
        instances = new Instances(testReader);
    } catch (IOException e) {
        e.printStackTrace();
    }
    instances.setClassIndex(instances.numAttributes() - 1);
    testReader.close();
    return instances;
}

From source file:edu.washington.cs.knowitall.utilities.Classifier.java

License:Open Source License

/**
 * Set up the instances from the reader// ww w  . jav  a 2 s .  co  m
 * @param instanceReader the source of the instances
 * @return the instances object
 */
public Instances setupInstances(Reader instanceReader) {
    Instances instances = null;
    try {
        instances = new Instances(instanceReader);
    } catch (IOException e) {
        e.printStackTrace();
    }
    instances.setClassIndex(instances.numAttributes() - 1);
    try {
        instanceReader.close();
    } catch (IOException e) {
        System.err.println("could not close reader");
        e.printStackTrace();
        System.exit(1);
    }
    return instances;
}

From source file:elh.eus.absa.CLI.java

License:Open Source License

/**
 * Main access to the train-atc functionalities.
 * Train ATC using a single classifier (one vs. all) for E#A aspect categories.
 * //  ww w.  j a  v a 2 s.com
 * @throws Exception 
 */
public final void trainATC(final InputStream inputStream) throws IOException {
    // load training parameters file
    String paramFile = parsedArguments.getString("params");
    String corpusFormat = parsedArguments.getString("corpusFormat");
    //String validation = parsedArguments.getString("validation");
    int foldNum = Integer.parseInt(parsedArguments.getString("foldNum"));
    String lang = parsedArguments.getString("language");
    //boolean printPreds = parsedArguments.getBoolean("printPreds");
    boolean nullSentenceOpinions = parsedArguments.getBoolean("nullSentences");
    //double threshold = 0.2;
    //String modelsPath = "/home/inaki/Proiektuak/BOM/SEMEVAL2015/ovsaModels";

    CorpusReader reader = new CorpusReader(inputStream, corpusFormat, nullSentenceOpinions, lang);
    Features atcTrain = new Features(reader, paramFile, "3");
    Instances traindata = atcTrain.loadInstances(true, "atc");

    //setting class attribute (entCat|attCat|entAttCat|polarityCat)

    //HashMap<String, Integer> opInst = atcTrain.getOpinInst();
    WekaWrapper classifyEnts;
    WekaWrapper classifyAtts;
    //WekaWrapper onevsall;
    try {
        //train first classifier (entities)
        Instances traindataEnt = new Instances(traindata);
        // IMPORTANT: filter indexes are added 1 because weka remove function counts attributes from 1, 
        traindataEnt.setClassIndex(traindataEnt.attribute("entCat").index());
        classifyEnts = new WekaWrapper(traindataEnt, true);
        String filtRange = String.valueOf(traindata.attribute("attCat").index() + 1) + ","
                + String.valueOf(traindata.attribute("entAttCat").index() + 1);
        classifyEnts.filterAttribute(filtRange);

        System.out.println("trainATC: entity classifier results -> ");
        classifyEnts.crossValidate(foldNum);
        classifyEnts.saveModel("elixa-atc_ent-" + lang + ".model");

        //Classifier entityCl = classify.getMLclass();

        //train second classifier (attributes)
        Instances traindataAtt = new Instances(traindata);
        traindataAtt.setClassIndex(traindataAtt.attribute("attCat").index());
        classifyAtts = new WekaWrapper(traindataAtt, true);
        filtRange = String.valueOf(traindataAtt.attribute("entAttCat").index() + 1);
        classifyAtts.filterAttribute(filtRange);

        System.out.println("trainATC: attribute classifier results -> ");
        classifyAtts.crossValidate(foldNum);
        classifyAtts.saveModel("elixa-atc_att-" + lang + ".model");
        /*
        Instances traindataEntadded = classifyEnts.addClassification(classifyEnts.getMLclass(), traindataEnt);
        //train second classifier (entCat attributes will have the values of the entities always)
        traindataEntadded.setClassIndex(traindataEntadded.attribute("attCat").index());
        WekaWrapper classify2 = new WekaWrapper(traindataEntadded, true);
        System.out.println("trainATC: enhanced attribute classifier results -> ");
        classify2.saveModel("elixa-atc_att_enhanced.model");
        classify2.crossValidate(foldNum);      
        */
        //classify.printMultilabelPredictions(classify.multiLabelPrediction());      */   

        //reader.print2Semeval2015format(paramFile+"entAttCat.xml");
    } catch (Exception e) {
        e.printStackTrace();
    }

    //traindata.setClass(traindata.attribute("entAttCat"));
    System.err.println("DONE CLI train-atc");
}

From source file:elh.eus.absa.CLI.java

License:Open Source License

/**
 * Main access to the train-atc functionalities. Train ATC using a double one vs. all classifier
 * (E and A) for E#A aspect categories//  w w  w. j  av a  2  s.  co  m
 * @throws Exception 
 */
public final void trainATC2(final InputStream inputStream) throws IOException {
    // load training parameters file
    String paramFile = parsedArguments.getString("params");
    String testFile = parsedArguments.getString("testset");
    String paramFile2 = parsedArguments.getString("params2");
    String corpusFormat = parsedArguments.getString("corpusFormat");
    //String validation = parsedArguments.getString("validation");
    String lang = parsedArguments.getString("language");
    //int foldNum = Integer.parseInt(parsedArguments.getString("foldNum"));
    //boolean printPreds = parsedArguments.getBoolean("printPreds");
    boolean nullSentenceOpinions = parsedArguments.getBoolean("nullSentences");
    boolean onlyTest = parsedArguments.getBoolean("testOnly");
    double threshold = 0.5;
    double threshold2 = 0.5;
    String modelsPath = "/home/inaki/elixa-atp/ovsaModels";

    CorpusReader reader = new CorpusReader(inputStream, corpusFormat, nullSentenceOpinions, lang);
    Features atcTrain = new Features(reader, paramFile, "3");
    Instances traindata = atcTrain.loadInstances(true, "atc");

    if (onlyTest) {
        if (FileUtilsElh.checkFile(testFile)) {
            System.err.println("read from test file");
            reader = new CorpusReader(new FileInputStream(new File(testFile)), corpusFormat,
                    nullSentenceOpinions, lang);
            atcTrain.setCorpus(reader);
            traindata = atcTrain.loadInstances(true, "atc");
        }
    }

    //setting class attribute (entCat|attCat|entAttCat|polarityCat)

    //HashMap<String, Integer> opInst = atcTrain.getOpinInst();      
    //WekaWrapper classifyAtts;
    WekaWrapper onevsall;
    try {

        //classify.printMultilabelPredictions(classify.multiLabelPrediction());      */   

        //onevsall
        Instances entdata = new Instances(traindata);
        entdata.deleteAttributeAt(entdata.attribute("attCat").index());
        entdata.deleteAttributeAt(entdata.attribute("entAttCat").index());
        entdata.setClassIndex(entdata.attribute("entCat").index());
        onevsall = new WekaWrapper(entdata, true);

        if (!onlyTest) {
            onevsall.trainOneVsAll(modelsPath, paramFile + "entCat");
            System.out.println("trainATC: one vs all models ready");
        }
        onevsall.setTestdata(entdata);
        HashMap<Integer, HashMap<String, Double>> ovsaRes = onevsall.predictOneVsAll(modelsPath,
                paramFile + "entCat");
        System.out.println("trainATC: one vs all predictions ready");
        HashMap<Integer, String> instOps = new HashMap<Integer, String>();
        for (String oId : atcTrain.getOpinInst().keySet()) {
            instOps.put(atcTrain.getOpinInst().get(oId), oId);
        }

        atcTrain = new Features(reader, paramFile2, "3");
        entdata = atcTrain.loadInstances(true, "attTrain2_data");
        entdata.deleteAttributeAt(entdata.attribute("entAttCat").index());
        //entdata.setClassIndex(entdata.attribute("entCat").index());

        Attribute insAtt = entdata.attribute("instanceId");
        double maxInstId = entdata.kthSmallestValue(insAtt, entdata.numDistinctValues(insAtt) - 1);
        System.err.println("last instance has index: " + maxInstId);
        for (int ins = 0; ins < entdata.numInstances(); ins++) {
            System.err.println("ins" + ins);
            int i = (int) entdata.instance(ins).value(insAtt);
            Instance currentInst = entdata.instance(ins);
            //System.err.println("instance "+i+" oid "+kk.get(i+1)+"kk contains key i?"+kk.containsKey(i));
            String sId = reader.getOpinion(instOps.get(i)).getsId();
            String oId = instOps.get(i);
            reader.removeSentenceOpinions(sId);
            int oSubId = 0;
            for (String cl : ovsaRes.get(i).keySet()) {
                //System.err.println("instance: "+i+" class "+cl+" value: "+ovsaRes.get(i).get(cl));
                if (ovsaRes.get(i).get(cl) > threshold) {
                    //System.err.println("one got through ! instance "+i+" class "+cl+" value: "+ovsaRes.get(i).get(cl));                  
                    // for the first one update the instances
                    if (oSubId >= 1) {
                        Instance newIns = new SparseInstance(currentInst);
                        newIns.setDataset(entdata);
                        entdata.add(newIns);
                        newIns.setValue(insAtt, maxInstId + oSubId);
                        newIns.setClassValue(cl);
                        instOps.put((int) maxInstId + oSubId, oId);

                    }
                    // if the are more create new instances
                    else {
                        currentInst.setClassValue(cl);
                        //create and add opinion to the structure
                        //   trgt, offsetFrom, offsetTo, polarity, cat, sId);
                        //Opinion op = new Opinion(instOps.get(i)+"_"+oSubId, "", 0, 0, "", cl, sId);
                        //reader.addOpinion(op);
                    }
                    oSubId++;
                }
            } //finished updating instances data                                    
        }

        entdata.setClass(entdata.attribute("attCat"));
        onevsall = new WekaWrapper(entdata, true);

        /**
         *  Bigarren sailkatzailea
         * 
         * */
        if (!onlyTest) {
            onevsall.trainOneVsAll(modelsPath, paramFile + "attCat");
            System.out.println("trainATC: one vs all attcat models ready");
        }

        ovsaRes = onevsall.predictOneVsAll(modelsPath, paramFile + "entAttCat");

        insAtt = entdata.attribute("instanceId");
        maxInstId = entdata.kthSmallestValue(insAtt, insAtt.numValues());
        System.err.println("last instance has index: " + maxInstId);
        for (int ins = 0; ins < entdata.numInstances(); ins++) {
            System.err.println("ins: " + ins);
            int i = (int) entdata.instance(ins).value(insAtt);
            Instance currentInst = entdata.instance(ins);
            //System.err.println("instance "+i+" oid "+kk.get(i+1)+"kk contains key i?"+kk.containsKey(i));
            String sId = reader.getOpinion(instOps.get(i)).getsId();
            String oId = instOps.get(i);
            reader.removeSentenceOpinions(sId);
            int oSubId = 0;
            for (String cl : ovsaRes.get(i).keySet()) {
                //System.err.println("instance: "+i+" class "+cl+" value: "+ovsaRes.get(i).get(cl));
                if (ovsaRes.get(i).get(cl) > threshold2) {
                    ///System.err.println("instance: "+i+" class "+cl+" value: "+ovsaRes.get(i).get(cl));
                    if (ovsaRes.get(i).get(cl) > threshold) {
                        //System.err.println("one got through ! instance "+i+" class "+cl+" value: "+ovsaRes.get(i).get(cl));                  
                        // for the first one update the instances
                        if (oSubId >= 1) {
                            String label = currentInst.stringValue(entdata.attribute("entAtt")) + "#" + cl;
                            //create and add opinion to the structure
                            //   trgt, offsetFrom, offsetTo, polarity, cat, sId);                     
                            Opinion op = new Opinion(oId + "_" + oSubId, "", 0, 0, "", label, sId);
                            reader.addOpinion(op);
                        }
                        // if the are more create new instances
                        else {
                            String label = currentInst.stringValue(entdata.attribute("entAtt")) + "#" + cl;
                            //create and add opinion to the structure
                            //   trgt, offsetFrom, offsetTo, polarity, cat, sId);
                            reader.removeOpinion(oId);
                            Opinion op = new Opinion(oId + "_" + oSubId, "", 0, 0, "", label, sId);
                            reader.addOpinion(op);
                        }
                        oSubId++;
                    }
                } //finished updating instances data                                    
            }
        }
        reader.print2Semeval2015format(paramFile + "entAttCat.xml");
    } catch (Exception e) {
        e.printStackTrace();
    }

    //traindata.setClass(traindata.attribute("entAttCat"));
    System.err.println("DONE CLI train-atc2 (oneVsAll)");
}

From source file:elh.eus.absa.CLI.java

License:Open Source License

/**
 * train ATC using a single classifier (one vs. all) for E#A aspect categories.
 * //from   ww w  . ja  va  2s.co m
 * @param inputStream
 * @throws IOException
 */
public final void trainATCsingleCategory(final InputStream inputStream) throws IOException {
    // load training parameters file
    String paramFile = parsedArguments.getString("params");
    String testFile = parsedArguments.getString("testset");
    String corpusFormat = parsedArguments.getString("corpusFormat");
    //String validation = parsedArguments.getString("validation");
    String lang = parsedArguments.getString("language");
    //int foldNum = Integer.parseInt(parsedArguments.getString("foldNum"));
    //boolean printPreds = parsedArguments.getBoolean("printPreds");
    boolean nullSentenceOpinions = parsedArguments.getBoolean("nullSentences");
    boolean onlyTest = parsedArguments.getBoolean("testOnly");
    double threshold = 0.5;

    String modelsPath = "/home/inaki/Proiektuak/BOM/SEMEVAL2015/ovsaModels";

    CorpusReader reader = new CorpusReader(inputStream, corpusFormat, nullSentenceOpinions, lang);
    Features atcTrain = new Features(reader, paramFile, "3");
    Instances traindata = atcTrain.loadInstances(true, "atc");

    if (onlyTest) {
        if (FileUtilsElh.checkFile(testFile)) {
            System.err.println("read from test file");
            reader = new CorpusReader(new FileInputStream(new File(testFile)), corpusFormat,
                    nullSentenceOpinions, lang);
            atcTrain.setCorpus(reader);
            traindata = atcTrain.loadInstances(true, "atc");
        }
    }

    //setting class attribute (entCat|attCat|entAttCat|polarityCat)

    //HashMap<String, Integer> opInst = atcTrain.getOpinInst();
    //WekaWrapper classifyEnts;
    //WekaWrapper classifyAtts;
    WekaWrapper onevsall;
    try {

        //classify.printMultilabelPredictions(classify.multiLabelPrediction());      */   

        //onevsall
        //Instances entdata = new Instances(traindata);
        traindata.deleteAttributeAt(traindata.attribute("attCat").index());
        traindata.deleteAttributeAt(traindata.attribute("entCat").index());
        traindata.setClassIndex(traindata.attribute("entAttCat").index());
        onevsall = new WekaWrapper(traindata, true);

        if (!onlyTest) {
            onevsall.trainOneVsAll(modelsPath, paramFile + "entAttCat");
            System.out.println("trainATC: one vs all models ready");
        }
        onevsall.setTestdata(traindata);
        HashMap<Integer, HashMap<String, Double>> ovsaRes = onevsall.predictOneVsAll(modelsPath,
                paramFile + "entAttCat");
        System.out.println("trainATC: one vs all predictions ready");
        HashMap<Integer, String> kk = new HashMap<Integer, String>();
        for (String oId : atcTrain.getOpinInst().keySet()) {
            kk.put(atcTrain.getOpinInst().get(oId), oId);
        }

        Object[] ll = ovsaRes.get(1).keySet().toArray();
        for (Object l : ll) {
            System.err.print((String) l + " - ");
        }
        System.err.print("\n");

        for (int i : ovsaRes.keySet()) {
            //System.err.println("instance "+i+" oid "+kk.get(i+1)+"kk contains key i?"+kk.containsKey(i));
            String sId = reader.getOpinion(kk.get(i)).getsId();
            reader.removeSentenceOpinions(sId);
            int oSubId = 0;
            for (String cl : ovsaRes.get(i).keySet()) {
                //System.err.println("instance: "+i+" class "+cl+" value: "+ovsaRes.get(i).get(cl));
                if (ovsaRes.get(i).get(cl) > threshold) {
                    //System.err.println("one got through ! instance "+i+" class "+cl+" value: "+ovsaRes.get(i).get(cl));
                    oSubId++;
                    //create and add opinion to the structure
                    //trgt, offsetFrom, offsetTo, polarity, cat, sId);
                    Opinion op = new Opinion(kk.get(i) + "_" + oSubId, "", 0, 0, "", cl, sId);
                    reader.addOpinion(op);
                }
            }
        }
        reader.print2Semeval2015format(paramFile + "entAttCat.xml");
    } catch (Exception e) {
        e.printStackTrace();
    }

    //traindata.setClass(traindata.attribute("entAttCat"));
    System.err.println("DONE CLI train-atc2 (oneVsAll)");
}

From source file:elh.eus.absa.WekaWrapper.java

License:Open Source License

/**
 *      Train one vs all models over the given training data.
 *  /*from  w w  w. ja  va2  s. c o m*/
 * @param modelpath directory to store each model for the one vs. all method
 * @param prefix prefix the models should have (each model will have the name of its class appended
 * @throws Exception
 */
public void trainOneVsAll(String modelpath, String prefix) throws Exception {
    Instances orig = new Instances(traindata);
    Enumeration<Object> classValues = traindata.classAttribute().enumerateValues();
    String classAtt = traindata.classAttribute().name();
    while (classValues.hasMoreElements()) {
        String v = (String) classValues.nextElement();
        System.err.println("trainer onevsall for class " + v + " classifier");
        //needed because of weka's sparse data format problems THIS IS TROUBLE! ...
        if (v.equalsIgnoreCase("dummy")) {
            continue;
        }
        // copy instances and set the same class value
        Instances ovsa = new Instances(orig);
        //create a new class attribute         
        //   // Declare the class attribute along with its values
        ArrayList<String> classVal = new ArrayList<String>();
        classVal.add("dummy"); //needed because of weka's sparse data format problems...
        classVal.add(v);
        classVal.add("UNKNOWN");
        ovsa.insertAttributeAt(new Attribute(classAtt + "2", classVal), ovsa.numAttributes());
        //change all instance labels that have not the current class value to "other"
        for (int i = 0; i < ovsa.numInstances(); i++) {
            Instance inst = ovsa.instance(i);
            String instClass = inst.stringValue(ovsa.attribute(classAtt).index());
            if (instClass.equalsIgnoreCase(v)) {
                inst.setValue(ovsa.attribute(classAtt + "2").index(), v);
            } else {
                inst.setValue(ovsa.attribute(classAtt + "2").index(), "UNKNOWN");
            }
        }
        //delete the old class attribute and set the new.         
        ovsa.setClassIndex(ovsa.attribute(classAtt + "2").index());
        ovsa.deleteAttributeAt(ovsa.attribute(classAtt).index());
        ovsa.renameAttribute(ovsa.attribute(classAtt + "2").index(), classAtt);
        ovsa.setClassIndex(ovsa.attribute(classAtt).index());

        //build the classifier, crossvalidate and store the model
        setTraindata(ovsa);
        saveModel(modelpath + File.separator + prefix + "_" + v + ".model");
        setTestdata(ovsa);
        testModel(modelpath + File.separator + prefix + "_" + v + ".model");

        System.err.println("trained onevsall " + v + " classifier");
    }

    setTraindata(orig);
}

From source file:entities.WekaBaselineBOWFeatureVector.java

public Instances fillInstanceSet(ArrayList<BaselineBOWFeatureVector> vList,
        ArrayList<BaselineBOWFeatureVector> vList2) throws IOException {

    ArrayList<Attribute> attributes = initializeWekaFeatureVector();
    Instances isSet = new Instances(vList.get(0).getLabel(), attributes, vList.size());

    isSet.setClassIndex(isSet.numAttributes() - 1);

    for (BaselineBOWFeatureVector BOWv : vList) {

        Instance i = fillFeatureVector(BOWv, isSet);

        isSet.add(i);/* w  w w  .  ja v  a 2  s  .  c o  m*/
    }

    for (BaselineBOWFeatureVector BOWv : vList2) {

        Instance i = fillFeatureVector(BOWv, isSet);

        isSet.add(i);
    }

    ArffSaver saver = new ArffSaver();
    saver.setInstances(isSet);
    saver.setFile(new File("./data/test.arff"));
    saver.writeBatch();

    return isSet;
}

From source file:entities.WekaBOWFeatureVector.java

public Instances fillInstanceSet(ArrayList<BOWFeatureVector> vList, ArrayList<BOWFeatureVector> vList2)
        throws IOException {

    ArrayList<Attribute> attributes = initializeWekaFeatureVector();
    Instances isSet = new Instances(vList.get(0).getLabel(), attributes, vList.size());

    isSet.setClassIndex(isSet.numAttributes() - 1);

    for (BOWFeatureVector BOWv : vList) {

        Instance i = fillFeatureVector(BOWv, isSet);

        isSet.add(i);/* www. j  ava2 s  .  c o  m*/
    }

    for (BOWFeatureVector BOWv : vList2) {

        Instance i = fillFeatureVector(BOWv, isSet);

        isSet.add(i);
    }

    ArffSaver saver = new ArffSaver();
    saver.setInstances(isSet);
    saver.setFile(new File("./data/test.arff"));
    saver.writeBatch();

    return isSet;
}

From source file:entities.WekaHMMFeatureVector.java

public Instances fillInstanceSet(ArrayList<HMMFeatureVector> vList, ArrayList<HMMFeatureVector> vList2)
        throws IOException {

    //FastVector fvWekaAttributesHmm = new FastVector(3);

    ArrayList<Attribute> attributes = initializeWekaFeatureVector();
    Instances isSet = new Instances("dataset", attributes, vList.size());

    isSet.setClassIndex(isSet.numAttributes() - 1);

    for (HMMFeatureVector HMMv : vList) {

        Instance i = fillFeatureVector(HMMv, isSet);

        isSet.add(i);//from w  w  w .  j  a  v  a  2  s .  c o m
    }

    for (HMMFeatureVector HMMv : vList2) {

        Instance i = fillFeatureVector(HMMv, isSet);

        isSet.add(i);
    }

    ArffSaver saver = new ArffSaver();
    saver.setInstances(isSet);
    saver.setFile(new File("./data/test.arff"));
    saver.writeBatch();

    return isSet;
}