com.ibm.bi.dml.api.MLMatrix.java Source code

Java tutorial

Introduction

Here is the source code for com.ibm.bi.dml.api.MLMatrix.java

Source

/**
 * (C) Copyright IBM Corp. 2010, 2015
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * 
*/
package com.ibm.bi.dml.api;

import java.io.IOException;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SQLContext.QueryExecution;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.types.StructType;

import scala.Tuple2;

import com.ibm.bi.dml.parser.DMLTranslator;
import com.ibm.bi.dml.parser.ParseException;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.instructions.spark.functions.GetMIMBFromRow;
import com.ibm.bi.dml.runtime.instructions.spark.functions.GetMLBlock;
import com.ibm.bi.dml.runtime.matrix.MatrixCharacteristics;
import com.ibm.bi.dml.runtime.matrix.data.MatrixBlock;
import com.ibm.bi.dml.runtime.matrix.data.MatrixIndexes;

/**
 * Experimental API: Might be discontinued in future release
 * 
 * This class serves four purposes:
 * 1. It allows SystemML to fit nicely in MLPipeline by reducing number of reblocks.
 * 2. It allows users to easily read and write matrices without worrying 
 * too much about format, metadata and type of underlying RDDs.
 * 3. It provides mechanism to convert to and from MLLib's BlockedMatrix format
 * 4. It provides off-the-shelf library for Distributed Blocked Matrix and reduces learning curve for using SystemML.
 * However, it is important to know that it is easy to abuse this off-the-shelf library and think it as replacement
 * to writing DML, which it is not. It does not provide any optimization between calls. A simple example
 * of the optimization that is conveniently skipped is: (t(m) %*% m)).
 * Also, note that this library is not thread-safe. The operator precedence is not exactly same as DML (as the precedence is
 * enforced by scala compiler), so please use appropriate brackets to enforce precedence. 
    
 import com.ibm.bi.dml.api.{MLContext, MLMatrix}
 val ml = new MLContext(sc)
 val mat1 = ml.read(sqlContext, "V_small.csv", "csv")
 val mat2 = ml.read(sqlContext, "W_small.mtx", "binary")
 val result = mat1.transpose() %*% mat2
 result.write("Result_small.mtx", "text")
     
 */
public class MLMatrix extends DataFrame {
    private static final long serialVersionUID = -7005940673916671165L;
    protected static final Log LOG = LogFactory.getLog(DMLScript.class.getName());

    protected MatrixCharacteristics mc = null;
    protected MLContext ml = null;

    protected MLMatrix(SQLContext sqlContext, LogicalPlan logicalPlan, MLContext ml) {
        super(sqlContext, logicalPlan);
        this.ml = ml;
    }

    protected MLMatrix(SQLContext sqlContext, QueryExecution queryExecution, MLContext ml) {
        super(sqlContext, queryExecution);
        this.ml = ml;
    }

    // Only used internally to set a new MLMatrix after one of matrix operations.
    // Not to be used externally.
    protected MLMatrix(DataFrame df, MatrixCharacteristics mc, MLContext ml) throws DMLRuntimeException {
        super(df.sqlContext(), df.logicalPlan());
        this.mc = mc;
        this.ml = ml;
    }

    static String writeStmt = "write(output, \"tmp\", format=\"binary\", rows_in_block="
            + DMLTranslator.DMLBlockSize + ", cols_in_block=" + DMLTranslator.DMLBlockSize + ");";

    // ------------------------------------------------------------------------------------------------

    //   /**
    //    * Experimental unstable API: Converts our blocked matrix format to MLLib's format
    //    * @return
    //    */
    //   public BlockMatrix toBlockedMatrix() {
    //      JavaPairRDD<MatrixIndexes, MatrixBlock> blocks = getRDDLazily(this);
    //      RDD<Tuple2<Tuple2<Object, Object>, Matrix>> mllibBlocks = blocks.mapToPair(new GetMLLibBlocks(mc.getRows(), mc.getCols(), mc.getRowsPerBlock(), mc.getColsPerBlock())).rdd();
    //      return new BlockMatrix(mllibBlocks, mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getRows(), mc.getCols());
    //   }

