clus.statistic.RegressionStat.java Source code

Java tutorial

Introduction

Here is the source code for clus.statistic.RegressionStat.java

Source

/*************************************************************************
 * Clus - Software for Predictive Clustering                             *
 * Copyright (C) 2007                                                    *
 *    Katholieke Universiteit Leuven, Leuven, Belgium                    *
 *    Jozef Stefan Institute, Ljubljana, Slovenia                        *
 *                                                                       *
 * This program is free software: you can redistribute it and/or modify  *
 * it under the terms of the GNU General Public License as published by  *
 * the Free Software Foundation, either version 3 of the License, or     *
 * (at your option) any later version.                                   *
 *                                                                       *
 * This program is distributed in the hope that it will be useful,       *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 * GNU General Public License for more details.                          *
 *                                                                       *
 * You should have received a copy of the GNU General Public License     *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>. *
 *                                                                       *
 * Contact information: <http://www.cs.kuleuven.be/~dtai/clus/>.         *
 *************************************************************************/

package clus.statistic;

import java.io.IOException;
import java.io.PrintWriter;
import java.text.*;
import java.util.ArrayList;
import java.util.Arrays;

import jeans.math.MathUtil;
import jeans.util.StringUtils;

import org.apache.commons.math.MathException;
import org.apache.commons.math.distribution.*;
import org.w3c.dom.Attr;
import org.w3c.dom.Document;
import org.w3c.dom.Element;

import clus.main.ClusStatManager;
import clus.main.Settings;
import clus.util.*;
import clus.data.cols.*;
import clus.data.rows.*;
import clus.data.type.*;
import clus.data.attweights.*;
import clus.error.ClusNumericError;

public class RegressionStat extends RegressionStatBase {

    public final static long serialVersionUID = Settings.SERIAL_VERSION_ID;

    public double[] m_SumValues;
    public double[] m_SumWeights;
    public double[] m_SumSqValues;
    public RegressionStat m_Training;

    public RegressionStat(NumericAttrType[] attrs) {
        this(attrs, false);
    }

    public RegressionStat(NumericAttrType[] attrs, boolean onlymean) {
        super(attrs, onlymean);
        if (!onlymean) {
            m_SumValues = new double[m_NbAttrs];
            m_SumWeights = new double[m_NbAttrs];
            m_SumSqValues = new double[m_NbAttrs];
        }
    }

    public void setTrainingStat(ClusStatistic train) {
        m_Training = (RegressionStat) train;
    }

    public ClusStatistic cloneStat() {
        RegressionStat res = new RegressionStat(m_Attrs, false);
        res.m_Training = m_Training;
        return res;
    }

    public ClusStatistic cloneSimple() {
        RegressionStat res = new RegressionStat(m_Attrs, true);
        res.m_Training = m_Training;
        return res;
    }

    /** Clone this statistic by taking the given weight into account.
     *  This is used for example to get the weighted prediction of default rule. */
    public ClusStatistic copyNormalizedWeighted(double weight) {
        //      RegressionStat newStat = (RegressionStat) cloneSimple();
        RegressionStat newStat = (RegressionStat) normalizedCopy();
        for (int iTarget = 0; iTarget < newStat.getNbAttributes(); iTarget++) {
            newStat.m_Means[iTarget] = weight * newStat.m_Means[iTarget];
        }
        return (ClusStatistic) newStat;
    }

    public void reset() {
        m_SumWeight = 0.0;
        m_NbExamples = 0;
        Arrays.fill(m_SumWeights, 0.0);
        Arrays.fill(m_SumValues, 0.0);
        Arrays.fill(m_SumSqValues, 0.0);
    }

    public void copy(ClusStatistic other) {
        RegressionStat or = (RegressionStat) other;
        m_SumWeight = or.m_SumWeight;
        m_NbExamples = or.m_NbExamples;
        System.arraycopy(or.m_SumWeights, 0, m_SumWeights, 0, m_NbAttrs);
        System.arraycopy(or.m_SumValues, 0, m_SumValues, 0, m_NbAttrs);
        System.arraycopy(or.m_SumSqValues, 0, m_SumSqValues, 0, m_NbAttrs);
    }

    /**
     * Used for combining weighted predictions.
     */
    public RegressionStat normalizedCopy() {
        RegressionStat copy = (RegressionStat) cloneSimple();
        copy.m_NbExamples = 0;
        copy.m_SumWeight = 1;
        calcMean(copy.m_Means);
        return copy;
    }

