Java tutorial
/* * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ /* * PLSFilter.java * Copyright (C) 2006 University of Waikato, Hamilton, New Zealand * */ package weka.filters.supervised.attribute; import weka.core.Attribute; import weka.core.Capabilities; import weka.core.Capabilities.Capability; import weka.core.DenseInstance; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; import weka.core.RevisionUtils; import weka.core.SelectedTag; import weka.core.Tag; import weka.core.TechnicalInformation; import weka.core.TechnicalInformation.Field; import weka.core.TechnicalInformation.Type; import weka.core.TechnicalInformationHandler; import weka.core.Utils; import weka.core.matrix.EigenvalueDecomposition; import weka.core.matrix.Matrix; import weka.filters.Filter; import weka.filters.SimpleBatchFilter; import weka.filters.SupervisedFilter; import weka.filters.unsupervised.attribute.Center; import weka.filters.unsupervised.attribute.ReplaceMissingValues; import weka.filters.unsupervised.attribute.Standardize; import java.util.ArrayList; import java.util.Collections; import java.util.Enumeration; import java.util.Vector; /** * <!-- globalinfo-start --> Runs Partial Least Square Regression over the given * instances and computes the resulting beta matrix for prediction.<br/> * By default it replaces missing values and centers the data.<br/> * <br/> * For more information see:<br/> * <br/> * Tormod Naes, Tomas Isaksson, Tom Fearn, Tony Davies (2002). A User Friendly * Guide to Multivariate Calibration and Classification. NIR Publications.<br/> * <br/> * StatSoft, Inc.. Partial Least Squares (PLS).<br/> * <br/> * Bent Jorgensen, Yuri Goegebeur. Module 7: Partial least squares regression I.<br/> * <br/> * S. de Jong (1993). SIMPLS: an alternative approach to partial least squares * regression. Chemometrics and Intelligent Laboratory Systems. 18:251-263. * <p/> * <!-- globalinfo-end --> * * <!-- technical-bibtex-start --> BibTeX: * * <pre> * @book{Naes2002, * author = {Tormod Naes and Tomas Isaksson and Tom Fearn and Tony Davies}, * publisher = {NIR Publications}, * title = {A User Friendly Guide to Multivariate Calibration and Classification}, * year = {2002}, * ISBN = {0-9528666-2-5} * } * * @misc{missing_id, * author = {StatSoft, Inc.}, * booktitle = {Electronic Textbook StatSoft}, * title = {Partial Least Squares (PLS)}, * HTTP = {http://www.statsoft.com/textbook/stpls.html} * } * * @misc{missing_id, * author = {Bent Jorgensen and Yuri Goegebeur}, * booktitle = {ST02: Multivariate Data Analysis and Chemometrics}, * title = {Module 7: Partial least squares regression I}, * HTTP = {http://statmaster.sdu.dk/courses/ST02/module07/} * } * * @article{Jong1993, * author = {S. de Jong}, * journal = {Chemometrics and Intelligent Laboratory Systems}, * pages = {251-263}, * title = {SIMPLS: an alternative approach to partial least squares regression}, * volume = {18}, * year = {1993} * } * </pre> * <p/> * <!-- technical-bibtex-end --> * * <!-- options-start --> Valid options are: * <p/> * * <pre> * -D * Turns on output of debugging information. * </pre> * * <pre> * -C <num> * The number of components to compute. * (default: 20) * </pre> * * <pre> * -U * Updates the class attribute as well. * (default: off) * </pre> * * <pre> * -M * Turns replacing of missing values on. * (default: off) * </pre> * * <pre> * -A <SIMPLS|PLS1> * The algorithm to use. * (default: PLS1) * </pre> * * <pre> * -P <none|center|standardize> * The type of preprocessing that is applied to the data. * (default: center) * </pre> * * <!-- options-end --> * * @author FracPete (fracpete at waikato dot ac dot nz) * @version $Revision$ */ public class PLSFilter extends SimpleBatchFilter implements SupervisedFilter, TechnicalInformationHandler { /** for serialization */ static final long serialVersionUID = -3335106965521265631L; /** the type of algorithm: SIMPLS */ public static final int ALGORITHM_SIMPLS = 1; /** the type of algorithm: PLS1 */ public static final int ALGORITHM_PLS1 = 2; /** the types of algorithm */ public static final Tag[] TAGS_ALGORITHM = { new Tag(ALGORITHM_SIMPLS, "SIMPLS"), new Tag(ALGORITHM_PLS1, "PLS1") }; /** the type of preprocessing: None */ public static final int PREPROCESSING_NONE = 0; /** the type of preprocessing: Center */ public static final int PREPROCESSING_CENTER = 1; /** the type of preprocessing: Standardize */ public static final int PREPROCESSING_STANDARDIZE = 2; /** the types of preprocessing */ public static final Tag[] TAGS_PREPROCESSING = { new Tag(PREPROCESSING_NONE, "none"), new Tag(PREPROCESSING_CENTER, "center"), new Tag(PREPROCESSING_STANDARDIZE, "standardize") }; /** the maximum number of components to generate */ protected int m_NumComponents = 20; /** the type of algorithm */ protected int m_Algorithm = ALGORITHM_PLS1; /** the regression vector "r-hat" for PLS1 */ protected Matrix m_PLS1_RegVector = null; /** the P matrix for PLS1 */ protected Matrix m_PLS1_P = null; /** the W matrix for PLS1 */ protected Matrix m_PLS1_W = null; /** the b-hat vector for PLS1 */ protected Matrix m_PLS1_b_hat = null; /** the W matrix for SIMPLS */ protected Matrix m_SIMPLS_W = null; /** the B matrix for SIMPLS (used for prediction) */ protected Matrix m_SIMPLS_B = null; /** whether to include the prediction, i.e., modifying the class attribute */ protected boolean m_PerformPrediction = false; /** for replacing missing values */ protected Filter m_Missing = null; /** whether to replace missing values */ protected boolean m_ReplaceMissing = true; /** for centering the data */ protected Filter m_Filter = null; /** the type of preprocessing */ protected int m_Preprocessing = PREPROCESSING_CENTER; /** the mean of the class */ protected double m_ClassMean = 0; /** the standard deviation of the class */ protected double m_ClassStdDev = 0; /** * default constructor */ public PLSFilter() { super(); // setup pre-processing m_Missing = new ReplaceMissingValues(); m_Filter = new Center(); } /** * Returns a string describing this classifier. * * @return a description of the classifier suitable for displaying in the * explorer/experimenter gui */ @Override public String globalInfo() { return "Runs Partial Least Square Regression over the given instances " + "and computes the resulting beta matrix for prediction.\n" + "By default it replaces missing values and centers the data.\n\n" + "For more information see:\n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing detailed * information about the technical background of this class, e.g., paper * reference or book this class is based on. * * @return the technical information about this class */ @Override public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; TechnicalInformation additional; result = new TechnicalInformation(Type.BOOK); result.setValue(Field.AUTHOR, "Tormod Naes and Tomas Isaksson and Tom Fearn and Tony Davies"); result.setValue(Field.YEAR, "2002"); result.setValue(Field.TITLE, "A User Friendly Guide to Multivariate Calibration and Classification"); result.setValue(Field.PUBLISHER, "NIR Publications"); result.setValue(Field.ISBN, "0-9528666-2-5"); additional = result.add(Type.MISC); additional.setValue(Field.AUTHOR, "StatSoft, Inc."); additional.setValue(Field.TITLE, "Partial Least Squares (PLS)"); additional.setValue(Field.BOOKTITLE, "Electronic Textbook StatSoft"); additional.setValue(Field.HTTP, "http://www.statsoft.com/textbook/stpls.html"); additional = result.add(Type.MISC); additional.setValue(Field.AUTHOR, "Bent Jorgensen and Yuri Goegebeur"); additional.setValue(Field.TITLE, "Module 7: Partial least squares regression I"); additional.setValue(Field.BOOKTITLE, "ST02: Multivariate Data Analysis and Chemometrics"); additional.setValue(Field.HTTP, "http://statmaster.sdu.dk/courses/ST02/module07/"); additional = result.add(Type.ARTICLE); additional.setValue(Field.AUTHOR, "S. de Jong"); additional.setValue(Field.YEAR, "1993"); additional.setValue(Field.TITLE, "SIMPLS: an alternative approach to partial least squares regression"); additional.setValue(Field.JOURNAL, "Chemometrics and Intelligent Laboratory Systems"); additional.setValue(Field.VOLUME, "18"); additional.setValue(Field.PAGES, "251-263"); return result; } /** * Gets an enumeration describing the available options. * * @return an enumeration of all the available options. */ @Override public Enumeration<Option> listOptions() { Vector<Option> result = new Vector<Option>(); result.addElement( new Option("\tThe number of components to compute.\n" + "\t(default: 20)", "C", 1, "-C <num>")); result.addElement( new Option("\tUpdates the class attribute as well.\n" + "\t(default: off)", "U", 0, "-U")); result.addElement( new Option("\tTurns replacing of missing values on.\n" + "\t(default: off)", "M", 0, "-M")); String param = ""; for (int i = 0; i < TAGS_ALGORITHM.length; i++) { if (i > 0) { param += "|"; } SelectedTag tag = new SelectedTag(TAGS_ALGORITHM[i].getID(), TAGS_ALGORITHM); param += tag.getSelectedTag().getReadable(); } result.addElement( new Option("\tThe algorithm to use.\n" + "\t(default: PLS1)", "A", 1, "-A <" + param + ">")); param = ""; for (int i = 0; i < TAGS_PREPROCESSING.length; i++) { if (i > 0) { param += "|"; } SelectedTag tag = new SelectedTag(TAGS_PREPROCESSING[i].getID(), TAGS_PREPROCESSING); param += tag.getSelectedTag().getReadable(); } result.addElement( new Option("\tThe type of preprocessing that is applied to the data.\n" + "\t(default: center)", "P", 1, "-P <" + param + ">")); result.addAll(Collections.list(super.listOptions())); return result.elements(); } /** * returns the options of the current setup * * @return the current options */ @Override public String[] getOptions() { Vector<String> result = new Vector<String>(); result.add("-C"); result.add("" + getNumComponents()); if (getPerformPrediction()) { result.add("-U"); } if (getReplaceMissing()) { result.add("-M"); } result.add("-A"); result.add("" + getAlgorithm().getSelectedTag().getReadable()); result.add("-P"); result.add("" + getPreprocessing().getSelectedTag().getReadable()); Collections.addAll(result, super.getOptions()); return result.toArray(new String[result.size()]); } /** * Parses the options for this object. * <p/> * * <!-- options-start --> Valid options are: * <p/> * * <pre> * -D * Turns on output of debugging information. * </pre> * * <pre> * -C <num> * The number of components to compute. * (default: 20) * </pre> * * <pre> * -U * Updates the class attribute as well. * (default: off) * </pre> * * <pre> * -M * Turns replacing of missing values on. * (default: off) * </pre> * * <pre> * -A <SIMPLS|PLS1> * The algorithm to use. * (default: PLS1) * </pre> * * <pre> * -P <none|center|standardize> * The type of preprocessing that is applied to the data. * (default: center) * </pre> * * <!-- options-end --> * * @param options the options to use * @throws Exception if the option setting fails */ @Override public void setOptions(String[] options) throws Exception { String tmpStr; tmpStr = Utils.getOption("C", options); if (tmpStr.length() != 0) { setNumComponents(Integer.parseInt(tmpStr)); } else { setNumComponents(20); } setPerformPrediction(Utils.getFlag("U", options)); setReplaceMissing(Utils.getFlag("M", options)); tmpStr = Utils.getOption("A", options); if (tmpStr.length() != 0) { setAlgorithm(new SelectedTag(tmpStr, TAGS_ALGORITHM)); } else { setAlgorithm(new SelectedTag(ALGORITHM_PLS1, TAGS_ALGORITHM)); } tmpStr = Utils.getOption("P", options); if (tmpStr.length() != 0) { setPreprocessing(new SelectedTag(tmpStr, TAGS_PREPROCESSING)); } else { setPreprocessing(new SelectedTag(PREPROCESSING_CENTER, TAGS_PREPROCESSING)); } super.setOptions(options); Utils.checkForRemainingOptions(options); } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String numComponentsTipText() { return "The number of components to compute."; } /** * sets the maximum number of attributes to use. * * @param value the maximum number of attributes */ public void setNumComponents(int value) { m_NumComponents = value; } /** * returns the maximum number of attributes to use. * * @return the current maximum number of attributes */ public int getNumComponents() { return m_NumComponents; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String performPredictionTipText() { return "Whether to update the class attribute with the predicted value."; } /** * Sets whether to update the class attribute with the predicted value. * * @param value if true the class value will be replaced by the predicted * value. */ public void setPerformPrediction(boolean value) { m_PerformPrediction = value; } /** * Gets whether the class attribute is updated with the predicted value. * * @return true if the class attribute is updated */ public boolean getPerformPrediction() { return m_PerformPrediction; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String algorithmTipText() { return "Sets the type of algorithm to use."; } /** * Sets the type of algorithm to use * * @param value the algorithm type */ public void setAlgorithm(SelectedTag value) { if (value.getTags() == TAGS_ALGORITHM) { m_Algorithm = value.getSelectedTag().getID(); } } /** * Gets the type of algorithm to use * * @return the current algorithm type. */ public SelectedTag getAlgorithm() { return new SelectedTag(m_Algorithm, TAGS_ALGORITHM); } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String replaceMissingTipText() { return "Whether to replace missing values."; } /** * Sets whether to replace missing values. * * @param value if true missing values are replaced with the * ReplaceMissingValues filter. */ public void setReplaceMissing(boolean value) { m_ReplaceMissing = value; } /** * Gets whether missing values are replace. * * @return true if missing values are replaced with the ReplaceMissingValues * filter */ public boolean getReplaceMissing() { return m_ReplaceMissing; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String preprocessingTipText() { return "Sets the type of preprocessing to use."; } /** * Sets the type of preprocessing to use * * @param value the preprocessing type */ public void setPreprocessing(SelectedTag value) { if (value.getTags() == TAGS_PREPROCESSING) { m_Preprocessing = value.getSelectedTag().getID(); } } /** * Gets the type of preprocessing to use * * @return the current preprocessing type. */ public SelectedTag getPreprocessing() { return new SelectedTag(m_Preprocessing, TAGS_PREPROCESSING); } /** * Determines the output format based on the input format and returns this. In * case the output format cannot be returned immediately, i.e., * immediateOutputFormat() returns false, then this method will be called from * batchFinished(). * * @param inputFormat the input format to base the output format on * @return the output format * @throws Exception in case the determination goes wrong * @see #hasImmediateOutputFormat() * @see #batchFinished() */ @Override protected Instances determineOutputFormat(Instances inputFormat) throws Exception { // generate header ArrayList<Attribute> atts = new ArrayList<Attribute>(); String prefix = getAlgorithm().getSelectedTag().getReadable(); for (int i = 0; i < getNumComponents(); i++) { atts.add(new Attribute(prefix + "_" + (i + 1))); } atts.add(new Attribute(inputFormat.classAttribute().name())); Instances result = new Instances(prefix, atts, 0); result.setClassIndex(result.numAttributes() - 1); return result; } /** * returns the data minus the class column as matrix * * @param instances the data to work on * @return the data without class attribute */ protected Matrix getX(Instances instances) { double[][] x; double[] values; Matrix result; int i; int n; int j; int clsIndex; clsIndex = instances.classIndex(); x = new double[instances.numInstances()][]; for (i = 0; i < instances.numInstances(); i++) { values = instances.instance(i).toDoubleArray(); x[i] = new double[values.length - 1]; j = 0; for (n = 0; n < values.length; n++) { if (n != clsIndex) { x[i][j] = values[n]; j++; } } } result = new Matrix(x); return result; } /** * returns the data minus the class column as matrix * * @param instance the instance to work on * @return the data without the class attribute */ protected Matrix getX(Instance instance) { double[][] x; double[] values; Matrix result; x = new double[1][]; values = instance.toDoubleArray(); x[0] = new double[values.length - 1]; System.arraycopy(values, 0, x[0], 0, values.length - 1); result = new Matrix(x); return result; } /** * returns the data class column as matrix * * @param instances the data to work on * @return the class attribute */ protected Matrix getY(Instances instances) { double[][] y; Matrix result; int i; y = new double[instances.numInstances()][1]; for (i = 0; i < instances.numInstances(); i++) { y[i][0] = instances.instance(i).classValue(); } result = new Matrix(y); return result; } /** * returns the data class column as matrix * * @param instance the instance to work on * @return the class attribute */ protected Matrix getY(Instance instance) { double[][] y; Matrix result; y = new double[1][1]; y[0][0] = instance.classValue(); result = new Matrix(y); return result; } /** * returns the X and Y matrix again as Instances object, based on the given * header (must have a class attribute set). * * @param header the format of the instance object * @param x the X matrix (data) * @param y the Y matrix (class) * @return the assembled data */ protected Instances toInstances(Instances header, Matrix x, Matrix y) { double[] values; int i; int n; Instances result; int rows; int cols; int offset; int clsIdx; result = new Instances(header, 0); rows = x.getRowDimension(); cols = x.getColumnDimension(); clsIdx = header.classIndex(); for (i = 0; i < rows; i++) { values = new double[cols + 1]; offset = 0; for (n = 0; n < values.length; n++) { if (n == clsIdx) { offset--; values[n] = y.get(i, 0); } else { values[n] = x.get(i, n + offset); } } result.add(new DenseInstance(1.0, values)); } return result; } /** * returns the given column as a vector (actually a n x 1 matrix) * * @param m the matrix to work on * @param columnIndex the column to return * @return the column as n x 1 matrix */ protected Matrix columnAsVector(Matrix m, int columnIndex) { Matrix result; int i; result = new Matrix(m.getRowDimension(), 1); for (i = 0; i < m.getRowDimension(); i++) { result.set(i, 0, m.get(i, columnIndex)); } return result; } /** * stores the data from the (column) vector in the matrix at the specified * index * * @param v the vector to store in the matrix * @param m the receiving matrix * @param columnIndex the column to store the values in */ protected void setVector(Matrix v, Matrix m, int columnIndex) { m.setMatrix(0, m.getRowDimension() - 1, columnIndex, columnIndex, v); } /** * returns the (column) vector of the matrix at the specified index * * @param m the matrix to work on * @param columnIndex the column to get the values from * @return the column vector */ protected Matrix getVector(Matrix m, int columnIndex) { return m.getMatrix(0, m.getRowDimension() - 1, columnIndex, columnIndex); } /** * determines the dominant eigenvector for the given matrix and returns it * * @param m the matrix to determine the dominant eigenvector for * @return the dominant eigenvector */ protected Matrix getDominantEigenVector(Matrix m) { EigenvalueDecomposition eigendecomp; double[] eigenvalues; int index; Matrix result; eigendecomp = m.eig(); eigenvalues = eigendecomp.getRealEigenvalues(); index = Utils.maxIndex(eigenvalues); result = columnAsVector(eigendecomp.getV(), index); return result; } /** * normalizes the given vector (inplace) * * @param v the vector to normalize */ protected void normalizeVector(Matrix v) { double sum; int i; // determine length sum = 0; for (i = 0; i < v.getRowDimension(); i++) { sum += v.get(i, 0) * v.get(i, 0); } sum = StrictMath.sqrt(sum); // normalize content for (i = 0; i < v.getRowDimension(); i++) { v.set(i, 0, v.get(i, 0) / sum); } } /** * processes the instances using the PLS1 algorithm * * @param instances the data to process * @return the modified data * @throws Exception in case the processing goes wrong */ protected Instances processPLS1(Instances instances) throws Exception { Matrix X, X_trans, x; Matrix y; Matrix W, w; Matrix T, t, t_trans; Matrix P, p, p_trans; double b; Matrix b_hat; int i; int j; Matrix tmp; Instances result; Instances tmpInst; // initialization if (!isFirstBatchDone()) { // split up data X = getX(instances); y = getY(instances); X_trans = X.transpose(); // init W = new Matrix(instances.numAttributes() - 1, getNumComponents()); P = new Matrix(instances.numAttributes() - 1, getNumComponents()); T = new Matrix(instances.numInstances(), getNumComponents()); b_hat = new Matrix(getNumComponents(), 1); for (j = 0; j < getNumComponents(); j++) { // 1. step: wj w = X_trans.times(y); normalizeVector(w); setVector(w, W, j); // 2. step: tj t = X.times(w); t_trans = t.transpose(); setVector(t, T, j); // 3. step: ^bj b = t_trans.times(y).get(0, 0) / t_trans.times(t).get(0, 0); b_hat.set(j, 0, b); // 4. step: pj p = X_trans.times(t).times(1 / t_trans.times(t).get(0, 0)); p_trans = p.transpose(); setVector(p, P, j); // 5. step: Xj+1 X = X.minus(t.times(p_trans)); y = y.minus(t.times(b)); } // W*(P^T*W)^-1 tmp = W.times(((P.transpose()).times(W)).inverse()); // factor = W*(P^T*W)^-1 * b_hat m_PLS1_RegVector = tmp.times(b_hat); // save matrices m_PLS1_P = P; m_PLS1_W = W; m_PLS1_b_hat = b_hat; } result = new Instances(getOutputFormat()); for (i = 0; i < instances.numInstances(); i++) { // work on each instance tmpInst = new Instances(instances, 0); tmpInst.add((Instance) instances.instance(i).copy()); x = getX(tmpInst); X = new Matrix(1, getNumComponents()); T = new Matrix(1, getNumComponents()); for (j = 0; j < getNumComponents(); j++) { setVector(x, X, j); // 1. step: tj = xj * wj t = x.times(getVector(m_PLS1_W, j)); setVector(t, T, j); // 2. step: xj+1 = xj - tj*pj^T (tj is 1x1 matrix!) x = x.minus(getVector(m_PLS1_P, j).transpose().times(t.get(0, 0))); } if (getPerformPrediction()) { tmpInst = toInstances(getOutputFormat(), T, T.times(m_PLS1_b_hat)); } else { tmpInst = toInstances(getOutputFormat(), T, getY(tmpInst)); } result.add(tmpInst.instance(0)); } return result; } /** * processes the instances using the SIMPLS algorithm * * @param instances the data to process * @return the modified data * @throws Exception in case the processing goes wrong */ protected Instances processSIMPLS(Instances instances) throws Exception { Matrix A, A_trans; Matrix M; Matrix X, X_trans; Matrix X_new; Matrix Y, y; Matrix C, c; Matrix Q, q; Matrix W, w; Matrix P, p, p_trans; Matrix v, v_trans; Matrix T; Instances result; int h; if (!isFirstBatchDone()) { // init X = getX(instances); X_trans = X.transpose(); Y = getY(instances); A = X_trans.times(Y); M = X_trans.times(X); C = Matrix.identity(instances.numAttributes() - 1, instances.numAttributes() - 1); W = new Matrix(instances.numAttributes() - 1, getNumComponents()); P = new Matrix(instances.numAttributes() - 1, getNumComponents()); Q = new Matrix(1, getNumComponents()); for (h = 0; h < getNumComponents(); h++) { // 1. qh as dominant EigenVector of Ah'*Ah A_trans = A.transpose(); q = getDominantEigenVector(A_trans.times(A)); // 2. wh=Ah*qh, ch=wh'*Mh*wh, wh=wh/sqrt(ch), store wh in W as column w = A.times(q); c = w.transpose().times(M).times(w); w = w.times(1.0 / StrictMath.sqrt(c.get(0, 0))); setVector(w, W, h); // 3. ph=Mh*wh, store ph in P as column p = M.times(w); p_trans = p.transpose(); setVector(p, P, h); // 4. qh=Ah'*wh, store qh in Q as column q = A_trans.times(w); setVector(q, Q, h); // 5. vh=Ch*ph, vh=vh/||vh|| v = C.times(p); normalizeVector(v); v_trans = v.transpose(); // 6. Ch+1=Ch-vh*vh', Mh+1=Mh-ph*ph' C = C.minus(v.times(v_trans)); M = M.minus(p.times(p_trans)); // 7. Ah+1=ChAh (actually Ch+1) A = C.times(A); } // finish m_SIMPLS_W = W; T = X.times(m_SIMPLS_W); X_new = T; m_SIMPLS_B = W.times(Q.transpose()); if (getPerformPrediction()) { y = T.times(P.transpose()).times(m_SIMPLS_B); } else { y = getY(instances); } result = toInstances(getOutputFormat(), X_new, y); } else { result = new Instances(getOutputFormat()); X = getX(instances); X_new = X.times(m_SIMPLS_W); if (getPerformPrediction()) { y = X.times(m_SIMPLS_B); } else { y = getY(instances); } result = toInstances(getOutputFormat(), X_new, y); } return result; } /** * Returns the Capabilities of this filter. * * @return the capabilities of this object * @see Capabilities */ @Override public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NUMERIC_CLASS); result.enable(Capability.DATE_CLASS); return result; } /** * Processes the given data (may change the provided dataset) and returns the * modified version. This method is called in batchFinished(). * * @param instances the data to process * @return the modified data * @throws Exception in case the processing goes wrong * @see #batchFinished() */ @Override protected Instances process(Instances instances) throws Exception { Instances result; int i; double clsValue; double[] clsValues; result = null; // save original class values if no prediction is performed if (!getPerformPrediction()) { clsValues = instances.attributeToDoubleArray(instances.classIndex()); } else { clsValues = null; } if (!isFirstBatchDone()) { // init filters if (m_ReplaceMissing) { m_Missing.setInputFormat(instances); } switch (m_Preprocessing) { case PREPROCESSING_CENTER: m_ClassMean = instances.meanOrMode(instances.classIndex()); m_ClassStdDev = 1; m_Filter = new Center(); ((Center) m_Filter).setIgnoreClass(true); break; case PREPROCESSING_STANDARDIZE: m_ClassMean = instances.meanOrMode(instances.classIndex()); m_ClassStdDev = StrictMath.sqrt(instances.variance(instances.classIndex())); m_Filter = new Standardize(); ((Standardize) m_Filter).setIgnoreClass(true); break; default: m_ClassMean = 0; m_ClassStdDev = 1; m_Filter = null; } if (m_Filter != null) { m_Filter.setInputFormat(instances); } } // filter data if (m_ReplaceMissing) { instances = Filter.useFilter(instances, m_Missing); } if (m_Filter != null) { instances = Filter.useFilter(instances, m_Filter); } switch (m_Algorithm) { case ALGORITHM_SIMPLS: result = processSIMPLS(instances); break; case ALGORITHM_PLS1: result = processPLS1(instances); break; default: throw new IllegalStateException("Algorithm type '" + m_Algorithm + "' is not recognized!"); } // add the mean to the class again if predictions are to be performed, // otherwise restore original class values for (i = 0; i < result.numInstances(); i++) { if (!getPerformPrediction()) { result.instance(i).setClassValue(clsValues[i]); } else { clsValue = result.instance(i).classValue(); result.instance(i).setClassValue(clsValue * m_ClassStdDev + m_ClassMean); } } return result; } /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision$"); } /** * runs the filter with the given arguments. * * @param args the commandline arguments */ public static void main(String[] args) { runFilter(new PLSFilter(), args); } }