Java tutorial
/* * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package meka.classifiers.multitarget; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import meka.classifiers.multilabel.ProblemTransformationMethod; import meka.core.PSUtils; import meka.core.SuperLabelUtils; import weka.core.Attribute; import weka.core.Capabilities; import weka.core.Instance; import weka.core.Instances; import weka.classifiers.trees.J48; import meka.core.MLUtils; import meka.core.A; import weka.core.RevisionUtils; import weka.core.Utils; /** * NSR.java - The Nearest Set Relpacement (NSR) method. * A multi-target version of PS: The nearest sets are used to replace outliers, rather than subsets (as in PS). * Important Note: currently can only handle 10 values (or fewer) per target variable. * @see meka.classifiers.multilabel.PS * @version Jan 2013 * @author Jesse Read */ public class NSR extends meka.classifiers.multilabel.PS implements MultiTargetClassifier { /** for serialization. */ private static final long serialVersionUID = 8373228150066785001L; public NSR() { // default classifier for GUI this.m_Classifier = new J48(); } @Override protected String defaultClassifierString() { // default classifier for CLI return "weka.classifiers.trees.J48"; } /** * Description to display in the GUI. * * @return the description */ @Override public String globalInfo() { return "The Nearest Set Relpacement (NSR) method.\n" + "A multi-target version of PS: The nearest sets are used to replace outliers, rather than subsets (as in PS)."; } @Override public Capabilities getCapabilities() { Capabilities result; result = super.getCapabilities(); result.setMinimumNumberInstances(1); return result; } @Override public void buildClassifier(Instances D) throws Exception { testCapabilities(D); int L = D.classIndex(); try { m_Classifier.buildClassifier(convertInstances(D, L)); } catch (Exception e) { if (m_P > 0) { m_P--; System.err.println("Not enough distinct class values, trying again with P = " + m_P + " ..."); buildClassifier(D); } else throw new Exception("Failed to construct a classifier."); } } @Override public double[] distributionForInstance(Instance x) throws Exception { int L = x.classIndex(); //if there is only one class (as for e.g. in some hier. mtds) predict it //if(L == 1) return new double[]{1.0}; Instance x_sl = PSUtils.convertInstance(x, L, m_InstancesTemplate); // the sl instance //x_sl.setDataset(m_InstancesTemplate); // where y in {comb_1,comb_2,...,comb_k} double w[] = m_Classifier.distributionForInstance(x_sl); // w[j] = p(y_j) for each j = 1,...,L int max_j = Utils.maxIndex(w); // j of max w[j] //int max_j = (int)m_Classifier.classifyInstance(x_sl); // where comb_i is selected String y_max = m_InstancesTemplate.classAttribute().value(max_j); // comb_i e.g. "0+3+0+0+1+2+0+0" double y[] = Arrays.copyOf(A.toDoubleArray(MLUtils.decodeValue(y_max)), L * 2); // "0+3+0+0+1+2+0+0" -> [0.0,3.0,0.0,...,0.0] HashMap<Double, Double> votes[] = new HashMap[L]; for (int j = 0; j < L; j++) { votes[j] = new HashMap<Double, Double>(); } for (int i = 0; i < w.length; i++) { double y_i[] = A.toDoubleArray(MLUtils.decodeValue(m_InstancesTemplate.classAttribute().value(i))); for (int j = 0; j < y_i.length; j++) { votes[j].put(y_i[j], votes[j].containsKey(y_i[j]) ? votes[j].get(y_i[j]) + w[i] : w[i]); } } // some confidence information for (int j = 0; j < L; j++) { y[j + L] = votes[j].size() > 0 ? Collections.max(votes[j].values()) : 0.0; } return y; } // TODO: use PSUtils public double[] convertDistribution(double y_sl[], int L) { double y_ml[] = new double[L]; for (int i = 0; i < y_sl.length; i++) { if (y_sl[i] > 0.0) { double d[] = MLUtils.fromBitString(m_InstancesTemplate.classAttribute().value(i)); for (int j = 0; j < d.length; j++) { if (d[j] > 0.0) y_ml[j] = 1.0; } } } return y_ml; } // TODO: use SuperLabelUtils /** * GetTopNSubsets - return the top N subsets which differ from y by a single class value, ranked by the frequency storte in masterCombinations. */ public static String[] getTopNSubsets(String y, final HashMap<String, Integer> masterCombinations, int N) { String y_bits[] = y.split("\\+"); ArrayList<String> Y = new ArrayList<String>(); for (String y_ : masterCombinations.keySet()) { if (MLUtils.bitDifference(y_bits, y_.split("\\+")) <= 1) { Y.add(y_); } } Collections.sort(Y, new Comparator<String>() { public int compare(String s1, String s2) { // @note this is just done by the count, @todo: could add further conditions return (masterCombinations.get(s1) > masterCombinations.get(s2) ? -1 : (masterCombinations.get(s1) > masterCombinations.get(s2) ? 1 : 0)); } }); String Y_strings[] = Y.toArray(new String[Y.size()]); //System.out.println("returning "+N+"of "+Arrays.toString(Y_strings)); return Arrays.copyOf(Y_strings, Math.min(N, Y_strings.length)); } // TODO use PSUtils public Instances convertInstances(Instances D, int L) throws Exception { //Gather combinations HashMap<String, Integer> distinctCombinations = MLUtils.classCombinationCounts(D); if (getDebug()) System.out.println("Found " + distinctCombinations.size() + " unique combinations"); //Prune combinations MLUtils.pruneCountHashMap(distinctCombinations, m_P); if (getDebug()) System.out.println("Pruned to " + distinctCombinations.size() + " with P=" + m_P); // Remove all class attributes Instances D_ = MLUtils.deleteAttributesAt(new Instances(D), MLUtils.gen_indices(L)); // Add a new class attribute D_.insertAttributeAt(new Attribute("CLASS", new ArrayList(distinctCombinations.keySet())), 0); // create the class attribute D_.setClassIndex(0); //Add class values for (int i = 0; i < D.numInstances(); i++) { String y = MLUtils.encodeValue(MLUtils.toIntArray(D.instance(i), L)); // add it if (distinctCombinations.containsKey(y)) //if its class value exists D_.instance(i).setClassValue(y); // decomp else if (m_N > 0) { String d_subsets[] = SuperLabelUtils.getTopNSubsets(y, distinctCombinations, m_N); for (String s : d_subsets) { int w = distinctCombinations.get(s); Instance copy = (Instance) (D_.instance(i)).copy(); copy.setClassValue(s); copy.setWeight(1.0 / d_subsets.length); D_.add(copy); } } } // remove with missing class D_.deleteWithMissingClass(); // keep the header of new dataset for classification m_InstancesTemplate = new Instances(D_, 0); if (getDebug()) System.out.println("" + D_); return D_; } public static String[] decodeValue(String a) { return a.split("\\+"); } @Override public String getRevision() { return RevisionUtils.extract("$Revision: 9117 $"); } public static void main(String args[]) { ProblemTransformationMethod.evaluation(new meka.classifiers.multitarget.NSR(), args); } }