Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.instructions.spark; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.stream.IntStream; import org.apache.commons.lang3.ArrayUtils; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.lops.PartialAggregate.CorrectionLocationType; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.codegen.CodegenUtils; import org.apache.sysml.runtime.codegen.LibSpoofPrimitives; import org.apache.sysml.runtime.codegen.SpoofCellwise; import org.apache.sysml.runtime.codegen.SpoofCellwise.AggOp; import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; import org.apache.sysml.runtime.codegen.SpoofMultiAggregate; import org.apache.sysml.runtime.codegen.SpoofOperator; import org.apache.sysml.runtime.codegen.SpoofOuterProduct; import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType; import org.apache.sysml.runtime.codegen.SpoofRowwise; import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType; import org.apache.sysml.runtime.controlprogram.caching.CacheableData; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.functionobjects.Builtin; import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysml.runtime.functionobjects.KahanPlus; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.DoubleObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast; import org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction; import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.operators.AggregateOperator; import scala.Tuple2; public class SpoofSPInstruction extends SPInstruction { private final Class<?> _class; private final byte[] _classBytes; private final CPOperand[] _in; private final CPOperand _out; private SpoofSPInstruction(Class<?> cls, byte[] classBytes, CPOperand[] in, CPOperand out, String opcode, String str) { super(SPType.SpoofFused, opcode, str); _class = cls; _classBytes = classBytes; _in = in; _out = out; } public static SpoofSPInstruction parseInstruction(String str) { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); //String opcode = parts[0]; ArrayList<CPOperand> inlist = new ArrayList<>(); Class<?> cls = CodegenUtils.getClass(parts[1]); byte[] classBytes = CodegenUtils.getClassData(parts[1]); String opcode = parts[0] + CodegenUtils.createInstance(cls).getSpoofType(); for (int i = 2; i < parts.length - 2; i++) inlist.add(new CPOperand(parts[i])); CPOperand out = new CPOperand(parts[parts.length - 2]); //note: number of threads parts[parts.length-1] always ignored return new SpoofSPInstruction(cls, classBytes, inlist.toArray(new CPOperand[0]), out, opcode, str); } @Override public void processInstruction(ExecutionContext ec) { SparkExecutionContext sec = (SparkExecutionContext) ec; //decide upon broadcast side inputs boolean[] bcVect = determineBroadcastInputs(sec, _in); boolean[] bcVect2 = getMatrixBroadcastVector(sec, _in, bcVect); int main = getMainInputIndex(_in, bcVect); //create joined input rdd w/ replication if needed MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(_in[main].getName()); JavaPairRDD<MatrixIndexes, MatrixBlock[]> in = createJoinedInputRDD(sec, _in, bcVect, (_class.getSuperclass() == SpoofOuterProduct.class)); JavaPairRDD<MatrixIndexes, MatrixBlock> out = null; //create lists of input broadcasts and scalars ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices = new ArrayList<>(); ArrayList<ScalarObject> scalars = new ArrayList<>(); for (int i = 0; i < _in.length; i++) { if (_in[i].getDataType() == DataType.MATRIX && bcVect[i]) { bcMatrices.add(sec.getBroadcastForVariable(_in[i].getName())); } else if (_in[i].getDataType() == DataType.SCALAR) { //note: even if literal, it might be compiled as scalar placeholder scalars.add(sec.getScalarInput(_in[i].getName(), _in[i].getValueType(), _in[i].isLiteral())); } } //execute generated operator if (_class.getSuperclass() == SpoofCellwise.class) //CELL { SpoofCellwise op = (SpoofCellwise) CodegenUtils.createInstance(_class); AggregateOperator aggop = getAggregateOperator(op.getAggOp()); if (_out.getDataType() == DataType.MATRIX) { //execute codegen block operation out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), _classBytes, bcVect2, bcMatrices, scalars, mcIn.getRowsPerBlock()), true); if ((op.getCellType() == CellType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock()) || (op.getCellType() == CellType.COL_AGG && mcIn.getRows() > mcIn.getRowsPerBlock())) { long numBlocks = (op.getCellType() == CellType.ROW_AGG) ? mcIn.getNumRowBlocks() : mcIn.getNumColBlocks(); out = RDDAggregateUtils.aggByKeyStable(out, aggop, (int) Math.min(out.getNumPartitions(), numBlocks), false); } sec.setRDDHandleForVariable(_out.getName(), out); //maintain lineage info and output characteristics maintainLineageInfo(sec, _in, bcVect, _out); updateOutputMatrixCharacteristics(sec, op); } else { //SCALAR out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), _classBytes, bcVect2, bcMatrices, scalars, mcIn.getRowsPerBlock()), true); MatrixBlock tmpMB = RDDAggregateUtils.aggStable(out, aggop); sec.setVariable(_out.getName(), new DoubleObject(tmpMB.getValue(0, 0))); } } else if (_class.getSuperclass() == SpoofMultiAggregate.class) //MAGG { SpoofMultiAggregate op = (SpoofMultiAggregate) CodegenUtils.createInstance(_class); AggOp[] aggOps = op.getAggOps(); MatrixBlock tmpMB = in .mapToPair(new MultiAggregateFunction(_class.getName(), _classBytes, bcVect2, bcMatrices, scalars, mcIn.getRowsPerBlock())) .values().fold(new MatrixBlock(), new MultiAggAggregateFunction(aggOps)); sec.setMatrixOutput(_out.getName(), tmpMB, getExtendedOpcode()); } else if (_class.getSuperclass() == SpoofOuterProduct.class) //OUTER { if (_out.getDataType() == DataType.MATRIX) { SpoofOperator op = (SpoofOperator) CodegenUtils.createInstance(_class); OutProdType type = ((SpoofOuterProduct) op).getOuterProdType(); //update matrix characteristics updateOutputMatrixCharacteristics(sec, op); MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName()); out = in.mapPartitionsToPair( new OuterProductFunction(_class.getName(), _classBytes, bcVect2, bcMatrices, scalars), true); if (type == OutProdType.LEFT_OUTER_PRODUCT || type == OutProdType.RIGHT_OUTER_PRODUCT) { long numBlocks = mcOut.getNumRowBlocks() * mcOut.getNumColBlocks(); out = RDDAggregateUtils.sumByKeyStable(out, (int) Math.min(out.getNumPartitions(), numBlocks), false); } sec.setRDDHandleForVariable(_out.getName(), out); //maintain lineage info and output characteristics maintainLineageInfo(sec, _in, bcVect, _out); } else { out = in.mapPartitionsToPair( new OuterProductFunction(_class.getName(), _classBytes, bcVect2, bcMatrices, scalars), true); MatrixBlock tmp = RDDAggregateUtils.sumStable(out); sec.setVariable(_out.getName(), new DoubleObject(tmp.getValue(0, 0))); } } else if (_class.getSuperclass() == SpoofRowwise.class) { //ROW if (mcIn.getCols() > mcIn.getColsPerBlock()) { throw new DMLRuntimeException("Invalid spark rowwise operator w/ ncol=" + mcIn.getCols() + ", ncolpb=" + mcIn.getColsPerBlock() + "."); } SpoofRowwise op = (SpoofRowwise) CodegenUtils.createInstance(_class); long clen2 = op.getRowType().isConstDim2(op.getConstDim2()) ? op.getConstDim2() : op.getRowType().isRowTypeB1() ? sec.getMatrixCharacteristics(_in[1].getName()).getCols() : -1; RowwiseFunction fmmc = new RowwiseFunction(_class.getName(), _classBytes, bcVect2, bcMatrices, scalars, mcIn.getRowsPerBlock(), (int) mcIn.getCols(), (int) clen2); out = in.mapPartitionsToPair(fmmc, op.getRowType() == RowType.ROW_AGG || op.getRowType() == RowType.NO_AGG); if (op.getRowType().isColumnAgg() || op.getRowType() == RowType.FULL_AGG) { MatrixBlock tmpMB = RDDAggregateUtils.sumStable(out); if (op.getRowType().isColumnAgg()) sec.setMatrixOutput(_out.getName(), tmpMB, getExtendedOpcode()); else sec.setScalarOutput(_out.getName(), new DoubleObject(tmpMB.quickGetValue(0, 0))); } else //row-agg or no-agg { if (op.getRowType() == RowType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock()) { out = RDDAggregateUtils.sumByKeyStable(out, (int) Math.min(out.getNumPartitions(), mcIn.getNumRowBlocks()), false); } sec.setRDDHandleForVariable(_out.getName(), out); //maintain lineage info and output characteristics maintainLineageInfo(sec, _in, bcVect, _out); updateOutputMatrixCharacteristics(sec, op); } } else { throw new DMLRuntimeException("Operator " + _class.getSuperclass() + " is not supported on Spark"); } } private static boolean[] determineBroadcastInputs(SparkExecutionContext sec, CPOperand[] inputs) { boolean[] ret = new boolean[inputs.length]; double localBudget = OptimizerUtils.getLocalMemBudget() - CacheableData.getBroadcastSize(); //account for other broadcasts double bcBudget = SparkExecutionContext.getBroadcastMemoryBudget(); //decided for each matrix input if it fits into remaining memory //budget; the major input, i.e., inputs[0] is always an RDD for (int i = 0; i < inputs.length; i++) if (inputs[i].getDataType().isMatrix()) { MatrixCharacteristics mc = sec.getMatrixCharacteristics(inputs[i].getName()); double sizeL = OptimizerUtils.estimateSizeExactSparsity(mc); double sizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(mc); //account for partitioning and local/remote budgets ret[i] = localBudget > (sizeL + sizeP) && bcBudget > sizeP; localBudget -= ret[i] ? sizeP : 0; //in local block manager bcBudget -= ret[i] ? sizeP : 0; //in remote block managers } //ensure there is at least one RDD input, with awareness for scalars if (!IntStream.range(0, ret.length).anyMatch(i -> inputs[i].isMatrix() && !ret[i])) ret[0] = false; return ret; } private static boolean[] getMatrixBroadcastVector(SparkExecutionContext sec, CPOperand[] inputs, boolean[] bcVect) { int numMtx = (int) Arrays.stream(inputs).filter(in -> in.getDataType().isMatrix()).count(); boolean[] ret = new boolean[numMtx]; for (int i = 0, pos = 0; i < inputs.length; i++) if (inputs[i].getDataType().isMatrix()) ret[pos++] = bcVect[i]; return ret; } private static JavaPairRDD<MatrixIndexes, MatrixBlock[]> createJoinedInputRDD(SparkExecutionContext sec, CPOperand[] inputs, boolean[] bcVect, boolean outer) { //get input rdd for main input int main = getMainInputIndex(inputs, bcVect); MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(inputs[main].getName()); JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(inputs[main].getName()); JavaPairRDD<MatrixIndexes, MatrixBlock[]> ret = in.mapValues(new MapInputSignature()); for (int i = 0; i < inputs.length; i++) if (i != main && inputs[i].getDataType().isMatrix() && !bcVect[i]) { //create side input rdd String varname = inputs[i].getName(); JavaPairRDD<MatrixIndexes, MatrixBlock> tmp = sec.getBinaryBlockRDDHandleForVariable(varname); MatrixCharacteristics mcTmp = sec.getMatrixCharacteristics(varname); //replicate blocks if mismatch with main input if (outer && i == 2) tmp = tmp.flatMapToPair( new ReplicateRightFactorFunction(mcIn.getRows(), mcIn.getRowsPerBlock())); else if (mcIn.getNumRowBlocks() > mcTmp.getNumRowBlocks()) tmp = tmp.flatMapToPair( new ReplicateBlockFunction(mcIn.getRows(), mcIn.getRowsPerBlock(), false)); else if (mcIn.getNumColBlocks() > mcTmp.getNumColBlocks()) tmp = tmp.flatMapToPair( new ReplicateBlockFunction(mcIn.getCols(), mcIn.getColsPerBlock(), true)); //join main and side inputs and consolidate signature ret = ret.join(tmp).mapValues(new MapJoinSignature()); } return ret; } private static void maintainLineageInfo(SparkExecutionContext sec, CPOperand[] inputs, boolean[] bcVect, CPOperand output) { //add lineage info for all rdd/broadcast inputs for (int i = 0; i < inputs.length; i++) if (inputs[i].getDataType().isMatrix()) sec.addLineage(output.getName(), inputs[i].getName(), bcVect[i]); } private static int getMainInputIndex(CPOperand[] inputs, boolean[] bcVect) { return IntStream.range(0, bcVect.length).filter(i -> inputs[i].isMatrix() && !bcVect[i]).min().orElse(0); } private void updateOutputMatrixCharacteristics(SparkExecutionContext sec, SpoofOperator op) { if (op instanceof SpoofCellwise) { MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(_in[0].getName()); MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName()); if (((SpoofCellwise) op).getCellType() == CellType.ROW_AGG) mcOut.set(mcIn.getRows(), 1, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock()); else if (((SpoofCellwise) op).getCellType() == CellType.NO_AGG) mcOut.set(mcIn); } else if (op instanceof SpoofOuterProduct) { MatrixCharacteristics mcIn1 = sec.getMatrixCharacteristics(_in[0].getName()); //X MatrixCharacteristics mcIn2 = sec.getMatrixCharacteristics(_in[1].getName()); //U MatrixCharacteristics mcIn3 = sec.getMatrixCharacteristics(_in[2].getName()); //V MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName()); OutProdType type = ((SpoofOuterProduct) op).getOuterProdType(); if (type == OutProdType.CELLWISE_OUTER_PRODUCT) mcOut.set(mcIn1.getRows(), mcIn1.getCols(), mcIn1.getRowsPerBlock(), mcIn1.getColsPerBlock()); else if (type == OutProdType.LEFT_OUTER_PRODUCT) mcOut.set(mcIn3.getRows(), mcIn3.getCols(), mcIn3.getRowsPerBlock(), mcIn3.getColsPerBlock()); else if (type == OutProdType.RIGHT_OUTER_PRODUCT) mcOut.set(mcIn2.getRows(), mcIn2.getCols(), mcIn2.getRowsPerBlock(), mcIn2.getColsPerBlock()); } else if (op instanceof SpoofRowwise) { MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(_in[0].getName()); MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName()); RowType type = ((SpoofRowwise) op).getRowType(); if (type == RowType.NO_AGG) mcOut.set(mcIn); else if (type == RowType.ROW_AGG) mcOut.set(mcIn.getRows(), 1, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock()); else if (type == RowType.COL_AGG) mcOut.set(1, mcIn.getCols(), mcIn.getRowsPerBlock(), mcIn.getColsPerBlock()); else if (type == RowType.COL_AGG_T) mcOut.set(mcIn.getCols(), 1, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock()); } } private static class MapInputSignature implements Function<MatrixBlock, MatrixBlock[]> { private static final long serialVersionUID = -816443970067626102L; @Override public MatrixBlock[] call(MatrixBlock v1) throws Exception { return new MatrixBlock[] { v1 }; } } private static class MapJoinSignature implements Function<Tuple2<MatrixBlock[], MatrixBlock>, MatrixBlock[]> { private static final long serialVersionUID = -704403012606821854L; @Override public MatrixBlock[] call(Tuple2<MatrixBlock[], MatrixBlock> v1) throws Exception { return ArrayUtils.add(v1._1(), v1._2()); } } private static class SpoofFunction implements Serializable { private static final long serialVersionUID = 2953479427746463003L; protected final boolean[] _bcInd; protected final ArrayList<PartitionedBroadcast<MatrixBlock>> _inputs; protected final ArrayList<ScalarObject> _scalars; protected final byte[] _classBytes; protected final String _className; protected SpoofFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) { _bcInd = bcInd; _inputs = bcMatrices; _scalars = scalars; _classBytes = classBytes; _className = className; } protected ArrayList<MatrixBlock> getAllMatrixInputs(MatrixIndexes ixIn, MatrixBlock[] blkIn) { return getAllMatrixInputs(ixIn, blkIn, false); } protected ArrayList<MatrixBlock> getAllMatrixInputs(MatrixIndexes ixIn, MatrixBlock[] blkIn, boolean outer) { ArrayList<MatrixBlock> ret = new ArrayList<>(); //add all rdd/broadcast inputs (main and side inputs) for (int i = 0, posRdd = 0, posBc = 0; i < _bcInd.length; i++) { if (_bcInd[i]) { PartitionedBroadcast<MatrixBlock> pb = _inputs.get(posBc++); int rowIndex = (int) ((outer && i == 2) ? ixIn.getColumnIndex() : (pb.getNumRowBlocks() >= ixIn.getRowIndex()) ? ixIn.getRowIndex() : 1); int colIndex = (int) ((outer && i == 2) ? 1 : (pb.getNumColumnBlocks() >= ixIn.getColumnIndex()) ? ixIn.getColumnIndex() : 1); ret.add(pb.getBlock(rowIndex, colIndex)); } else ret.add(blkIn[posRdd++]); } return ret; } } private static class RowwiseFunction extends SpoofFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = -7926980450209760212L; private final int _brlen; private final int _clen; private final int _clen2; private SpoofRowwise _op = null; public RowwiseFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars, int brlen, int clen, int clen2) { super(className, classBytes, bcInd, bcMatrices, scalars); _brlen = brlen; _clen = clen; _clen2 = clen; } @Override public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg) { //lazy load of shipped class if (_op == null) { Class<?> loadedClass = CodegenUtils.getClassSync(_className, _classBytes); _op = (SpoofRowwise) CodegenUtils.createInstance(loadedClass); } //setup local memory for reuse LibSpoofPrimitives.setupThreadLocalMemory(_op.getNumIntermediates(), _clen, _clen2); ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<>(); boolean aggIncr = (_op.getRowType().isColumnAgg() //aggregate entire partition || _op.getRowType() == RowType.FULL_AGG); MatrixBlock blkOut = aggIncr ? new MatrixBlock() : null; while (arg.hasNext()) { //get main input block and indexes Tuple2<MatrixIndexes, MatrixBlock[]> e = arg.next(); MatrixIndexes ixIn = e._1(); MatrixBlock[] blkIn = e._2(); long rix = (ixIn.getRowIndex() - 1) * _brlen; //0-based //prepare output and execute single-threaded operator ArrayList<MatrixBlock> inputs = getAllMatrixInputs(ixIn, blkIn); blkOut = aggIncr ? blkOut : new MatrixBlock(); blkOut = _op.execute(inputs, _scalars, blkOut, false, aggIncr, rix); if (!aggIncr) { MatrixIndexes ixOut = new MatrixIndexes(ixIn.getRowIndex(), _op.getRowType() != RowType.NO_AGG ? 1 : ixIn.getColumnIndex()); ret.add(new Tuple2<>(ixOut, blkOut)); } } //cleanup and final result preparations LibSpoofPrimitives.cleanupThreadLocalMemory(); if (aggIncr) { blkOut.recomputeNonZeros(); blkOut.examSparsity(); //deferred format change ret.add(new Tuple2<>(new MatrixIndexes(1, 1), blkOut)); } return ret.iterator(); } } private static class CellwiseFunction extends SpoofFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = -8209188316939435099L; private SpoofCellwise _op = null; private final int _brlen; public CellwiseFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars, int brlen) { super(className, classBytes, bcInd, bcMatrices, scalars); _brlen = brlen; } @Override public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg) throws Exception { //lazy load of shipped class if (_op == null) { Class<?> loadedClass = CodegenUtils.getClassSync(_className, _classBytes); _op = (SpoofCellwise) CodegenUtils.createInstance(loadedClass); } List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<>(); while (arg.hasNext()) { Tuple2<MatrixIndexes, MatrixBlock[]> tmp = arg.next(); MatrixIndexes ixIn = tmp._1(); MatrixBlock[] blkIn = tmp._2(); MatrixIndexes ixOut = ixIn; MatrixBlock blkOut = new MatrixBlock(); ArrayList<MatrixBlock> inputs = getAllMatrixInputs(ixIn, blkIn); long rix = (ixIn.getRowIndex() - 1) * _brlen; //0-based //execute core operation if (_op.getCellType() == CellType.FULL_AGG) { ScalarObject obj = _op.execute(inputs, _scalars, 1, rix); blkOut.reset(1, 1); blkOut.quickSetValue(0, 0, obj.getDoubleValue()); } else { if (_op.getCellType() == CellType.ROW_AGG) ixOut = new MatrixIndexes(ixOut.getRowIndex(), 1); else if (((SpoofCellwise) _op).getCellType() == CellType.COL_AGG) ixOut = new MatrixIndexes(1, ixOut.getColumnIndex()); blkOut = _op.execute(inputs, _scalars, blkOut, 1, rix); } ret.add(new Tuple2<>(ixOut, blkOut)); } return ret.iterator(); } } private static class MultiAggregateFunction extends SpoofFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock[]>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = -5224519291577332734L; private SpoofMultiAggregate _op = null; private final int _brlen; public MultiAggregateFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars, int brlen) { super(className, classBytes, bcInd, bcMatrices, scalars); _brlen = brlen; } @Override public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock[]> arg) throws Exception { //lazy load of shipped class if (_op == null) { Class<?> loadedClass = CodegenUtils.getClassSync(_className, _classBytes); _op = (SpoofMultiAggregate) CodegenUtils.createInstance(loadedClass); } //execute core operation ArrayList<MatrixBlock> inputs = getAllMatrixInputs(arg._1(), arg._2()); MatrixBlock blkOut = new MatrixBlock(); long rix = (arg._1().getRowIndex() - 1) * _brlen; //0-based blkOut = _op.execute(inputs, _scalars, blkOut, 1, rix); return new Tuple2<>(arg._1(), blkOut); } } private static class MultiAggAggregateFunction implements Function2<MatrixBlock, MatrixBlock, MatrixBlock> { private static final long serialVersionUID = 5978731867787952513L; private AggOp[] _ops = null; public MultiAggAggregateFunction(AggOp[] ops) { _ops = ops; } @Override public MatrixBlock call(MatrixBlock arg0, MatrixBlock arg1) throws Exception { //prepare combiner block if (arg0.getNumRows() <= 0 || arg0.getNumColumns() <= 0) { arg0.copy(arg1); return arg0; } else if (arg1.getNumRows() <= 0 || arg1.getNumColumns() <= 0) { return arg0; } //aggregate second input (in-place) SpoofMultiAggregate.aggregatePartialResults(_ops, arg0, arg1); return arg0; } } private static class OuterProductFunction extends SpoofFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = -8209188316939435099L; private SpoofOperator _op = null; public OuterProductFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) { super(className, classBytes, bcInd, bcMatrices, scalars); } @Override public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg) throws Exception { //lazy load of shipped class if (_op == null) { Class<?> loadedClass = CodegenUtils.getClassSync(_className, _classBytes); _op = (SpoofOperator) CodegenUtils.createInstance(loadedClass); } List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<>(); while (arg.hasNext()) { Tuple2<MatrixIndexes, MatrixBlock[]> tmp = arg.next(); MatrixIndexes ixIn = tmp._1(); MatrixBlock[] blkIn = tmp._2(); MatrixBlock blkOut = new MatrixBlock(); ArrayList<MatrixBlock> inputs = getAllMatrixInputs(ixIn, blkIn, true); //execute core operation if (((SpoofOuterProduct) _op).getOuterProdType() == OutProdType.AGG_OUTER_PRODUCT) { ScalarObject obj = _op.execute(inputs, _scalars, 1); blkOut.reset(1, 1); blkOut.quickSetValue(0, 0, obj.getDoubleValue()); } else { blkOut = _op.execute(inputs, _scalars, blkOut); } ret.add(new Tuple2<>(createOutputIndexes(ixIn, _op), blkOut)); } return ret.iterator(); } private static MatrixIndexes createOutputIndexes(MatrixIndexes in, SpoofOperator spoofOp) { if (((SpoofOuterProduct) spoofOp).getOuterProdType() == OutProdType.LEFT_OUTER_PRODUCT) return new MatrixIndexes(in.getColumnIndex(), 1); else if (((SpoofOuterProduct) spoofOp).getOuterProdType() == OutProdType.RIGHT_OUTER_PRODUCT) return new MatrixIndexes(in.getRowIndex(), 1); else return in; } } public static class ReplicateRightFactorFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = -7295989688796126442L; private final long _len; private final long _blen; public ReplicateRightFactorFunction(long len, long blen) { _len = len; _blen = blen; } @Override public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception { LinkedList<Tuple2<MatrixIndexes, MatrixBlock>> ret = new LinkedList<>(); MatrixIndexes ixIn = arg0._1(); MatrixBlock blkIn = arg0._2(); long numBlocks = (long) Math.ceil((double) _len / _blen); //replicate wrt # row blocks in LHS long j = ixIn.getRowIndex(); for (long i = 1; i <= numBlocks; i++) { MatrixIndexes tmpix = new MatrixIndexes(i, j); MatrixBlock tmpblk = blkIn; ret.add(new Tuple2<>(tmpix, tmpblk)); } //output list of new tuples return ret.iterator(); } } public static AggregateOperator getAggregateOperator(AggOp aggop) { if (aggop == AggOp.SUM || aggop == AggOp.SUM_SQ) return new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.NONE); else if (aggop == AggOp.MIN) return new AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject(BuiltinCode.MIN), false, CorrectionLocationType.NONE); else if (aggop == AggOp.MAX) return new AggregateOperator(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject(BuiltinCode.MAX), false, CorrectionLocationType.NONE); return null; } }