org.knime.al.util.noveltydetection.knfst.KNFST.java Source code

Java tutorial

Introduction

Here is the source code for org.knime.al.util.noveltydetection.knfst.KNFST.java

Source

/*
 * ------------------------------------------------------------------------
 *
 *  Copyright by KNIME AG, Zurich, Switzerland
 *  Website: http://www.knime.com; Email: contact@knime.com
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License, Version 3, as
 *  published by the Free Software Foundation.
 *
 *  This program is distributed in the hope that it will be useful, but
 *  WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, see <http://www.gnu.org/licenses>.
 *
 *  Additional permission under GNU GPL version 3 section 7:
 *
 *  KNIME interoperates with ECLIPSE solely via ECLIPSE's plug-in APIs.
 *  Hence, KNIME and ECLIPSE are both independent programs and are not
 *  derived from each other. Should, however, the interpretation of the
 *  GNU GPL Version 3 ("License") under any applicable laws result in
 *  KNIME and ECLIPSE being a combined program, KNIME AG herewith grants
 *  you the additional permission to use and propagate KNIME together with
 *  ECLIPSE with only the license terms in place for ECLIPSE applying to
 *  ECLIPSE and the GNU GPL Version 3 applying for KNIME, provided the
 *  license terms of ECLIPSE themselves allow for the respective use and
 *  propagation of ECLIPSE together with KNIME.
 *
 *  Additional permission relating to nodes for KNIME that extend the Node
 *  Extension (and in particular that are based on subclasses of NodeModel,
 *  NodeDialog, and NodeView) and that only interoperate with KNIME through
 *  standard APIs ("Nodes"):
 *  Nodes are deemed to be separate and independent programs and to not be
 *  covered works.  Notwithstanding anything to the contrary in the
 *  License, the License does not apply to Nodes, you are not required to
 *  license Nodes under the License, and you are granted a license to
 *  prepare and propagate Nodes, in each case even if such Nodes are
 *  propagated with or for interoperation with KNIME.  The owner of a Node
 *  may freely choose the license terms applicable to such Node, including
 *  when such Node is propagated with or for interoperation with KNIME.
 * --------------------------------------------------------------------- *
 *
 */

package org.knime.al.util.noveltydetection.knfst;

import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;

import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.knime.al.util.noveltydetection.kernel.KernelCalculator;
import org.knime.core.data.DataRow;

public abstract class KNFST implements Externalizable {
    protected KernelCalculator m_kernel;
    protected RealMatrix m_projection;
    protected RealMatrix m_targetPoints;
    protected double[] m_betweenClassDistances;

    public KNFST() {

    }

    public KNFST(final KernelCalculator kernel) {
        m_kernel = kernel;
    }

    // public abstract NoveltyScores scoreTestData(BufferedDataTable test);

    public abstract NoveltyScores scoreTestData(DataRow testInstance);

    // public abstract NoveltyScores scoreTestData(double[][] test);

    public abstract NoveltyScores scoreTestData(RealMatrix kernelMatrix);

