org.knime.al.util.noveltydetection.kernel.KernelCalculator.java Source code

Java tutorial

Introduction

Here is the source code for org.knime.al.util.noveltydetection.kernel.KernelCalculator.java

Source

/*
 * ------------------------------------------------------------------------
 *
 *  Copyright by KNIME AG, Zurich, Switzerland
 *  Website: http://www.knime.com; Email: contact@knime.com
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License, Version 3, as
 *  published by the Free Software Foundation.
 *
 *  This program is distributed in the hope that it will be useful, but
 *  WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, see <http://www.gnu.org/licenses>.
 *
 *  Additional permission under GNU GPL version 3 section 7:
 *
 *  KNIME interoperates with ECLIPSE solely via ECLIPSE's plug-in APIs.
 *  Hence, KNIME and ECLIPSE are both independent programs and are not
 *  derived from each other. Should, however, the interpretation of the
 *  GNU GPL Version 3 ("License") under any applicable laws result in
 *  KNIME and ECLIPSE being a combined program, KNIME AG herewith grants
 *  you the additional permission to use and propagate KNIME together with
 *  ECLIPSE with only the license terms in place for ECLIPSE applying to
 *  ECLIPSE and the GNU GPL Version 3 applying for KNIME, provided the
 *  license terms of ECLIPSE themselves allow for the respective use and
 *  propagation of ECLIPSE together with KNIME.
 *
 *  Additional permission relating to nodes for KNIME that extend the Node
 *  Extension (and in particular that are based on subclasses of NodeModel,
 *  NodeDialog, and NodeView) and that only interoperate with KNIME through
 *  standard APIs ("Nodes"):
 *  Nodes are deemed to be separate and independent programs and to not be
 *  covered works.  Notwithstanding anything to the contrary in the
 *  License, the License does not apply to Nodes, you are not required to
 *  license Nodes under the License, and you are granted a license to
 *  prepare and propagate Nodes, in each case even if such Nodes are
 *  propagated with or for interoperation with KNIME.  The owner of a Node
 *  may freely choose the license terms applicable to such Node, including
 *  when such Node is propagated with or for interoperation with KNIME.
 * --------------------------------------------------------------------- *
 *
 */

package org.knime.al.util.noveltydetection.kernel;

import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.concurrent.Semaphore;

import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.knime.core.data.DataCell;
import org.knime.core.data.DataRow;
import org.knime.core.data.DoubleValue;
import org.knime.core.node.BufferedDataTable;
import org.knime.core.node.ExecutionMonitor;
import org.knime.core.node.KNIMEConstants;
import org.knime.core.util.ThreadPool;

public class KernelCalculator implements Externalizable {
    static final int DEFAULT_NUM_CORES = 4;

    public enum KernelType {
        RBF("RBF"), HIK("HIK"), EXPHIK("EXPHIK"), Polynomial("Polynomial");

        private final String m_name;

        private KernelType(final String name) {
            m_name = name;
        }

        @Override
        public String toString() {
            return m_name;
        }
    }

    // Holds the training data
    private double[][] m_trainingData;
    private int m_rowCount;
    private int m_colCount;

    private KernelFunction m_kernelFunction;

    public KernelCalculator() {

    }

    public KernelCalculator(final BufferedDataTable trainingData, final KernelFunction kernelFunction) {
        m_trainingData = readBufferedDataTable(trainingData);
        final long size = trainingData.size();
        if (size > Integer.MAX_VALUE) {
            throw new IllegalStateException("The input table is too large.");
        }
        m_rowCount = (int) size;
        m_colCount = trainingData.getDataTableSpec().getNumColumns();
        m_kernelFunction = kernelFunction;
    }

    public KernelCalculator(final double[][] trainingData, final KernelFunction kernelFunction) {
        m_trainingData = trainingData;
        m_rowCount = trainingData.length;
        m_colCount = trainingData[0].length;
        m_kernelFunction = kernelFunction;
    }

    // public KernelCalculator(KernelFunction kernelFunction) {
    // m_kernelFunction = kernelFunction;
    // }

    /*
     * Returns kernel matrix containing similarities of the training data
     * Output: mxm matrix containing similarities of the training data
     */
    public RealMatrix kernelize(final ExecutionMonitor progMon) throws Exception {
        return calculateKernelMatrix(m_trainingData, m_trainingData, progMon);
    }

    public RealMatrix kernelize(final BufferedDataTable trainingData, final BufferedDataTable testData,
            final ExecutionMonitor progMon) throws Exception {
        return calculateKernelMatrix(readBufferedDataTable(trainingData), readBufferedDataTable(testData), progMon);
    }

    /*
     * Returns kernel matrix containing similarities of test data with training
     * data Parameters: testData: BufferedDataTable containing the test data
     * Output: nxm matrix containing the similarities of n test samples with m
     * training samples
     */
    public RealMatrix kernelize(final BufferedDataTable testData, final ExecutionMonitor progMon) throws Exception {
        return calculateKernelMatrix(m_trainingData, readBufferedDataTable(testData), progMon);
    }