    // ------------------------------------------------------------------------------------------------
    static MLMatrix createMLMatrix(MLContext ml, SQLContext sqlContext,
            JavaPairRDD<MatrixIndexes, MatrixBlock> blocks, MatrixCharacteristics mc) throws DMLRuntimeException {
        RDD<Row> rows = blocks.map(new GetMLBlock()).rdd();
        StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
        return new MLMatrix(sqlContext.createDataFrame(rows.toJavaRDD(), schema), mc, ml);
    }

    /**
     * Convenient method to write a MLMatrix.
     */
    public void write(String filePath, String format) throws IOException, DMLException, ParseException {
        ml.reset();
        ml.registerInput("left", this);
        ml.executeScript(
                "left = read(\"\"); output=left; write(output, \"" + filePath + "\", format=\"" + format + "\");");
    }

    private double getScalarBuiltinFunctionResult(String fn) throws IOException, DMLException, ParseException {
        if (fn.compareTo("nrow") == 0 || fn.compareTo("ncol") == 0) {
            ml.reset();
            ml.registerInput("left", getRDDLazily(this), mc.getRows(), mc.getCols(), mc.getRowsPerBlock(),
                    mc.getColsPerBlock(), mc.getNonZeros());
            ml.registerOutput("output");
            String script = "left = read(\"\");" + "val = " + fn + "(left); "
                    + "output = matrix(val, rows=1, cols=1); " + writeStmt;
            MLOutput out = ml.executeScript(script);
            List<Tuple2<MatrixIndexes, MatrixBlock>> result = out.getBinaryBlockedRDD("output").collect();
            if (result == null || result.size() != 1) {
                throw new DMLRuntimeException("Error while computing the function: " + fn);
            }
            return result.get(0)._2.getValue(0, 0);
        } else {
            throw new DMLRuntimeException("The function " + fn + " is not yet supported in MLMatrix");
        }
    }

    /**
     * Gets or computes the number of rows.
     * @return
     * @throws ParseException 
     * @throws DMLException 
     * @throws IOException 
     */
    public long numRows() throws IOException, DMLException, ParseException {
        if (mc.rowsKnown()) {
            return mc.getRows();
        } else {
            return (long) getScalarBuiltinFunctionResult("nrow");
        }
    }

    /**
     * Gets or computes the number of columns.
     * @return
     * @throws ParseException 
     * @throws DMLException 
     * @throws IOException 
     */
    public long numCols() throws IOException, DMLException, ParseException {
        if (mc.colsKnown()) {
            return mc.getCols();
        } else {
            return (long) getScalarBuiltinFunctionResult("ncol");
        }
    }

    public int rowsPerBlock() {
        return mc.getRowsPerBlock();
    }

    public int colsPerBlock() {
        return mc.getColsPerBlock();
    }

    private String getScript(String binaryOperator) {
        return "left = read(\"\");" + "right = read(\"\");" + "output = left " + binaryOperator + " right; "
                + writeStmt;
    }

    private String getScalarBinaryScript(String binaryOperator, double scalar, boolean isScalarLeft) {
        if (isScalarLeft) {
            return "left = read(\"\");" + "output = " + scalar + " " + binaryOperator + " left ;" + writeStmt;
        } else {
            return "left = read(\"\");" + "output = left " + binaryOperator + " " + scalar + ";" + writeStmt;
        }
    }

    static JavaPairRDD<MatrixIndexes, MatrixBlock> getRDDLazily(MLMatrix mat) {
        return mat.rdd().toJavaRDD().mapToPair(new GetMIMBFromRow());
    }

    private MLMatrix matrixBinaryOp(MLMatrix that, String op) throws IOException, DMLException, ParseException {

        if (mc.getRowsPerBlock() != that.mc.getRowsPerBlock()
                || mc.getColsPerBlock() != that.mc.getColsPerBlock()) {
            throw new DMLRuntimeException(
                    "Incompatible block sizes: brlen:" + mc.getRowsPerBlock() + "!=" + that.mc.getRowsPerBlock()
                            + " || bclen:" + mc.getColsPerBlock() + "!=" + that.mc.getColsPerBlock());
        }

        if (op.compareTo("%*%") == 0) {
            if (mc.getCols() != that.mc.getRows()) {
                throw new DMLRuntimeException("Dimensions mismatch:" + mc.getCols() + "!=" + that.mc.getRows());
            }
        } else {
            if (mc.getRows() != that.mc.getRows() || mc.getCols() != that.mc.getCols()) {
                throw new DMLRuntimeException("Dimensions mismatch:" + mc.getRows() + "!=" + that.mc.getRows()
                        + " || " + mc.getCols() + "!=" + that.mc.getCols());
            }
        }

        ml.reset();
        ml.registerInput("left", this);
        ml.registerInput("right", that);
        ml.registerOutput("output");
        MLOutput out = ml.executeScript(getScript(op));
        RDD<Row> rows = out.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
        StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
        MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
        return new MLMatrix(this.sqlContext().createDataFrame(rows.toJavaRDD(), schema), mcOut, ml);
    }

