org.apache.mahout.math.SparseMatrix.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.math.SparseMatrix.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.math;

import it.unimi.dsi.fastutil.ints.Int2ObjectMap.Entry;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;

import java.util.Iterator;
import java.util.Map;

import org.apache.mahout.math.flavor.MatrixFlavor;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.list.IntArrayList;

import com.google.common.collect.AbstractIterator;

/** Doubly sparse matrix. Implemented as a Map of RandomAccessSparseVector rows */
public class SparseMatrix extends AbstractMatrix {

    private Int2ObjectOpenHashMap<Vector> rowVectors;

    /**
     * Construct a matrix of the given cardinality with the given row map
     *
     * @param rows no of rows
     * @param columns no of columns
     * @param rowVectors a {@code Map<Integer, RandomAccessSparseVector>} of rows
     */
    public SparseMatrix(int rows, int columns, Map<Integer, Vector> rowVectors) {
        this(rows, columns, rowVectors, false);
    }

    public SparseMatrix(int rows, int columns, Map<Integer, Vector> rowVectors, boolean shallow) {

        // Why this is passing in a map? iterating it is pretty inefficient as opposed to simple lists...
        super(rows, columns);
        this.rowVectors = new Int2ObjectOpenHashMap<>();
        if (shallow) {
            for (Map.Entry<Integer, Vector> entry : rowVectors.entrySet()) {
                this.rowVectors.put(entry.getKey().intValue(), entry.getValue());
            }
        } else {
            for (Map.Entry<Integer, Vector> entry : rowVectors.entrySet()) {
                this.rowVectors.put(entry.getKey().intValue(), entry.getValue().clone());
            }
        }
    }

    /**
     * Construct a matrix with specified number of rows and columns.
     */
    public SparseMatrix(int rows, int columns) {
        super(rows, columns);
        this.rowVectors = new Int2ObjectOpenHashMap<>();
    }

    @Override
    public Matrix clone() {
        SparseMatrix clone = new SparseMatrix(numRows(), numCols());
        for (MatrixSlice slice : this) {
            clone.rowVectors.put(slice.index(), slice.clone());
        }
        return clone;
    }

    @Override
    public int numSlices() {
        return rowVectors.size();
    }

    public Iterator<MatrixSlice> iterateNonEmpty() {
        final int[] keys = rowVectors.keySet().toIntArray();
        return new AbstractIterator<MatrixSlice>() {
            private int slice;

            @Override
            protected MatrixSlice computeNext() {
                if (slice >= rowVectors.size()) {
                    return endOfData();
                }
                int i = keys[slice];
                Vector row = rowVectors.get(i);
                slice++;
                return new MatrixSlice(row, i);
            }
        };
    }

    @Override
    public double getQuick(int row, int column) {
        Vector r = rowVectors.get(row);
        return r == null ? 0.0 : r.getQuick(column);
    }

    @Override
    public Matrix like() {
        return new SparseMatrix(rowSize(), columnSize());
    }

    @Override
    public Matrix like(int rows, int columns) {
        return new SparseMatrix(rows, columns);
    }

    @Override
    public void setQuick(int row, int column, double value) {
        Vector r = rowVectors.get(row);
        if (r == null) {
            r = new RandomAccessSparseVector(columnSize());
            rowVectors.put(row, r);
        }
        r.setQuick(column, value);
    }

    @Override
    public int[] getNumNondefaultElements() {
        int[] result = new int[2];
        result[ROW] = rowVectors.size();
        for (Vector row : rowVectors.values()) {
            result[COL] = Math.max(result[COL], row.getNumNondefaultElements());
        }
        return result;
    }

    @Override
    public Matrix viewPart(int[] offset, int[] size) {
        if (offset[ROW] < 0) {
            throw new IndexException(offset[ROW], rowSize());
        }
        if (offset[ROW] + size[ROW] > rowSize()) {
            throw new IndexException(offset[ROW] + size[ROW], rowSize());
        }
        if (offset[COL] < 0) {
            throw new IndexException(offset[COL], columnSize());
        }
        if (offset[COL] + size[COL] > columnSize()) {
            throw new IndexException(offset[COL] + size[COL], columnSize());
        }
        return new MatrixView(this, offset, size);
    }

    @Override
    public Matrix assign(Matrix other, DoubleDoubleFunction function) {
        //TODO generalize to other kinds of functions
        if (Functions.PLUS.equals(function) && other instanceof SparseMatrix) {
            int rows = rowSize();
            if (rows != other.rowSize()) {
                throw new CardinalityException(rows, other.rowSize());
            }
            int columns = columnSize();
            if (columns != other.columnSize()) {
                throw new CardinalityException(columns, other.columnSize());
            }

            SparseMatrix otherSparse = (SparseMatrix) other;
            for (ObjectIterator<Entry<Vector>> fastIterator = otherSparse.rowVectors.int2ObjectEntrySet()
                    .fastIterator(); fastIterator.hasNext();) {
                final Entry<Vector> entry = fastIterator.next();
                final int rowIndex = entry.getIntKey();
                Vector row = rowVectors.get(rowIndex);
                if (row == null) {
                    rowVectors.put(rowIndex, entry.getValue().clone());
                } else {
                    row.assign(entry.getValue(), Functions.PLUS);
                }
            }
            return this;
        } else {
            return super.assign(other, function);
        }
    }

    @Override
    public Matrix assignColumn(int column, Vector other) {
        if (rowSize() != other.size()) {
            throw new CardinalityException(rowSize(), other.size());
        }
        if (column < 0 || column >= columnSize()) {
            throw new IndexException(column, columnSize());
        }
        for (int row = 0; row < rowSize(); row++) {
            double val = other.getQuick(row);
            if (val != 0.0) {
                Vector r = rowVectors.get(row);
                if (r == null) {
                    r = new RandomAccessSparseVector(columnSize());
                    rowVectors.put(row, r);
                }
                r.setQuick(column, val);
            }
        }
        return this;
    }

    @Override
    public Matrix assignRow(int row, Vector other) {
        if (columnSize() != other.size()) {
            throw new CardinalityException(columnSize(), other.size());
        }
        if (row < 0 || row >= rowSize()) {
            throw new IndexException(row, rowSize());
        }
        rowVectors.put(row, other);
        return this;
    }

    @Override
    public Vector viewRow(int row) {
        if (row < 0 || row >= rowSize()) {
            throw new IndexException(row, rowSize());
        }
        Vector res = rowVectors.get(row);
        if (res == null) {
            res = new RandomAccessSparseVector(columnSize());
            rowVectors.put(row, res);
        }
        return res;
    }

    /** special method necessary for efficient serialization */
    public IntArrayList nonZeroRowIndices() {
        return new IntArrayList(rowVectors.keySet().toIntArray());
    }

    @Override
    public MatrixFlavor getFlavor() {
        return MatrixFlavor.SPARSEROWLIKE;
    }
}