clus.algo.tdidt.tune.CDTuneSizeConstrPruning.java Source code

Java tutorial

Introduction

Here is the source code for clus.algo.tdidt.tune.CDTuneSizeConstrPruning.java

Source

/*************************************************************************
 * Clus - Software for Predictive Clustering                             *
 * Copyright (C) 2007                                                    *
 *    Katholieke Universiteit Leuven, Leuven, Belgium                    *
 *    Jozef Stefan Institute, Ljubljana, Slovenia                        *
 *                                                                       *
 * This program is free software: you can redistribute it and/or modify  *
 * it under the terms of the GNU General Public License as published by  *
 * the Free Software Foundation, either version 3 of the License, or     *
 * (at your option) any later version.                                   *
 *                                                                       *
 * This program is distributed in the hope that it will be useful,       *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 * GNU General Public License for more details.                          *
 *                                                                       *
 * You should have received a copy of the GNU General Public License     *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>. *
 *                                                                       *
 * Contact information: <http://www.cs.kuleuven.be/~dtai/clus/>.         *
 *************************************************************************/

/*
 * Created on May 11, 2005
 */
package clus.algo.tdidt.tune;

import java.io.*;
import java.util.*;

import jeans.io.MyFile;
import jeans.math.*;

import clus.main.*;
import clus.model.ClusModel;
import clus.model.ClusModelInfo;
import clus.selection.*;
import clus.statistic.*;
import clus.util.*;
import clus.pruning.*;
import clus.algo.ClusInductionAlgorithmType;
import clus.algo.tdidt.ClusDecisionTree;
import clus.algo.tdidt.ClusNode;
import clus.data.ClusData;
import clus.data.rows.*;
import clus.data.attweights.*;
import clus.data.type.*;
import clus.error.*;
import clus.ext.hierarchical.HierRemoveInsigClasses;

/*
import org.apache.commons.math.distribution.*;
import org.apache.commons.math.*;
*/

public class CDTuneSizeConstrPruning extends ClusDecisionTree {

    protected ClusInductionAlgorithmType m_Class;
    protected ClusSchema m_Schema;
    protected ClusStatistic m_TotalStat;
    protected boolean m_HasMissing;
    protected int m_NbExamples;
    //   protected TDistribution m_Distribution;
    protected int m_OrigSize;
    protected double m_RelErrAcc = 0.01;
    protected ArrayList m_Graph;
    protected int m_Optimal, m_MaxSize;
    protected ClusAttributeWeights m_TargetWeights;
    protected boolean m_Relative;
    protected double m_RelativeScale;

    public CDTuneSizeConstrPruning(ClusInductionAlgorithmType clss) {
        super(clss.getClus());
        m_Class = clss;
    }

    public void printInfo() {
        System.out.println("TDIDT (Tuning Size Constraint)");
        System.out.println("Heuristic: " + getStatManager().getHeuristicName());
    }

    private final void showFold(int i) {
        if (i != 0)
            System.out.print(" ");
        System.out.print(String.valueOf(i + 1));
        System.out.flush();
    }

    public void setRelativeMeasure(boolean enable, double value) {
        m_Relative = enable;
        m_RelativeScale = value;
    }

    public void computeTestStatistics(ClusRun[] runs, int model, ClusError error)
            throws IOException, ClusException {
        TreeErrorComputer comp = new TreeErrorComputer();
        for (int i = 0; i < runs.length; i++) {
            ClusNode tree = (ClusNode) runs[i].getModelInfo(model).getModel();
            TreeErrorComputer.initializeTestErrors(tree, error);
            MemoryTupleIterator test = (MemoryTupleIterator) runs[i].getTestIter();
            test.init();
            DataTuple tuple = test.readTuple();
            while (tuple != null) {
                tree.applyModelProcessor(tuple, comp);
                tuple = test.readTuple();
            }
        }
    }

    public void computeErrorStandard(ClusNode tree, int model, ClusRun run) throws ClusException, IOException {
        ClusModelInfo mi = run.getModelInfo(model);
        ClusError err = mi.getTestError().getFirstError();
        MemoryTupleIterator test = (MemoryTupleIterator) run.getTestIter();
        test.init();
        DataTuple tuple = test.readTuple();
        while (tuple != null) {
            ClusStatistic pred = tree.predictWeighted(tuple);
            err.addExample(tuple, pred);
            tuple = test.readTuple();
        }
    }