    private MLMatrix scalarBinaryOp(Double scalar, String op, boolean isScalarLeft)
            throws IOException, DMLException, ParseException {
        ml.reset();
        ml.registerInput("left", this);
        ml.registerOutput("output");
        MLOutput out = ml.executeScript(getScalarBinaryScript(op, scalar, isScalarLeft));
        RDD<Row> rows = out.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
        StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
        MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
        return new MLMatrix(this.sqlContext().createDataFrame(rows.toJavaRDD(), schema), mcOut, ml);
    }

    // ---------------------------------------------------
    // Simple operator loading but doesnot utilize the optimizer

    public MLMatrix $greater(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, ">");
    }

    public MLMatrix $less(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "<");
    }

    public MLMatrix $greater$eq(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, ">=");
    }

    public MLMatrix $less$eq(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "<=");
    }

    public MLMatrix $eq$eq(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "==");
    }

    public MLMatrix $bang$eq(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "!=");
    }

    public MLMatrix $up(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "^");
    }

    public MLMatrix exp(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "^");
    }

    public MLMatrix $plus(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "+");
    }

    public MLMatrix add(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "+");
    }

    public MLMatrix $minus(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "-");
    }

    public MLMatrix minus(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "-");
    }

    public MLMatrix $times(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "*");
    }

    public MLMatrix elementWiseMultiply(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "*");
    }

    public MLMatrix $div(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "/");
    }

    public MLMatrix divide(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "/");
    }

    public MLMatrix $percent$div$percent(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "%/%");
    }

    public MLMatrix integerDivision(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "%/%");
    }

    public MLMatrix $percent$percent(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "%%");
    }

    public MLMatrix modulus(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "%%");
    }

    public MLMatrix $percent$times$percent(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "%*%");
    }

    public MLMatrix multiply(MLMatrix that) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(that, "%*%");
    }

    public MLMatrix transpose() throws IOException, DMLException, ParseException {
        ml.reset();
        ml.registerInput("left", this);
        ml.registerOutput("output");
        String script = "left = read(\"\");" + "output = t(left); " + writeStmt;
        MLOutput out = ml.executeScript(script);
        RDD<Row> rows = out.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
        StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
        MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
        return new MLMatrix(this.sqlContext().createDataFrame(rows.toJavaRDD(), schema), mcOut, ml);
    }

    // TODO: For 'scalar op matrix' operations: Do implicit conversions 
    public MLMatrix $plus(Double scalar) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(scalar, "+", false);
    }

    public MLMatrix add(Double scalar) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(scalar, "+", false);
    }

    public MLMatrix $minus(Double scalar) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(scalar, "-", false);
    }

    public MLMatrix minus(Double scalar) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(scalar, "-", false);
    }

    public MLMatrix $times(Double scalar) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(scalar, "*", false);
    }

    public MLMatrix elementWiseMultiply(Double scalar) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(scalar, "*", false);
    }

    public MLMatrix $div(Double scalar) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(scalar, "/", false);
    }

    public MLMatrix divide(Double scalar) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(scalar, "/", false);
    }

    public MLMatrix $greater(Double scalar) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(scalar, ">", false);
    }

    public MLMatrix $less(Double scalar) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(scalar, "<", false);
    }

    public MLMatrix $greater$eq(Double scalar) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(scalar, ">=", false);
    }

    public MLMatrix $less$eq(Double scalar) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(scalar, "<=", false);
    }

    public MLMatrix $eq$eq(Double scalar) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(scalar, "==", false);
    }

    public MLMatrix $bang$eq(Double scalar) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(scalar, "!=", false);
    }

}