org.jpmml.evaluator.MatrixUtil.java Source code

Java tutorial

Introduction

Here is the source code for org.jpmml.evaluator.MatrixUtil.java

Source

/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.jpmml.evaluator;

import java.util.Collections;
import java.util.Comparator;
import java.util.List;

import com.google.common.base.Predicate;
import com.google.common.collect.Iterables;
import org.dmg.pmml.Array;
import org.dmg.pmml.MatCell;
import org.dmg.pmml.Matrix;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.UnsupportedFeatureException;

public class MatrixUtil {

    private MatrixUtil() {
    }

    /**
     * @param row The row index. The index of the first row is <code>1</code>.
     * @param column The column index. The index of the first column is <code>1</code>.
     *
     * @return The element at the specified location, or <code>null</code>.
     *
     * @throws IndexOutOfBoundsException If either the row or column index is out of range.
     */
    static public Number getElementAt(Matrix matrix, int row, int column) {
        List<Array> arrays = matrix.getArrays();
        List<MatCell> matCells = matrix.getMatCells();

        Matrix.Kind kind = matrix.getKind();
        switch (kind) {
        case DIAGONAL: {
            // "The content is just one Array of numbers representing the diagonal values"
            if (arrays.size() == 1) {
                Array array = arrays.get(0);

                List<? extends Number> elements = ArrayUtil.getNumberContent(array);

                // Diagonal element
                if (row == column) {
                    return elements.get(row - 1);
                } else

                // Off-diagonal element
                {
                    int min = 1;
                    int max = elements.size();

                    if ((row < min || row > max) || (column < min || column > max)) {
                        throw new IndexOutOfBoundsException();
                    }

                    return matrix.getOffDiagDefault();
                }
            }
        }
            break;
        case SYMMETRIC: {
            // "The content must be represented by Arrays"
            if (arrays.size() > 0) {

                // Make sure the specified coordinates target the lower left triangle
                if (column > row) {
                    int temp = row;

                    row = column;
                    column = temp;
                }

                return getArrayValue(arrays, row, column);
            }
        }
            break;
        case ANY: {
            if (arrays.size() > 0) {
                return getArrayValue(arrays, row, column);
            } // End if

            if (matCells.size() > 0) {

                if (row < 1 || column < 1) {
                    throw new IndexOutOfBoundsException();
                }

                Number value = getMatCellValue(matCells, row, column);
                if (value == null) {

                    if (row == column) {
                        return matrix.getDiagDefault();
                    }

                    return matrix.getOffDiagDefault();
                }

                return value;
            }
        }
            break;
        default:
            throw new UnsupportedFeatureException(matrix, kind);
        }

        throw new InvalidFeatureException(matrix);
    }

    static private Number getArrayValue(List<Array> arrays, int row, int column) {
        Array array = arrays.get(row - 1);

        List<? extends Number> elements = ArrayUtil.getNumberContent(array);

        return elements.get(column - 1);
    }

    static private Number getMatCellValue(List<MatCell> matCells, final int row, final int column) {
        Predicate<MatCell> filter = new Predicate<MatCell>() {

            @Override
            public boolean apply(MatCell matCell) {
                return (getRow(matCell) == row) && (getColumn(matCell) == column);
            }
        };

        MatCell matCell = Iterables.getFirst(Iterables.filter(matCells, filter), null);
        if (matCell != null) {
            String value = matCell.getValue();

            return Double.valueOf(value);
        }

        return null;
    }

    /**
     * @return The number of rows.
     */
    static public int getRows(Matrix matrix) {
        Integer nbRows = matrix.getNbRows();
        if (nbRows != null) {
            return nbRows.intValue();
        }

        List<Array> arrays = matrix.getArrays();
        List<MatCell> matCells = matrix.getMatCells();

        Matrix.Kind kind = matrix.getKind();
        switch (kind) {
        case DIAGONAL: {
            if (arrays.size() == 1) {
                Array array = arrays.get(0);

                return ArrayUtil.getSize(array);
            }
        }
            break;
        case SYMMETRIC: {
            if (arrays.size() > 0) {
                return arrays.size();
            }
        }
            break;
        case ANY: {
            if (arrays.size() > 0) {
                return arrays.size();
            } // End if

            if (matCells.size() > 0) {
                MatCell matCell = Collections.max(matCells, MatrixUtil.rowComparator);

                return getRow(matCell);
            }
        }
            break;
        default:
            throw new UnsupportedFeatureException(matrix, kind);
        }

        throw new InvalidFeatureException(matrix);
    }

    /**
     * @return The number of columns.
     */
    static public int getColumns(Matrix matrix) {
        Integer nbCols = matrix.getNbCols();
        if (nbCols != null) {
            return nbCols.intValue();
        }

        List<Array> arrays = matrix.getArrays();
        List<MatCell> matCells = matrix.getMatCells();

        Matrix.Kind kind = matrix.getKind();
        switch (kind) {
        case DIAGONAL: {
            if (arrays.size() == 1) {
                Array array = arrays.get(0);

                return ArrayUtil.getSize(array);
            }
        }
            break;
        case SYMMETRIC: {
            if (arrays.size() > 0) {
                return arrays.size();
            }
        }
            break;
        case ANY: {
            if (arrays.size() > 0) {
                Array array = arrays.get(arrays.size() - 1);

                return ArrayUtil.getSize(array);
            } // End if

            if (matCells.size() > 0) {
                MatCell matCell = Collections.max(matCells, MatrixUtil.columnComparator);

                return getColumn(matCell);
            }
        }
            break;
        default:
            throw new UnsupportedFeatureException(matrix, kind);
        }

        throw new InvalidFeatureException(matrix);
    }

    static private int getRow(MatCell matCell) {
        Integer row = matCell.getRow();
        if (row == null) {
            throw new InvalidFeatureException(matCell);
        }

        return row.intValue();
    }

    static private int getColumn(MatCell matCell) {
        Integer column = matCell.getCol();
        if (column == null) {
            throw new InvalidFeatureException(matCell);
        }

        return column.intValue();
    }

    private static final Comparator<MatCell> rowComparator = new Comparator<MatCell>() {

        @Override
        public int compare(MatCell left, MatCell right) {
            return (getRow(left) - getRow(right));
        }
    };

    private static final Comparator<MatCell> columnComparator = new Comparator<MatCell>() {

        @Override
        public int compare(MatCell left, MatCell right) {
            return (getColumn(left) - getColumn(right));
        }
    };
}