    public SingleStatList computeTreeError(ClusRun[] runs, SizeConstraintPruning[] pruners, int model,
            ClusSummary summ, int size) throws ClusException, IOException {
        ClusModelInfo summ_mi = summ.getModelInfo(model);
        ClusError summ_err = summ_mi.getTestError().getFirstError();
        summ_err.reset();
        SingleStatList res = new SingleStatList(runs.length);
        for (int i = 0; i < runs.length; i++) {
            ClusModelInfo mi = runs[i].getModelInfo(model);
            ClusNode tree = (ClusNode) mi.getModel();
            if (size == 1) {
                tree = tree.cloneNodeWithVisitor();
            } else {
                int modelsize = tree.getModelSize();
                if (size < modelsize) {
                    tree = tree.cloneTreeWithVisitors();
                    pruners[i].pruneExecute(tree, size);
                }
            }
            if (getStatManager().getMode() == ClusStatManager.MODE_HIERARCHICAL) {
                PruneTree pruner = new PruneTree();
                boolean bonf = getSettings().isUseBonferroni();
                HierRemoveInsigClasses hierpruner = new HierRemoveInsigClasses(runs[i].getPruneSet(), pruner, bonf,
                        getStatManager().getHier());
                hierpruner.setSignificance(getSettings().getHierPruneInSig());
                hierpruner.prune(tree);
            }
            ClusError err = mi.getTestError().getFirstError();
            err.reset();
            if (m_HasMissing) {
                computeErrorStandard(tree, model, runs[i]);
            } else {
                TreeErrorComputer.computeErrorSimple(tree, err);
            }
            summ_err.add(err);
            MemoryTupleIterator test = (MemoryTupleIterator) runs[i].getTestIter();
            mi.getTestError().setNbExamples(test.getNbExamples());
            if (m_Relative) {
                res.addFloat(err.getModelError() / m_RelativeScale);
            } else {
                res.addFloat(err.getModelError());
            }
        }
        summ_mi.getTestError().setNbExamples(m_NbExamples);
        /*
              PrintWriter wrt = new PrintWriter(new OutputStreamWriter(System.out));
              summ_mi.getTestError().showError(summ, ClusModelInfo.TEST_ERR, wrt);
              wrt.print(" -> "+summ_err.getModelError());
              wrt.flush();
        */
        if (m_Relative) {
            res.setY(summ_err.getModelError() / m_RelativeScale);
        } else {
            res.setY(summ_err.getModelError());
        }
        return res;
    }

    public SingleStatList addPoint(ArrayList points, int size, ClusRun[] runs, SizeConstraintPruning[] pruners,
            int model, ClusSummary summ) throws ClusException, IOException {
        int pos = 0;
        while (pos < points.size() && ((SingleStatList) points.get(pos)).getX() < size) {
            pos++;
        }
        if (pos < points.size() && ((SingleStatList) points.get(pos)).getX() == size) {
            return null;
        }
        SingleStatList point = computeTreeError(runs, pruners, model, summ, size);
        point.setX(size);
        points.add(pos, point);
        return point;
    }