    public void add(ClusStatistic other) {
        RegressionStat or = (RegressionStat) other;
        m_SumWeight += or.m_SumWeight;
        m_NbExamples += or.m_NbExamples;
        for (int i = 0; i < m_NbAttrs; i++) {
            m_SumWeights[i] += or.m_SumWeights[i];
            m_SumValues[i] += or.m_SumValues[i];
            m_SumSqValues[i] += or.m_SumSqValues[i];
        }
    }

    public void addScaled(double scale, ClusStatistic other) {
        RegressionStat or = (RegressionStat) other;
        m_SumWeight += scale * or.m_SumWeight;
        m_NbExamples += or.m_NbExamples;
        for (int i = 0; i < m_NbAttrs; i++) {
            m_SumWeights[i] += scale * or.m_SumWeights[i];
            m_SumValues[i] += scale * or.m_SumValues[i];
            m_SumSqValues[i] += scale * or.m_SumSqValues[i];
        }
    }

    public void subtractFromThis(ClusStatistic other) {
        RegressionStat or = (RegressionStat) other;
        m_SumWeight -= or.m_SumWeight;
        m_NbExamples -= or.m_NbExamples;
        for (int i = 0; i < m_NbAttrs; i++) {
            m_SumWeights[i] -= or.m_SumWeights[i];
            m_SumValues[i] -= or.m_SumValues[i];
            m_SumSqValues[i] -= or.m_SumSqValues[i];
        }
    }

    public void subtractFromOther(ClusStatistic other) {
        RegressionStat or = (RegressionStat) other;
        m_SumWeight = or.m_SumWeight - m_SumWeight;
        m_NbExamples = or.m_NbExamples - m_NbExamples;
        for (int i = 0; i < m_NbAttrs; i++) {
            m_SumWeights[i] = or.m_SumWeights[i] - m_SumWeights[i];
            m_SumValues[i] = or.m_SumValues[i] - m_SumValues[i];
            m_SumSqValues[i] = or.m_SumSqValues[i] - m_SumSqValues[i];
        }
    }

    public void updateWeighted(DataTuple tuple, double weight) {
        m_NbExamples++;
        m_SumWeight += weight;
        for (int i = 0; i < m_NbAttrs; i++) {
            double val = m_Attrs[i].getNumeric(tuple);
            if (val != Double.POSITIVE_INFINITY) {
                m_SumWeights[i] += weight;
                m_SumValues[i] += weight * val;
                m_SumSqValues[i] += weight * val * val;
            }
        }
    }

    public void calcMean(double[] means) {
        for (int i = 0; i < m_NbAttrs; i++) {
            // If divider zero, return zero
            means[i] = m_SumWeights[i] != 0.0 ? m_SumValues[i] / m_SumWeights[i] : 0.0;
        }
    }

    public double getMean(int i) {
        return m_SumWeights[i] != 0.0 ? m_SumValues[i] / m_SumWeights[i] : 0.0;
    }

    public double getSumValues(int i) {
        return m_SumValues[i];
    }

    public double getSumWeights(int i) {
        return m_SumWeights[i];
    }

    public double getSVarS(int i) {
        double n_tot = m_SumWeight;
        double k_tot = m_SumWeights[i];
        double sv_tot = m_SumValues[i];
        double ss_tot = m_SumSqValues[i];
        if (k_tot <= MathUtil.C1E_9 && m_Training != null) {
            return m_Training.getSVarS(i);
        } else {
            return (k_tot > 1.0) ? ss_tot * (n_tot - 1) / (k_tot - 1) - n_tot * sv_tot / k_tot * sv_tot / k_tot
                    : 0.0;
        }
    }

    public double getSVarS(ClusAttributeWeights scale) {
        double result = 0.0;
        for (int i = 0; i < m_NbAttrs; i++) {
            double n_tot = m_SumWeight;
            double k_tot = m_SumWeights[i];
            double sv_tot = m_SumValues[i];
            double ss_tot = m_SumSqValues[i];
            if (k_tot == n_tot) {
                result += (ss_tot - sv_tot * sv_tot / n_tot) * scale.getWeight(m_Attrs[i]);
            } else {
                if (k_tot <= MathUtil.C1E_9 && m_Training != null) {
                    result += m_Training.getSVarS(i) * scale.getWeight(m_Attrs[i]);
                } else {
                    result += (ss_tot * (n_tot - 1) / (k_tot - 1) - n_tot * sv_tot / k_tot * sv_tot / k_tot)
                            * scale.getWeight(m_Attrs[i]);
                }
            }
        }
        return result / m_NbAttrs;
    }