    public static RealMatrix projection(final RealMatrix kernelMatrix, final String[] labels)
            throws KNFSTException {

        final ClassWrapper[] classes = ClassWrapper.classes(labels);

        // check labels
        if (classes.length == 1) {
            throw new IllegalArgumentException(
                    "not able to calculate a nullspace from data of a single class using KNFST (input variable \"labels\" only contains a single value)");
        }

        // check kernel matrix
        if (!kernelMatrix.isSquare()) {
            throw new IllegalArgumentException("The KernelMatrix must be quadratic!");
        }

        // calculate weights of orthonormal basis in kernel space
        final RealMatrix centeredK = centerKernelMatrix(kernelMatrix);

        final EigenDecomposition eig = new EigenDecomposition(centeredK);
        final double[] eigVals = eig.getRealEigenvalues();
        final ArrayList<Integer> nonZeroEigValIndices = new ArrayList<Integer>();
        for (int i = 0; i < eigVals.length; i++) {
            if (eigVals[i] > 1e-12) {
                nonZeroEigValIndices.add(i);
            }
        }

        int eigIterator = 0;
        final RealMatrix eigVecs = eig.getV();
        RealMatrix basisvecs = null;
        try {
            basisvecs = MatrixUtils.createRealMatrix(eigVecs.getRowDimension(), nonZeroEigValIndices.size());
        } catch (final Exception e) {
            throw new KNFSTException("Something went wrong. Try different parameters or a different kernel.");
        }

        for (final Integer index : nonZeroEigValIndices) {
            final double normalizer = 1 / Math.sqrt(eigVals[index]);
            final RealVector basisVec = eigVecs.getColumnVector(eigIterator).mapMultiply(normalizer);
            basisvecs.setColumnVector(eigIterator++, basisVec);
        }

        // calculate transformation T of within class scatter Sw:
        // T= B'*K*(I-L) and L a block matrix
        final RealMatrix L = kernelMatrix.createMatrix(kernelMatrix.getRowDimension(),
                kernelMatrix.getColumnDimension());
        int start = 0;
        for (final ClassWrapper cl : classes) {
            final int count = cl.getCount();
            L.setSubMatrix(MatrixFunctions.ones(count, count).scalarMultiply(1.0 / count).getData(), start, start);
            start += count;
        }

        // need Matrix M with all entries 1/m to modify basisvecs which allows
        // usage of
        // uncentered kernel values (eye(size(M)).M)*basisvecs
        final RealMatrix M = MatrixFunctions
                .ones(kernelMatrix.getColumnDimension(), kernelMatrix.getColumnDimension())
                .scalarMultiply(1.0 / kernelMatrix.getColumnDimension());
        final RealMatrix I = MatrixUtils.createRealIdentityMatrix(M.getColumnDimension());

        // compute helper matrix H
        final RealMatrix H = ((I.subtract(M)).multiply(basisvecs)).transpose().multiply(kernelMatrix)
                .multiply(I.subtract(L));

        // T = H*H' = B'*Sw*B with B=basisvecs
        final RealMatrix T = H.multiply(H.transpose());

        // calculate weights for null space
        RealMatrix eigenvecs = MatrixFunctions.nullspace(T);

        if (eigenvecs == null) {
            final EigenDecomposition eigenComp = new EigenDecomposition(T);
            final double[] eigenvals = eigenComp.getRealEigenvalues();
            eigenvecs = eigenComp.getV();
            final int minId = MatrixFunctions.argmin(MatrixFunctions.abs(eigenvals));
            final double[] eigenvecsData = eigenvecs.getColumn(minId);
            eigenvecs = MatrixUtils.createColumnRealMatrix(eigenvecsData);
        }

        // System.out.println("eigenvecs:");
        // test.printMatrix(eigenvecs);

        // calculate null space projection
        final RealMatrix proj = ((I.subtract(M)).multiply(basisvecs)).multiply(eigenvecs);

        return proj;
    }

    private static RealMatrix centerKernelMatrix(final RealMatrix kernelMatrix) {
        // get size of kernelMatrix
        final int n = kernelMatrix.getRowDimension();

        // get mean values for each row/column
        final RealVector columnMeans = MatrixFunctions.columnMeans(kernelMatrix);
        final double matrixMean = MatrixFunctions.mean(kernelMatrix);

        RealMatrix centeredKernelMatrix = kernelMatrix.copy();

        for (int k = 0; k < n; k++) {
            centeredKernelMatrix.setRowVector(k, centeredKernelMatrix.getRowVector(k).subtract(columnMeans));
            centeredKernelMatrix.setColumnVector(k, centeredKernelMatrix.getColumnVector(k).subtract(columnMeans));
        }

        centeredKernelMatrix = centeredKernelMatrix.scalarAdd(matrixMean);

        return centeredKernelMatrix;
    }

    // Methods of Externalizable interface