    public double getRange(ArrayList graph) {
        double min = Double.POSITIVE_INFINITY;
        double max = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < graph.size(); i++) {
            SingleStatList elem = (SingleStatList) graph.get(i);
            if (elem.getY() < min)
                min = elem.getY();
            if (elem.getY() > max)
                max = elem.getY();
        }
        return Math.abs(max - min);
    }

    public void refineGraph(ArrayList graph, ClusRun[] runs, SizeConstraintPruning[] pruners, int model,
            ClusSummary summ) throws ClusException, IOException {
        int prevsize = -1;
        while (true) {
            boolean not_found = true;
            // double max_diff = getRange(graph);
            for (int i = 0; i < graph.size() - 2 && not_found; i++) {
                SingleStatList e1 = (SingleStatList) graph.get(i);
                SingleStatList e2 = (SingleStatList) graph.get(i + 1);
                if (Math.abs(e1.getY() - e2.getY()) > m_RelErrAcc) {
                    int s1 = (int) e1.getX();
                    int s2 = (int) e2.getX();
                    int nmean = ((s1 + s2) / 2 - 1) / 2;
                    int smean = 2 * nmean + 1;
                    if (smean != s1 && smean != s2 && (m_OrigSize == -1 || smean < m_OrigSize)) {
                        addPoint(graph, smean, runs, pruners, model, summ);
                        System.out.print("#");
                        System.out.flush();
                        /* we found a place to insert */
                        not_found = false;
                    }
                }
            }
            if (graph.size() == prevsize)
                return;
            prevsize = graph.size();
        }
    }

    /*
       public int significanceTest(SingleStatList l1, SingleStatList l2, double level) {
          double probability = 0.5 + level/2; // 1.0 - (1.0 - level)/2;
          try {
     double t_value = m_Distribution.inverseCumulativeProbability(probability);
     double t_test = l1.getTValueSigTest(l2);
     if (t_test > t_value) {
        if (l1.getY() > l2.getY()) return 1;
        else return -1;
     } else {
        return 0;
     }
          } catch (MathException e) {
     System.err.println("Math error: "+e.getMessage());
     return 0;
          }
       }
    */

    public int findOptimalSize(ArrayList graph, boolean shouldBeLow) {
        double best_value = shouldBeLow ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
        int best_index = -1;
        for (int i = 0; i < graph.size(); i++) {
            SingleStatList elem = (SingleStatList) graph.get(i);
            if (shouldBeLow) {
                if (elem.getY() < best_value) {
                    best_value = elem.getY();
                    best_index = i;
                }
            } else {
                if (elem.getY() > best_value) {
                    best_value = elem.getY();
                    best_index = i;
                }
            }
        }
        if (best_index == -1) {
            return 1;
        }
        // double max_diff = getRange(graph);
        SingleStatList best_elem = (SingleStatList) graph.get(best_index);
        System.out.print("[" + best_elem.getX() + "," + best_elem.getY() + "]");
        SingleStatList result = best_elem;
        int pos = best_index - 1;
        while (pos >= 0) {
            SingleStatList prev_elem = (SingleStatList) graph.get(pos);
            if (prev_elem.getX() >= 3 && Math.abs(prev_elem.getY() - best_elem.getY()) < m_RelErrAcc) {
                result = prev_elem;
                System.out.print(" < " + prev_elem.getX());
            }
            pos--;
        }
        return (int) result.getX();
    }

    public final XValMainSelection getXValSelection(Settings sett, int nbrows) throws IOException, ClusException {
        String value = sett.getTuneFolds();
        if (value.length() > 0 && Character.isDigit(value.charAt(0))) {
            try {
                int nbfolds = Integer.parseInt(value);
                Random random = new Random(0);
                return new XValRandomSelection(nbrows, nbfolds, random);
            } catch (NumberFormatException e) {
                throw new ClusException("Illegal number of folds: " + value);
            }
        } else {
            return XValDataSelection.readFoldsFile(value, nbrows);
        }
    }

    public void findBestSize(ClusData trset) throws ClusException, IOException {
        int prevVerb = Settings.enableVerbose(0);
        ClusStatManager mgr = getStatManager();
        ClusSummary summ = new ClusSummary();
        ClusErrorList errorpar = mgr.createDefaultError();
        errorpar.setWeights(m_TargetWeights);
        summ.setTestError(errorpar);
        int model = ClusModel.ORIGINAL;
        XValMainSelection sel = getXValSelection(getSettings(), trset.getNbRows());
        int nbfolds = sel.getNbFolds();
        ClusRun[] runs = new ClusRun[nbfolds];
        // Create trees based on cross-validation
        for (int i = 0; i < nbfolds; i++) {
            showFold(i);
            XValSelection msel = new XValSelection(sel, i);
            ClusRun cr = m_Clus.partitionDataBasic(trset, msel, summ, i + 1);
            ClusModel tree = m_Class.induceSingleUnpruned(cr);
            cr.getModelInfo(model).setModel(tree);
            runs[i] = cr;
        }
        // Construct tree pruners
        int maxsize = 0;
        SizeConstraintPruning pruners[] = new SizeConstraintPruning[nbfolds];
        for (int i = 0; i < nbfolds; i++) {
            ClusNode tree = (ClusNode) runs[i].getModelInfo(model).getModel();
            int size = tree.getModelSize();
            if (m_OrigSize != -1 && size > m_OrigSize)
                size = m_OrigSize;
            if (size > maxsize)
                maxsize = size;
            SizeConstraintPruning pruner = new SizeConstraintPruning(size, mgr.getClusteringWeights());
            pruner.pruneInitialize(tree, size);
            pruners[i] = pruner;
        }
        if (maxsize == 1) {
            System.out.println("Optimal size (maxsize = 1) = 1");
            m_Class.getSettings().setSizeConstraintPruning(1);
            return;
        }
        // Get training data in trees
        ClusError error = summ.getModelInfo(model).getTestError().getFirstError();
        if (!m_HasMissing) {
            computeTestStatistics(runs, model, error);
        }
        // Compute errors of default models
        ArrayList graph = new ArrayList();
        setRelativeMeasure(false, 0.0);
        SingleStatList point = computeTreeError(runs, pruners, model, summ, 1);
        setRelativeMeasure(true, point.getY());
        System.out.print(" ");
        System.out.print("<" + point.getY() + ">");
        addPoint(graph, 1, runs, pruners, model, summ);
        addPoint(graph, maxsize, runs, pruners, model, summ);
        // Add trees with exponentially increasing size
        int n = 1;
        boolean shouldBeLow = error.shouldBeLow();
        while (true) {
            int size = (int) (Math.pow(2, n) + 1.0);
            if (size > maxsize)
                break;
            if (m_OrigSize != -1 && size > m_OrigSize)
                break;
            SingleStatList new_pt = addPoint(graph, size, runs, pruners, model, summ);
            if (new_pt == null) {
                break;
            } else {
                if (shouldBeLow) {
                    if (graph.size() > 5 && new_pt.getY() > 1.1) {
                        break;
                    }
                } else {
                    if (graph.size() > 5 && new_pt.getY() < 0.9) {
                        break;
                    }
                }
                System.out.print("*");
                System.out.flush();
            }
            n++;
        }
        refineGraph(graph, runs, pruners, model, summ);
        int optimalSize = findOptimalSize(graph, shouldBeLow);
        System.out.println(" Best = " + optimalSize);
        // Write dat file
        setFinalResult(graph, optimalSize, maxsize);
        getSettings().setSizeConstraintPruning(optimalSize);
        Settings.enableVerbose(prevVerb);
    }

    public void saveInformation(String fname) {
        System.out.println("Saving: " + fname + ".dat");
        MyFile file = new MyFile(fname + ".dat");
        file.log("" + m_Optimal + "\t" + m_MaxSize);
        for (int i = 0; i < m_Graph.size(); i++) {
            SingleStatList elem = (SingleStatList) m_Graph.get(i);
            file.log("" + elem.getX() + "\t" + elem.getY());
        }
        file.close();
    }

    public void setFinalResult(ArrayList graph, int optimal, int maxsize) {
        m_Graph = graph;
        m_Optimal = optimal;
        m_MaxSize = maxsize;
    }

    public ClusModel induceSingle(ClusRun cr) {
        System.out.println(">>> Error: induceSingle/1 not implemented");
        return null;
    }

    public ClusStatistic createTotalStat(RowData data) {
        ClusStatistic stat = m_Class.getStatManager().createClusteringStat();
        data.calcTotalStatBitVector(stat);
        return stat;
    }

    public void induceAll(ClusRun cr) throws ClusException {
        try {
            long start_time = System.currentTimeMillis();
            m_OrigSize = getSettings().getSizeConstraintPruning(0);
            if (getSettings().getSizeConstraintPruningNumber() > 1) {
                throw new ClusException("Only one value is allowed for MaxSize if -tunesize is given");
            }
            RowData train = (RowData) cr.getTrainingSet();
            m_Schema = train.getSchema();
            m_HasMissing = m_Schema.hasMissing();
            m_TotalStat = createTotalStat(train);
            m_NbExamples = train.getNbRows();
            //         m_Distribution = DistributionFactory.newInstance().createTDistribution(getSettings().getTuneNbFolds()-1);
            System.out.println("Has missing values: " + m_HasMissing);
            m_TargetWeights = m_Class.getStatManager().getClusteringWeights();
            // Find optimal F-test value
            findBestSize(train);
            System.out.println();
            // Induce final model
            m_Class.induceAll(cr);
            getSettings().setSizeConstraintPruning(m_OrigSize);
            long time = (System.currentTimeMillis() - start_time);
            if (Settings.VERBOSE > 0)
                System.out.println("Time: " + (double) time / 1000 + " sec");
            cr.setInductionTime(time);
        } catch (ClusException e) {
            System.err.println("Error: " + e);
        } catch (IOException e) {
            System.err.println("IO Error: " + e);
        }
    }
}