The owner of a Node * may freely choose the license terms applicable to such Node, including * when such Node is propagated with or for interoperation with KNIME. * --------------------------------------------------------------------- * * */ package org.knime.knip.suise.node.boundarymodel.contourdata; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import net.imglib2.util.Pair; import net.imglib2.util.ValuePair; import org.knime.knip.core.util.PermutationSort; /** * TODO Auto-generated * * @author <a href="">Martin Horn</a> */ import weka.classifiers.AbstractClassifier; import weka.core.Capabilities; import weka.core.Capabilities.Capability; import weka.core.DistanceFunction; import weka.core.EuclideanDistance; import weka.core.Instance; import weka.core.Instances; import weka.filters.unsupervised.attribute.MultiInstanceToPropositional; /** * Interval Rule Induction * * TODO's: contour models (consisting of different rules) * * @author hornm, University of Konstanz */ public class IRI extends AbstractClassifier { private double m_sampleRate = .1; private double m_bias = 100; private double m_coverRate = .8; private List<IntervalRule> m_rules = new ArrayList<IntervalRule>(); private int m_numThreads = 8; /** * {@inheritDoc} */ @Override public void buildClassifier(Instances miData) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(miData); final Instances tmpMiData = new Instances(miData); final Instances flatData = toSingleInstanceDataset(miData, null); int numPosBags = 0; for (Instance bag : miData) { if (bag.value(2) == 1) { numPosBags++; } } int remainingNumPosBags = numPosBags; Future<Pair<IntervalRule, Double>>[] futures = new Future[m_numThreads]; ExecutorService pool = Executors.newFixedThreadPool(m_numThreads); while (remainingNumPosBags / (double) numPosBags > 1 - m_coverRate) { final int numIterations = ((int) (m_sampleRate * remainingNumPosBags)) / m_numThreads + 1; for (int t = 0; t < m_numThreads; t++) { futures[t] = pool.submit(new Callable<Pair<IntervalRule, Double>>() { @Override public Pair<IntervalRule, Double> call() throws Exception { return createRule(flatData, tmpMiData, numIterations); } }); } // select the best rule from the threads double score = -Double.MAX_VALUE; IntervalRule rule = null; for (int f = 0; f < futures.length; f++) { if (futures[f].get().getB() > score) { score = futures[f].get().getB(); rule = futures[f].get().getA(); } } m_rules.add(rule); // only keep the bags whose instances are not covered by this rule Instances tmp = new Instances(tmpMiData); tmpMiData.clear(); boolean covered; remainingNumPosBags = 0; for (Instance bag : tmp) { covered = false; for (Instance inst : bag.relationalValue(1)) { double[] distr; distr = rule.distributionForInstance(inst); if (distr[1] > distr[0]) { covered = true; break; } } if (!covered) { tmpMiData.add(bag); if (bag.value(2) == 1) { remainingNumPosBags++; } } } flatData.clear(); toSingleInstanceDataset(tmpMiData, flatData); } pool.shutdown(); } private Pair<IntervalRule, Double> createRule(Instances flatData, Instances miData, int iterations) throws Exception { // store for the distances between the reference distance and all others double[] distances = new double[flatData.numInstances()]; // the distance function DistanceFunction distFunc = new EuclideanDistance(flatData); // permutation which sorts the distances Integer[] perm = new Integer[flatData.numInstances()]; IntervalRule bestRule = null; double bestRuleScore = -Double.MAX_VALUE; // retrieve the best rule heuristically for a number of iterations for (int ruleIterations = 0; ruleIterations < iterations; ruleIterations++) { // System.out.println("------- Iteration " + ruleIterations // + "----------"); // randomly select an initial instance, i.e. selecting a positive // bag // randomly and taking the instance with the largest weight Random r = new Random(); int bagIdx; while (miData.get(bagIdx = r.nextInt(miData.numInstances())).value(2) == 0) ; // the reference instance for the next rule Instance refInstance = miData.get(bagIdx).relationalValue(1).firstInstance(); for (Instance i : miData.get(bagIdx).relationalValue(1)) { if (i.weight() > refInstance.weight()) { refInstance = i; } } // System.out.println("\tRef Instance: " + refInstance); IntervalRule rule = new IntervalRule(); rule.updateClassifier(refInstance); // calculate the distance from that particular reference instance to // all other // positive instances (negatives are set to NaN) and sort them Arrays.fill(distances, Double.NaN); for (int i = 0; i < distances.length; i++) { if (flatData.get(i).classValue() == 1) { distances[i] = distFunc.distance(refInstance, flatData.get(i)); } } PermutationSort.sortPermInPlace(distances, perm); double ruleScore = 0; double tmpRuleScore = 0; // extend the rule successively by the nearest instances till the // score doesn't increase anymore int instanceIdx = 0; while (true) { if (!Double.isNaN(distances[perm[instanceIdx]])) { IntervalRule tmpRule = new IntervalRule(rule); tmpRule.updateClassifier(flatData.get(perm[instanceIdx])); // System.out.println("\tNext Instance: " // + flatData.get(perm[instanceIdx])); // System.out.println("\tCurrent Rule: " + tmpRule); // evaluate rule tmpRuleScore = ruleScore(tmpRule, flatData); if (tmpRuleScore >= ruleScore) { ruleScore = tmpRuleScore; rule = tmpRule; } else { break; } } instanceIdx++; } if (ruleScore > bestRuleScore) { bestRuleScore = ruleScore; bestRule = rule; } } // iterations per rule return new ValuePair<IntervalRule, Double>(bestRule, bestRuleScore); } private double ruleScore(IntervalRule rule, Instances data) throws Exception { double posCount = 0; double negCount = 0; double posSumWeights = 0; for (Instance inst : data) { if (inst.weight() > 0) { double dist[] = rule.distributionForInstance(inst); if (dist[1] > dist[0]) { if (inst.classValue() == 1) { posSumWeights += inst.weight(); posCount++; } else { negCount++; } } } } double score = posSumWeights / (posCount + negCount + m_bias); // System.out.println("\tpSW=" + posSumWeights + ";pC=" + posCount // + ";nC=" + negCount + ";score=" + score); return score; } private Instances toSingleInstanceDataset(Instances miData, Instances flatData) throws Exception { MultiInstanceToPropositional convertToProp = new MultiInstanceToPropositional(); convertToProp.setInputFormat(miData); for (int i = 0; i < miData.numInstances(); i++) { convertToProp.input(miData.instance(i)); } convertToProp.batchFinished(); if (flatData == null) { flatData = convertToProp.getOutputFormat(); flatData.deleteAttributeAt(0); // remove the bag index attribute } Instance processed; while ((processed = convertToProp.output()) != null) { processed.setDataset(null); processed.deleteAttributeAt(0); // remove the bag index attribute flatData.add(processed); } // remove class attribute // flatData.setClassIndex(-1); // flatData.deleteAttributeAt(flatData.numAttributes() - 1); // set weights int instanceIdx = 0; for (Instance bag : miData) { for (Instance instance : bag.relationalValue(1)) { flatData.get(instanceIdx).setWeight(instance.weight()); instanceIdx++; } } return flatData; } /** * {@inheritDoc} */ @Override public double[] distributionForInstance(Instance bag) throws Exception { Instances contents = bag.relationalValue(1); double[] res = null; boolean positive = false; for (Instance i : contents) { res = distributionForSingleInstance(i); if (res[1] > res[0]) { positive = true; break; } } if (positive) { return res; } else { return new double[] { 1.0, 0 }; } } public double[] distributionForSingleInstance(Instance instance) throws Exception { double res = 0; for (IntervalRule r : m_rules) { res = Math.max(res, r.distributionForInstance(instance)[1]); } return new double[] { 1 - res, res }; } /** * @param instance * @return -1 if no fitting rule was found * @throws Exception */ public int getBestRuleIndexForSingleInstance(Instance instance) throws Exception { int bestIdx = -1; double bestProb = 0; for (int i = 0; i < m_rules.size(); i++) { double[] distr = m_rules.get(i).distributionForInstance(instance); if (distr[1] > bestProb) { bestProb = distr[1]; bestIdx = i; } } return bestIdx; } public IntervalRule getRule(int index) { return m_rules.get(index); } public void setBias(double bias) { m_bias = bias; } public void addRules(IntervalRule... rules) { m_rules.addAll(Arrays.asList(rules)); } /** * {@inheritDoc} */ @Override public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.RELATIONAL_ATTRIBUTES); result.disable(Capability.MISSING_VALUES); // class result.disableAllClasses(); result.disableAllClassDependencies(); result.enable(Capability.BINARY_CLASS); // Only multi instance data result.enable(Capability.ONLY_MULTIINSTANCE); return result; } }