    public RealMatrix kernelize(final DataRow testInstance) {
        return calculateKernelVector(m_trainingData, readDataRow(testInstance), m_kernelFunction);
    }

    private double[] readDataRow(final DataRow row) {
        final double[] data = new double[row.getNumCells()];
        for (int i = 0; i < row.getNumCells(); i++) {
            final DataCell cell = row.getCell(i);
            if (cell.isMissing()) {
                throw new IllegalArgumentException("Missing values are not supported.");
            } else if (!cell.getType().isCompatible(DoubleValue.class)) {
                throw new IllegalArgumentException("Only numerical data types are currently supported.");
            } else {
                data[i] = ((DoubleValue) cell).getDoubleValue();
            }
        }
        return data;
    }

    public RealMatrix kernelize(final double[][] testData, final ExecutionMonitor progMon) throws Exception {
        return calculateKernelMatrix(m_trainingData, testData, progMon);
    }

    public RealMatrix kernelize(final double[][] trainingData, final double[][] testData,
            final ExecutionMonitor progMon) throws Exception {

        return calculateKernelMatrix(trainingData, testData, progMon);
    }

    public int getNumTrainingSamples() {
        return m_rowCount;
    }

    public RealMatrix calculateKernelMatrix(final double[][] training, final double[][] test,
            final ExecutionMonitor progMon) throws Exception {

        final ThreadPool pool = KNIMEConstants.GLOBAL_THREAD_POOL;
        final int procCount = (int) (Runtime.getRuntime().availableProcessors() * (2.0 / 3));
        final Semaphore semaphore = new Semaphore(procCount);
        RealMatrix result = null;
        try {
            result = pool.runInvisible(new Callable<RealMatrix>() {

                @Override
                public RealMatrix call() throws Exception {
                    final double[][] resultArrayMatrix = new double[training.length][test.length];
                    final CalculateKernelValuesRunnable[] kct = new CalculateKernelValuesRunnable[test.length];
                    final int numberOfRunnables = kct.length;
                    for (int i = 0; i < numberOfRunnables; i++) {
                        kct[i] = new CalculateKernelValuesRunnable(0, training.length, i, i + 1, training, test,
                                resultArrayMatrix, m_kernelFunction, semaphore);
                    }
                    final Future<?>[] threads = new Future<?>[numberOfRunnables];
                    double progCounter = 0;
                    for (int i = 0; i < numberOfRunnables; i++) {
                        try {
                            progMon.checkCanceled();
                        } catch (final Exception e) {
                            for (int j = 0; j < i; j++) {
                                if (threads[j] != null) {
                                    threads[j].cancel(true);
                                }
                            }
                            throw e;
                        }
                        semaphore.acquire();
                        threads[i] = pool.enqueue(kct[i]);
                        progMon.setProgress(progCounter / (2 * numberOfRunnables),
                                "Kernel calculation started (" + i + "/" + numberOfRunnables + ")");
                        progCounter += 1;
                    }
                    for (int i = 0; i < numberOfRunnables; i++) {
                        try {
                            progMon.checkCanceled();
                        } catch (final Exception e) {
                            for (int j = 0; j < numberOfRunnables; j++) {
                                if (threads[j] != null) {
                                    threads[j].cancel(true);
                                }
                            }
                            throw e;
                        }
                        semaphore.acquire();
                        threads[i].get();
                        semaphore.release();
                        progMon.setProgress(progCounter / (2 * numberOfRunnables),
                                "Kernel calculation finished (" + i + "/" + numberOfRunnables + ")");
                        progCounter += 1;
                    }
                    return MatrixUtils.createRealMatrix(resultArrayMatrix);
                }

            });
        } catch (final Exception e) {
            throw e;
        }

        return result;
    }

    private RealMatrix calculateKernelVector(final double[][] training, final double[] test,
            final KernelFunction kernelFunction) {
        final double[] result = new double[training.length];

        for (int r = 0; r < training.length; r++) {
            result[r] = kernelFunction.calculate(training[r], test);
        }
        return MatrixUtils.createColumnRealMatrix(result);
    }

    private class CalculateKernelValuesRunnable implements Runnable {

        private final int m_trainingStart;
        private final int m_trainingEnd;
        private final int m_testStart;
        private final int m_testEnd;
        private final double[][] m_training;
        private final double[][] m_test;
        private final double[][] m_result;
        private final KernelFunction m_kernelFunction;
        private final Semaphore m_semaphore;

        public CalculateKernelValuesRunnable(final int trainingStart, final int trainingEnd, final int testStart,
                final int testEnd, final double[][] training, final double[][] test, final double[][] result,
                final KernelFunction kernelFunction, final Semaphore semaphore) {
            m_trainingStart = trainingStart;
            m_trainingEnd = trainingEnd;
            m_testStart = testStart;
            m_testEnd = testEnd;
            m_training = training;
            m_test = test;
            m_result = result;
            m_kernelFunction = kernelFunction;
            m_semaphore = semaphore;
        }