    @Override
    public void readExternal(final ObjectInput arg0) throws IOException, ClassNotFoundException {

        try {
            // read kernel
            m_kernel = (KernelCalculator) Class.forName(arg0.readUTF()).newInstance();
            m_kernel.readExternal(arg0);

            // read projection
            // rows
            final int rowsProj = arg0.readInt();
            // columns
            final int colsProj = arg0.readInt();
            // data
            final double[][] projData = new double[rowsProj][colsProj];
            for (int r = 0; r < rowsProj; r++) {
                for (int c = 0; c < colsProj; c++) {
                    projData[r][c] = arg0.readDouble();
                }
            }
            // Matrix construction
            m_projection = MatrixUtils.createRealMatrix(projData);

            // read targetPoints
            // rows
            final int rowsTar = arg0.readInt();
            // columns
            final int colsTar = arg0.readInt();
            // data
            final double[][] tarData = new double[rowsTar][colsTar];
            for (int r = 0; r < rowsTar; r++) {
                for (int c = 0; c < colsTar; c++) {
                    tarData[r][c] = arg0.readDouble();
                }
            }
            // Matrix construction
            m_targetPoints = MatrixUtils.createRealMatrix(tarData);

            // read betweenClassDistances
            final double[] betweenClassDistances = new double[arg0.readInt()];
            for (int i = 0; i < betweenClassDistances.length; i++) {
                betweenClassDistances[i] = arg0.readDouble();
            }
            m_betweenClassDistances = betweenClassDistances;

        } catch (InstantiationException | IllegalAccessException e) {
            throw new IOException(e);
        }
    }

    @Override
    public void writeExternal(final ObjectOutput arg0) throws IOException {
        // write kernel
        arg0.writeUTF(m_kernel.getClass().getName());
        m_kernel.writeExternal(arg0);

        // write projection
        // rows
        arg0.writeInt(m_projection.getRowDimension());
        // columns
        arg0.writeInt(m_projection.getColumnDimension());
        // data
        final double[][] projData = m_projection.getData();
        for (final double[] row : projData) {
            for (final double cell : row) {
                arg0.writeDouble(cell);
            }
        }

        // write targetPoints
        // rows
        arg0.writeInt(m_targetPoints.getRowDimension());
        // columns
        arg0.writeInt(m_targetPoints.getColumnDimension());
        // data
        final double[][] tarData = m_targetPoints.getData();
        for (final double[] row : tarData) {
            for (final double cell : row) {
                arg0.writeDouble(cell);
            }
        }

        // write betweenClassDistances
        // length
        arg0.writeInt(m_betweenClassDistances.length);
        // data
        for (final double dist : m_betweenClassDistances) {
            arg0.writeDouble(dist);
        }
    }

    public double[] getBetweenClassDistances() {
        return m_betweenClassDistances;
    }

    public KernelCalculator getKernel() {
        return m_kernel;
    }

    public RealMatrix getProjection() {
        return m_projection;
    }

    public double[][] getTargetPoints() {
        return m_targetPoints.getData();
    }

    public int getNullspaceDimension() {
        return m_targetPoints.getColumnDimension();
    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result + Arrays.hashCode(m_betweenClassDistances);
        result = prime * result + ((m_kernel == null) ? 0 : m_kernel.hashCode());
        result = prime * result + ((m_projection == null) ? 0 : m_projection.hashCode());
        result = prime * result + ((m_targetPoints == null) ? 0 : m_targetPoints.hashCode());
        return result;
    }

    @Override
    public boolean equals(final Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null) {
            return false;
        }
        if (!(obj instanceof KNFST)) {
            return false;
        }
        final KNFST other = (KNFST) obj;
        if (!Arrays.equals(m_betweenClassDistances, other.m_betweenClassDistances)) {
            return false;
        }
        if (m_kernel == null) {
            if (other.m_kernel != null) {
                return false;
            }
        } else if (!m_kernel.equals(other.m_kernel)) {
            return false;
        }
        if (m_projection == null) {
            if (other.m_projection != null) {
                return false;
            }
        } else if (!m_projection.equals(other.m_projection)) {
            return false;
        }
        if (m_targetPoints == null) {
            if (other.m_targetPoints != null) {
                return false;
            }
        } else if (!m_targetPoints.equals(other.m_targetPoints)) {
            return false;
        }
        return true;
    }

    @Override
    public String toString() {
        return "KNFST [m_kernel=" + m_kernel + ", m_projection=" + m_projection + ", m_targetPoints="
                + m_targetPoints + ", m_betweenClassDistances=" + Arrays.toString(m_betweenClassDistances) + "]";
    }
}