    public double getSVarSDiff(ClusAttributeWeights scale, ClusStatistic other) {
        double result = 0.0;
        RegressionStat or = (RegressionStat) other;
        for (int i = 0; i < m_NbAttrs; i++) {
            double n_tot = m_SumWeight - or.m_SumWeight;
            double k_tot = m_SumWeights[i] - or.m_SumWeights[i];
            double sv_tot = m_SumValues[i] - or.m_SumValues[i];
            double ss_tot = m_SumSqValues[i] - or.m_SumSqValues[i];
            if (k_tot == n_tot) {
                result += (ss_tot - sv_tot * sv_tot / n_tot) * scale.getWeight(m_Attrs[i]);
            } else {
                if (k_tot <= MathUtil.C1E_9 && m_Training != null) {
                    result += m_Training.getSVarS(i) * scale.getWeight(m_Attrs[i]);
                } else {
                    result += (ss_tot * (n_tot - 1) / (k_tot - 1) - n_tot * sv_tot / k_tot * sv_tot / k_tot)
                            * scale.getWeight(m_Attrs[i]);
                }
            }
        }
        return result / m_NbAttrs;
    }

    public String getString(StatisticPrintInfo info) {
        NumberFormat fr = ClusFormat.SIX_AFTER_DOT;
        StringBuffer buf = new StringBuffer();

        buf.append("[");
        for (int i = 0; i < m_NbAttrs; i++) {
            if (i != 0)
                buf.append(",");
            double tot = getSumWeights(i);
            if (tot == 0)
                buf.append("?");
            else
                buf.append(fr.format(getSumValues(i) / tot));
        }
        buf.append("]");
        if (info.SHOW_EXAMPLE_COUNT_BYTARGET) {
            buf.append(": [");
            for (int i = 0; i < m_NbAttrs; i++) {
                if (i != 0)
                    buf.append(",");
                buf.append(fr.format(m_SumWeights[i]));
            }
            buf.append("]");
        } else if (info.SHOW_EXAMPLE_COUNT) {
            buf.append(": ");
            buf.append(fr.format(m_SumWeight));
        }
        return buf.toString();
    }

    public void printDebug() {
        for (int i = 0; i < getNbAttributes(); i++) {
            double n_tot = m_SumWeight;
            double k_tot = m_SumWeights[i];
            double sv_tot = m_SumValues[i];
            double ss_tot = m_SumSqValues[i];
            System.out.println("n: " + n_tot + " k: " + k_tot);
            System.out.println("sv: " + sv_tot);
            System.out.println("ss: " + ss_tot);
            double mean = sv_tot / n_tot;
            double var = ss_tot - n_tot * mean * mean;
            System.out.println("mean: " + mean);
            System.out.println("var: " + var);
        }
        System.out.println("err: " + getError());
    }

    public RegressionStat getRegressionStat() {
        return this;
    }

    public double getSquaredDistance(ClusStatistic other) {
        double result = 0.0;
        RegressionStat o = (RegressionStat) other;
        for (int i = 0; i < m_NbAttrs; i++) {
            double distance = getMean(i) - o.getMean(i);
            result += distance * distance;
        }
        return result;
    }

    @Override
    public Element getPredictElement(Document doc) {
        Element stats = doc.createElement("RegressionStat");
        NumberFormat fr = ClusFormat.SIX_AFTER_DOT;
        Attr examples = doc.createAttribute("examples");
        examples.setValue(fr.format(m_SumWeight));

        stats.setAttributeNode(examples);
        for (int i = 0; i < m_NbAttrs; i++) {
            Element attr = doc.createElement("Target");
            Attr name = doc.createAttribute("name");
            name.setValue(m_Attrs[i].getName());
            attr.setAttributeNode(name);

            double tot = getSumWeights(i);
            if (tot == 0)
                attr.setTextContent("?");
            else
                attr.setTextContent(fr.format(getSumValues(i) / tot));

            stats.appendChild(attr);
        }
        return stats;
    }
}