        @Override
        public void run() {
            for (int r = m_trainingStart; r < m_trainingEnd; r++) {
                for (int c = m_testStart; c < m_testEnd; c++) {
                    m_result[r][c] = m_kernelFunction.calculate(m_training[r], m_test[c]);
                }
            }
            m_semaphore.release();

        }

    }

    public RealMatrix calculateKernelMatrix_singleThread(final double[][] training, final double[][] test) {
        final double[][] result = new double[training.length][test.length];

        for (int s1 = 0; s1 < training.length; s1++) {
            for (int s2 = 0; s2 < test.length; s2++) {
                result[s1][s2] = m_kernelFunction.calculate(training[s1], test[s2]);
            }
        }

        return MatrixUtils.createRealMatrix(result);
    }

    // private RealMatrix calculateKernelMatrix(double[][] training,
    // BufferedDataTable test) {
    // final RealMatrix kernelMatrix =
    // MatrixUtils.createRealMatrix(training.length, test.getRowCount());
    //
    // for (int r = 0; r < training.length; r++) {
    // Iterator<DataRow> testIterator = test.iterator();
    // for (int c = 0; c < test.getRowCount(); c++) {
    // kernelMatrix.setEntry(r, c, m_kernelFunction.calculate(training[r],
    // testIterator.next()));
    // }
    // }
    //
    // return kernelMatrix;
    // }

    public static double[][] readBufferedDataTable(final BufferedDataTable table) {
        final long size = table.size();
        if (size > Integer.MAX_VALUE) {
            throw new IllegalStateException("The input table is too large.");
        }
        final double[][] data = new double[(int) size][table.getDataTableSpec().getNumColumns()];

        final Iterator<DataRow> iterator = table.iterator();

        for (int r = 0; iterator.hasNext(); r++) {
            final DataRow next = iterator.next();
            for (int c = 0; c < table.getDataTableSpec().getNumColumns(); c++) {
                final DataCell cell = next.getCell(c);
                if (cell.isMissing()) {
                    throw new IllegalArgumentException("Missing values are not supported.");
                } else if (!cell.getType().isCompatible(DoubleValue.class)) {
                    throw new IllegalArgumentException("Only numerical data types are currently supported.");
                } else {
                    data[r][c] = ((DoubleValue) cell).getDoubleValue();
                }
            }
        }

        return data;
    }

    /************ Externalizable methods *****************/
    @Override
    public void readExternal(final ObjectInput in) throws IOException, ClassNotFoundException {
        try {
            // read kernelFunction
            m_kernelFunction = (KernelFunction) Class.forName(in.readUTF()).newInstance();
            m_kernelFunction.readExternal(in);

            // read trainingData
            // rows
            m_rowCount = in.readInt();
            // columns
            m_colCount = in.readInt();
            // data
            m_trainingData = new double[m_rowCount][m_colCount];
            for (int r = 0; r < m_rowCount; r++) {
                for (int c = 0; c < m_colCount; c++) {
                    m_trainingData[r][c] = in.readDouble();
                }
            }

        } catch (InstantiationException | IllegalAccessException e) {
            throw new IOException(e);
        }

    }

    @Override
    public void writeExternal(final ObjectOutput out) throws IOException {
        // write kernelFunction
        out.writeUTF(m_kernelFunction.getClass().getName());
        m_kernelFunction.writeExternal(out);

        // write trainingData
        // rows
        out.writeInt(m_rowCount);
        // columns
        out.writeInt(m_colCount);
        // data
        for (final double[] row : m_trainingData) {
            for (final double col : row) {
                out.writeDouble(col);
            }
        }

    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result + m_colCount;
        result = prime * result + ((m_kernelFunction == null) ? 0 : m_kernelFunction.hashCode());
        result = prime * result + m_rowCount;
        result = prime * result + Arrays.hashCode(m_trainingData);
        return result;
    }

    @Override
    public boolean equals(final Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null) {
            return false;
        }
        if (!(obj instanceof KernelCalculator)) {
            return false;
        }
        final KernelCalculator other = (KernelCalculator) obj;
        if (m_colCount != other.m_colCount) {
            return false;
        }
        if (m_kernelFunction == null) {
            if (other.m_kernelFunction != null) {
                return false;
            }
        } else if (!m_kernelFunction.equals(other.m_kernelFunction)) {
            return false;
        }
        if (m_rowCount != other.m_rowCount) {
            return false;
        }
        if (!Arrays.deepEquals(m_trainingData, other.m_trainingData)) {
            return false;
        }
        return true;
    }

    @Override
    public String toString() {
        final int maxLen = 10;
        return "KernelCalculator [m_trainingData="
                + (m_trainingData != null
                        ? Arrays.asList(m_trainingData).subList(0, Math.min(m_trainingData.length, maxLen))
                        : null)
                + ", m_rowCount=" + m_rowCount + ", m_colCount=" + m_colCount + ", m_kernelFunction="
                + m_kernelFunction + "]";
    }
}