List of usage examples for weka.core Instances setClassIndex
public void setClassIndex(int classIndex)
From source file:marytts.tools.newlanguage.LTSTrainer.java
License:Open Source License
/** * Train the tree, using binary decision nodes. * // w w w . j av a 2s .c om * @param minLeafData * the minimum number of instances that have to occur in at least two subsets induced by split * @return bigTree * @throws IOException * IOException */ public CART trainTree(int minLeafData) throws IOException { Map<String, List<String[]>> grapheme2align = new HashMap<String, List<String[]>>(); for (String gr : this.graphemeSet) { grapheme2align.put(gr, new ArrayList<String[]>()); } Set<String> phChains = new HashSet<String>(); // for every alignment pair collect counts for (int i = 0; i < this.inSplit.size(); i++) { StringPair[] alignment = this.getAlignment(i); for (int inNr = 0; inNr < alignment.length; inNr++) { // System.err.println(alignment[inNr]); // quotation signs needed to represent empty string String outAlNr = "'" + alignment[inNr].getString2() + "'"; // TODO: don't consider alignments to more than three characters if (outAlNr.length() > 5) continue; phChains.add(outAlNr); // storing context and target String[] datapoint = new String[2 * context + 2]; for (int ct = 0; ct < 2 * context + 1; ct++) { int pos = inNr - context + ct; if (pos >= 0 && pos < alignment.length) { datapoint[ct] = alignment[pos].getString1(); } else { datapoint[ct] = "null"; } } // set target datapoint[2 * context + 1] = outAlNr; // add datapoint grapheme2align.get(alignment[inNr].getString1()).add(datapoint); } } // for conversion need feature definition file FeatureDefinition fd = this.graphemeFeatureDef(phChains); int centerGrapheme = fd.getFeatureIndex("att" + (context + 1)); List<CART> stl = new ArrayList<CART>(fd.getNumberOfValues(centerGrapheme)); for (String gr : fd.getPossibleValues(centerGrapheme)) { System.out.println(" Training decision tree for: " + gr); logger.debug(" Training decision tree for: " + gr); ArrayList<Attribute> attributeDeclarations = new ArrayList<Attribute>(); // attributes with values for (int att = 1; att <= context * 2 + 1; att++) { // ...collect possible values ArrayList<String> attVals = new ArrayList<String>(); String featureName = "att" + att; for (String usableGrapheme : fd.getPossibleValues(fd.getFeatureIndex(featureName))) { attVals.add(usableGrapheme); } attributeDeclarations.add(new Attribute(featureName, attVals)); } List<String[]> datapoints = grapheme2align.get(gr); // maybe training is faster with targets limited to grapheme Set<String> graphSpecPh = new HashSet<String>(); for (String[] dp : datapoints) { graphSpecPh.add(dp[dp.length - 1]); } // targetattribute // ...collect possible values ArrayList<String> targetVals = new ArrayList<String>(); for (String phc : graphSpecPh) {// todo: use either fd of phChains targetVals.add(phc); } attributeDeclarations.add(new Attribute(TrainedLTS.PREDICTED_STRING_FEATURENAME, targetVals)); // now, create the dataset adding the datapoints Instances data = new Instances(gr, attributeDeclarations, 0); // datapoints for (String[] point : datapoints) { Instance currInst = new DenseInstance(data.numAttributes()); currInst.setDataset(data); for (int i = 0; i < point.length; i++) { currInst.setValue(i, point[i]); } data.add(currInst); } // Make the last attribute be the class data.setClassIndex(data.numAttributes() - 1); // build the tree without using the J48 wrapper class // standard parameters are: // binary split selection with minimum x instances at the leaves, tree is pruned, confidenced value, subtree raising, // cleanup, don't collapse // Here is used a modifed version of C45PruneableClassifierTree that allow using Unary Classes (see Issue #51) C45PruneableClassifierTree decisionTree; try { decisionTree = new C45PruneableClassifierTreeWithUnary( new BinC45ModelSelection(minLeafData, data, true), true, 0.25f, true, true, false); decisionTree.buildClassifier(data); } catch (Exception e) { throw new RuntimeException("couldn't train decisiontree using weka: ", e); } CART maryTree = TreeConverter.c45toStringCART(decisionTree, fd, data); stl.add(maryTree); } DecisionNode.ByteDecisionNode rootNode = new DecisionNode.ByteDecisionNode(centerGrapheme, stl.size(), fd); for (CART st : stl) { rootNode.addDaughter(st.getRootNode()); } Properties props = new Properties(); props.setProperty("lowercase", String.valueOf(convertToLowercase)); props.setProperty("stress", String.valueOf(considerStress)); props.setProperty("context", String.valueOf(context)); CART bigTree = new CART(rootNode, fd, props); return bigTree; }
From source file:marytts.tools.voiceimport.PauseDurationTrainer.java
License:Open Source License
private Instances enterDurations(Instances data, List<Integer> durs) { // System.out.println("discretizing durations..."); // now discretize and set target attributes (= pause durations) // for that, first train discretizer GmmDiscretizer discr = GmmDiscretizer.trainDiscretizer(durs, 6, true); // used to store the collected values ArrayList<String> targetVals = new ArrayList<String>(); for (int mappedDur : discr.getPossibleValues()) { targetVals.add(mappedDur + "ms"); }//from w w w. j a v a2 s . c o m // FastVector attributeDeclarations = data.; // attribute declaration finished data.insertAttributeAt(new Attribute("target", targetVals), data.numAttributes()); for (int i = 0; i < durs.size(); i++) { Instance currInst = data.instance(i); int dur = durs.get(i); // System.out.println(" mapping " + dur + " to " + discr.discretize(dur) + " - bi:" + // data.instance(i).value(data.attribute("breakindex"))); currInst.setValue(data.numAttributes() - 1, discr.discretize(dur) + "ms"); } // Make the last attribute be the class data.setClassIndex(data.numAttributes() - 1); return data; }
From source file:matres.MatResUI.java
private void doClassification() { J48 m_treeResiko;/*from w w w . j ava2s.co m*/ J48 m_treeAksi; NaiveBayes m_nbResiko; NaiveBayes m_nbAksi; FastVector m_fvInstanceRisks; FastVector m_fvInstanceActions; InputStream isRiskTree = getClass().getResourceAsStream("data/ResikoTree.model"); InputStream isRiskNB = getClass().getResourceAsStream("data/ResikoNB.model"); InputStream isActionTree = getClass().getResourceAsStream("data/AksiTree.model"); InputStream isActionNB = getClass().getResourceAsStream("data/AksiNB.model"); m_treeResiko = new J48(); m_treeAksi = new J48(); m_nbResiko = new NaiveBayes(); m_nbAksi = new NaiveBayes(); try { //m_treeResiko = (J48) weka.core.SerializationHelper.read("ResikoTree.model"); m_treeResiko = (J48) weka.core.SerializationHelper.read(isRiskTree); //m_nbResiko = (NaiveBayes) weka.core.SerializationHelper.read("ResikoNB.model"); m_nbResiko = (NaiveBayes) weka.core.SerializationHelper.read(isRiskNB); //m_treeAksi = (J48) weka.core.SerializationHelper.read("AksiTree.model"); m_treeAksi = (J48) weka.core.SerializationHelper.read(isActionTree); //m_nbAksi = (NaiveBayes) weka.core.SerializationHelper.read("AksiNB.model"); m_nbAksi = (NaiveBayes) weka.core.SerializationHelper.read(isActionNB); } catch (Exception ex) { Logger.getLogger(MatResUI.class.getName()).log(Level.SEVERE, null, ex); } System.out.println("Setting up an Instance..."); // Values for LIKELIHOOD OF OCCURRENCE FastVector fvLO = new FastVector(5); fvLO.addElement("> 10 in 1 year"); fvLO.addElement("1 - 10 in 1 year"); fvLO.addElement("1 in 1 year to 1 in 10 years"); fvLO.addElement("1 in 10 years to 1 in 100 years"); fvLO.addElement("1 in more than 100 years"); // Values for SAFETY FastVector fvSafety = new FastVector(5); fvSafety.addElement("near miss"); fvSafety.addElement("first aid injury, medical aid injury"); fvSafety.addElement("lost time injury / temporary disability"); fvSafety.addElement("permanent disability"); fvSafety.addElement("fatality"); // Values for EXTRA FUEL COST FastVector fvEFC = new FastVector(5); fvEFC.addElement("< 100 million rupiah"); fvEFC.addElement("0,1 - 1 billion rupiah"); fvEFC.addElement("1 - 10 billion rupiah"); fvEFC.addElement("10 - 100 billion rupiah"); fvEFC.addElement("> 100 billion rupiah"); // Values for SYSTEM RELIABILITY FastVector fvSR = new FastVector(5); fvSR.addElement("< 100 MWh"); fvSR.addElement("0,1 - 1 GWh"); fvSR.addElement("1 - 10 GWh"); fvSR.addElement("10 - 100 GWh"); fvSR.addElement("> 100 GWh"); // Values for EQUIPMENT COST FastVector fvEC = new FastVector(5); fvEC.addElement("< 50 million rupiah"); fvEC.addElement("50 - 500 million rupiah"); fvEC.addElement("0,5 - 5 billion rupiah"); fvEC.addElement("5 -50 billion rupiah"); fvEC.addElement("> 50 billion rupiah"); // Values for CUSTOMER SATISFACTION SOCIAL FACTOR FastVector fvCSSF = new FastVector(5); fvCSSF.addElement("Complaint from the VIP customer"); fvCSSF.addElement("Complaint from industrial customer"); fvCSSF.addElement("Complaint from community"); fvCSSF.addElement("Complaint from community that have potential riot"); fvCSSF.addElement("High potential riot"); // Values for RISK FastVector fvRisk = new FastVector(4); fvRisk.addElement("Low"); fvRisk.addElement("Moderate"); fvRisk.addElement("High"); fvRisk.addElement("Extreme"); // Values for ACTION FastVector fvAction = new FastVector(3); fvAction.addElement("Life Extension Program"); fvAction.addElement("Repair/Refurbish"); fvAction.addElement("Replace/Run to Fail + Investment"); // Defining Attributes, including Class(es) Attributes Attribute attrLO = new Attribute("LO", fvLO); Attribute attrSafety = new Attribute("Safety", fvSafety); Attribute attrEFC = new Attribute("EFC", fvEFC); Attribute attrSR = new Attribute("SR", fvSR); Attribute attrEC = new Attribute("EC", fvEC); Attribute attrCSSF = new Attribute("CSSF", fvCSSF); Attribute attrRisk = new Attribute("Risk", fvRisk); Attribute attrAction = new Attribute("Action", fvAction); m_fvInstanceRisks = new FastVector(7); m_fvInstanceRisks.addElement(attrLO); m_fvInstanceRisks.addElement(attrSafety); m_fvInstanceRisks.addElement(attrEFC); m_fvInstanceRisks.addElement(attrSR); m_fvInstanceRisks.addElement(attrEC); m_fvInstanceRisks.addElement(attrCSSF); m_fvInstanceRisks.addElement(attrRisk); m_fvInstanceActions = new FastVector(7); m_fvInstanceActions.addElement(attrLO); m_fvInstanceActions.addElement(attrSafety); m_fvInstanceActions.addElement(attrEFC); m_fvInstanceActions.addElement(attrSR); m_fvInstanceActions.addElement(attrEC); m_fvInstanceActions.addElement(attrCSSF); m_fvInstanceActions.addElement(attrAction); Instances dataRisk = new Instances("A-Risk-instance-to-classify", m_fvInstanceRisks, 0); Instances dataAction = new Instances("An-Action-instance-to-classify", m_fvInstanceActions, 0); double[] riskValues = new double[dataRisk.numAttributes()]; double[] actionValues = new double[dataRisk.numAttributes()]; String strLO = (String) m_cmbLO.getSelectedItem(); String strSafety = (String) m_cmbSafety.getSelectedItem(); String strEFC = (String) m_cmbEFC.getSelectedItem(); String strSR = (String) m_cmbSR.getSelectedItem(); String strEC = (String) m_cmbEC.getSelectedItem(); String strCSSF = (String) m_cmbCSSF.getSelectedItem(); Instance instRisk = new DenseInstance(7); Instance instAction = new DenseInstance(7); if (strLO.equals("-- none --")) { instRisk.setMissing(0); instAction.setMissing(0); } else { instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(0), strLO); instAction.setValue((Attribute) m_fvInstanceActions.elementAt(0), strLO); } if (strSafety.equals("-- none --")) { instRisk.setMissing(1); instAction.setMissing(1); } else { instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(1), strSafety); instAction.setValue((Attribute) m_fvInstanceActions.elementAt(1), strSafety); } if (strEFC.equals("-- none --")) { instRisk.setMissing(2); instAction.setMissing(2); } else { instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(2), strEFC); instAction.setValue((Attribute) m_fvInstanceActions.elementAt(2), strEFC); } if (strSR.equals("-- none --")) { instRisk.setMissing(3); instAction.setMissing(3); } else { instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(3), strSR); instAction.setValue((Attribute) m_fvInstanceActions.elementAt(3), strSR); } if (strEC.equals("-- none --")) { instRisk.setMissing(4); instAction.setMissing(4); } else { instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(4), strEC); instAction.setValue((Attribute) m_fvInstanceActions.elementAt(4), strEC); } if (strCSSF.equals("-- none --")) { instRisk.setMissing(5); instAction.setMissing(5); } else { instAction.setValue((Attribute) m_fvInstanceActions.elementAt(5), strCSSF); instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(5), strCSSF); } instRisk.setMissing(6); instAction.setMissing(6); dataRisk.add(instRisk); instRisk.setDataset(dataRisk); dataRisk.setClassIndex(dataRisk.numAttributes() - 1); dataAction.add(instAction); instAction.setDataset(dataAction); dataAction.setClassIndex(dataAction.numAttributes() - 1); System.out.println("Instance Resiko: " + dataRisk.instance(0)); System.out.println("\tNum Attributes : " + dataRisk.numAttributes()); System.out.println("\tNum instances : " + dataRisk.numInstances()); System.out.println("Instance Action: " + dataAction.instance(0)); System.out.println("\tNum Attributes : " + dataAction.numAttributes()); System.out.println("\tNum instances : " + dataAction.numInstances()); int classIndexRisk = 0; int classIndexAction = 0; String strClassRisk = null; String strClassAction = null; try { //classIndexRisk = (int) m_treeResiko.classifyInstance(dataRisk.instance(0)); classIndexRisk = (int) m_treeResiko.classifyInstance(instRisk); classIndexAction = (int) m_treeAksi.classifyInstance(instAction); } catch (Exception ex) { Logger.getLogger(MatResUI.class.getName()).log(Level.SEVERE, null, ex); } strClassRisk = (String) fvRisk.elementAt(classIndexRisk); strClassAction = (String) fvAction.elementAt(classIndexAction); System.out.println("[Risk Class Index: " + classIndexRisk + " Class Label: " + strClassRisk + "]"); System.out.println("[Action Class Index: " + classIndexAction + " Class Label: " + strClassAction + "]"); if (strClassRisk != null) { m_txtRisk.setText(strClassRisk); } double[] riskDist = null; double[] actionDist = null; try { riskDist = m_nbResiko.distributionForInstance(dataRisk.instance(0)); actionDist = m_nbAksi.distributionForInstance(dataAction.instance(0)); String strProb; // set up RISK progress bars m_jBarRiskLow.setValue((int) (100 * riskDist[0])); m_jBarRiskLow.setString(String.format("%6.3f%%", 100 * riskDist[0])); m_jBarRiskModerate.setValue((int) (100 * riskDist[1])); m_jBarRiskModerate.setString(String.format("%6.3f%%", 100 * riskDist[1])); m_jBarRiskHigh.setValue((int) (100 * riskDist[2])); m_jBarRiskHigh.setString(String.format("%6.3f%%", 100 * riskDist[2])); m_jBarRiskExtreme.setValue((int) (100 * riskDist[3])); m_jBarRiskExtreme.setString(String.format("%6.3f%%", 100 * riskDist[3])); } catch (Exception ex) { Logger.getLogger(MatResUI.class.getName()).log(Level.SEVERE, null, ex); } double predictedProb = 0.0; String predictedClass = ""; // Loop over all the prediction labels in the distribution. for (int predictionDistributionIndex = 0; predictionDistributionIndex < riskDist.length; predictionDistributionIndex++) { // Get this distribution index's class label. String predictionDistributionIndexAsClassLabel = dataRisk.classAttribute() .value(predictionDistributionIndex); int classIndex = dataRisk.classAttribute().indexOfValue(predictionDistributionIndexAsClassLabel); // Get the probability. double predictionProbability = riskDist[predictionDistributionIndex]; if (predictionProbability > predictedProb) { predictedProb = predictionProbability; predictedClass = predictionDistributionIndexAsClassLabel; } System.out.printf("[%2d %10s : %6.3f]", classIndex, predictionDistributionIndexAsClassLabel, predictionProbability); } m_txtRiskNB.setText(predictedClass); }
From source file:mcib3d.Classification.DataSet.java
public void loadDataARFF(String fileName) { Instances dataTmp = null; try {/* w w w .j av a2s . co m*/ dataTmp = new Instances(new BufferedReader(new FileReader(fileName))); } catch (IOException e) { e.printStackTrace(); } dataTmp.setClassIndex(attributes.getClassIndex()); if (dataTmp.numAttributes() == attributes.size()) instances = dataTmp; else IJ.log("Pb readind arff, number of attributes different"); }
From source file:meddle.PredictByDomainOS.java
License:Open Source License
public static boolean loadAllModels(String className) { domainOSModel = new HashMap<String, Classifier>(); domainOSFeature = new HashMap<String, Map<String, Integer>>(); domainOSStruct = new HashMap<String, Instances>(); try {/*from w w w .j a v a2s .com*/ File modelFolder = new File(RConfig.modelFolder); File[] models = modelFolder.listFiles(); if (models != null) for (int i = 0; i < models.length; i++) { String fn = models[i].getName(); if (!fn.endsWith(className + ".model")) continue; String domainOS = fn.substring(0, fn.length() - className.length() - ".model".length() - 1); Classifier classifier; classifier = (Classifier) (Class.forName(className).newInstance()); classifier = (Classifier) SerializationHelper.read(RConfig.modelFolder + fn); // System.out.println(domainOS); domainOSModel.put(domainOS, classifier); ArffLoader loader = new ArffLoader(); String arffStructureFile = RConfig.arffFolder + domainOS + ".arff"; File af = new File(arffStructureFile); if (!af.exists()) continue; loader.setFile(new File(arffStructureFile)); Instances structure; try { structure = loader.getStructure(); } catch (Exception e) { continue; } structure.setClassIndex(structure.numAttributes() - 1); domainOSStruct.put(domainOS, structure); Map<String, Integer> fi = new HashMap<String, Integer>(); for (int j = 0; j < structure.numAttributes(); j++) { fi.put(structure.attribute(j).name(), j); } domainOSFeature.put(domainOS, fi); } } catch (InstantiationException | IllegalAccessException | ClassNotFoundException e) { e.printStackTrace(); return false; } catch (Exception e) { e.printStackTrace(); return false; } isModelLoaded = true; return true; }
From source file:meddle.PredictByDomainOS.java
License:Open Source License
private static boolean predictOneFlow(String line, String domainOS) { if (!domainOSModel.containsKey(domainOS)) return false; else {/*from w w w . ja va2 s . c o m*/ try { Classifier classifier = domainOSModel.get(domainOS); Map<String, Integer> fi = domainOSFeature.get(domainOS); Instances structure = domainOSStruct.get(domainOS); Instance current = getInstance(line, fi, fi.size()); Instances is = new Instances(structure); is.setClassIndex(is.numAttributes() - 1); is.add(current); current = is.get(is.size() - 1); current.setClassMissing(); double predicted = classifier.classifyInstance(current); if (predicted > 0) { return true; } else return false; } catch (Exception e) { e.printStackTrace(); } } return false; }
From source file:meddle.TrainModelByDomainOS.java
License:Open Source License
public static Instances populateArff(Info info, Map<String, Integer> wordCount, ArrayList<Map<String, Integer>> trainMatrix, ArrayList<Integer> PIILabels, int numSamples, int theta) { // System.out.println(info); // Mapping feature_name_index Map<String, Integer> fi = new HashMap<String, Integer>(); int index = 0; // Populate Features ArrayList<Attribute> attributes = new ArrayList<Attribute>(); int high_freq = trainMatrix.size(); if (high_freq - theta < 30) theta = 0;/*ww w.ja v a 2 s . c o m*/ for (Map.Entry<String, Integer> entry : wordCount.entrySet()) { // filter low frequency word String currentWord = entry.getKey(); int currentWordFreq = entry.getValue(); if (currentWordFreq < theta) { if (!SharedMem.wildKeys.get("android").containsKey(currentWord) && !SharedMem.wildKeys.get("ios").containsKey(currentWord) && !SharedMem.wildKeys.get("windows").containsKey(currentWord)) continue; } Attribute attribute = new Attribute(currentWord); attributes.add(attribute); fi.put(currentWord, index); index++; } ArrayList<String> classVals = new ArrayList<String>(); classVals.add("" + LABEL_NEGATIVE); classVals.add("" + LABEL_POSITIVE); attributes.add(new Attribute("PIILabel", classVals)); // Populate Data Points Iterator<Map<String, Integer>> all = trainMatrix.iterator(); int count = 0; Instances trainingInstances = new Instances("Rel", attributes, 0); trainingInstances.setClassIndex(trainingInstances.numAttributes() - 1); while (all.hasNext()) { Map<String, Integer> dataMap = all.next(); double[] instanceValue = new double[attributes.size()]; for (int i = 0; i < attributes.size() - 1; i++) { instanceValue[i] = 0; } int label = PIILabels.get(count); instanceValue[attributes.size() - 1] = label; for (Map.Entry<String, Integer> entry : dataMap.entrySet()) { if (fi.containsKey(entry.getKey())) { int i = fi.get(entry.getKey()); int val = entry.getValue(); instanceValue[i] = val; } } Instance data = new SparseInstance(1.0, instanceValue); trainingInstances.add(data); count++; } // Write into .arff file for persistence try { BufferedWriter bw = new BufferedWriter(new FileWriter(RConfig.arffFolder + info.domainOS + ".arff")); bw.write(trainingInstances.toString()); bw.close(); } catch (IOException e) { e.printStackTrace(); } return trainingInstances; }
From source file:meddle.TrainModelByDomainOS.java
License:Open Source License
/** * @param modelPath/* w w w. jav a2 s . co m*/ * @param arffPath * @param org */ public static MetaEvaluationMeasures loadAndEvaluateClassifier(String modelPath, String arffPath, String org) { MetaEvaluationMeasures mem = null; try { if ((new File(modelPath).exists())) { J48 j48 = (J48) SerializationHelper.read(modelPath); // Read data from BufferedReader reader = new BufferedReader(new FileReader(arffPath)); Instances data = new Instances(reader); reader.close(); // setting class attribute data.setClassIndex(data.numAttributes() - 1); mem = doEvaluation(j48, org, data, new MetaEvaluationMeasures()); } } catch (Exception e) { e.printStackTrace(); } return mem; }
From source file:meka.classifiers.multilabel.BRq.java
License:Open Source License
@Override public void buildClassifier(Instances data) throws Exception { testCapabilities(data);//from w ww . java 2s .c o m int c = data.classIndex(); if (getDebug()) System.out.print("-: Creating " + c + " models (" + m_Classifier.getClass().getName() + "): "); m_MultiClassifiers = AbstractClassifier.makeCopies(m_Classifier, c); Instances sub_data = null; for (int i = 0; i < c; i++) { int indices[][] = new int[c][c - 1]; for (int j = 0, k = 0; j < c; j++) { if (j != i) { indices[i][k++] = j; } } //Select only class attribute 'i' Remove FilterRemove = new Remove(); FilterRemove.setAttributeIndicesArray(indices[i]); FilterRemove.setInputFormat(data); FilterRemove.setInvertSelection(true); sub_data = Filter.useFilter(data, FilterRemove); sub_data.setClassIndex(0); /* BEGIN downsample for this link */ sub_data.randomize(m_Random); int numToRemove = sub_data.numInstances() - (int) Math.round(sub_data.numInstances() * m_DownSampleRatio); for (int m = 0, removed = 0; m < sub_data.numInstances(); m++) { if (sub_data.instance(m).classValue() <= 0.0) { sub_data.instance(m).setClassMissing(); if (++removed >= numToRemove) break; } } sub_data.deleteWithMissingClass(); /* END downsample for this link */ //Build the classifier for that class m_MultiClassifiers[i].buildClassifier(sub_data); if (getDebug()) System.out.print(" " + (i + 1)); } if (getDebug()) System.out.println(" :-"); m_InstancesTemplate = new Instances(sub_data, 0); }
From source file:meka.classifiers.multilabel.cc.CNode.java
License:Open Source License
/** * Transform - transform dataset D for this node. * this.j defines the current node index, e.g., 3 * this.paY[] defines parents, e.g., [1,4] * we should remove the rest, e.g., [0,2,5,...,L-1] * @return dataset we should remove all variables from D EXCEPT current node, and parents. *//*from w w w . jav a 2s.c om*/ public Instances transform(Instances D) throws Exception { int L = D.classIndex(); d = D.numAttributes() - L; int keep[] = A.append(this.paY, j); // keep all parents and self! Arrays.sort(keep); int remv[] = A.invert(keep, L); // i.e., remove the rest < L Arrays.sort(remv); map = new int[L]; for (int j = 0; j < L; j++) { map[j] = Arrays.binarySearch(keep, j); } Instances D_ = F.remove(new Instances(D), remv, false); D_.setClassIndex(map[this.j]); return D_; }