Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.hops.codegen.cplan; import java.util.ArrayList; import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType; import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; import org.apache.sysml.hops.codegen.template.TemplateUtils; import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType; import org.apache.sysml.runtime.util.UtilFunctions; public class CNodeRow extends CNodeTpl { private static final String TEMPLATE = "package codegen;\n" + "import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;\n" + "import org.apache.sysml.runtime.codegen.SpoofOperator.SideInput;\n" + "import org.apache.sysml.runtime.codegen.SpoofRowwise;\n" + "import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType;\n" + "import org.apache.commons.math3.util.FastMath;\n" + "\n" + "public final class %TMP% extends SpoofRowwise { \n" + " public %TMP%() {\n" + " super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n" + " }\n" + " protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { \n" + "%BODY_dense%" + " }\n" + " protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int alen, int len, long grix, int rix) { \n" + "%BODY_sparse%" + " }\n" + "}\n"; private static final String TEMPLATE_ROWAGG_OUT = " c[rix] = %IN%;\n"; private static final String TEMPLATE_FULLAGG_OUT = " c[0] += %IN%;\n"; private static final String TEMPLATE_NOAGG_OUT = " LibSpoofPrimitives.vectWrite(%IN%, c, ci, %LEN%);\n"; public CNodeRow(ArrayList<CNode> inputs, CNode output) { super(inputs, output); } private RowType _type = null; //access pattern private long _constDim2 = -1; //constant number of output columns private int _numVectors = -1; //number of intermediate vectors public void setRowType(RowType type) { _type = type; _hash = 0; } public RowType getRowType() { return _type; } public void setNumVectorIntermediates(int num) { _numVectors = num; _hash = 0; } public int getNumVectorIntermediates() { return _numVectors; } public void setConstDim2(long dim2) { _constDim2 = dim2; _hash = 0; } public long getConstDim2() { return _constDim2; } @Override public void renameInputs() { rRenameDataNode(_output, _inputs.get(0), "a"); // input matrix renameInputs(_inputs, 1); } @Override public String codegen(boolean sparse) { // note: ignore sparse flag, generate both String tmp = TEMPLATE; //generate dense/sparse bodies String tmpDense = _output.codegen(false) + getOutputStatement(_output.getVarname()); _output.resetGenerated(); String tmpSparse = _output.codegen(true) + getOutputStatement(_output.getVarname()); tmp = tmp.replace("%TMP%", createVarname()); tmp = tmp.replace("%BODY_dense%", tmpDense); tmp = tmp.replace("%BODY_sparse%", tmpSparse); //replace outputs tmp = tmp.replace("%OUT%", "c"); tmp = tmp.replace("%POSOUT%", "0"); //replace size information tmp = tmp.replace("%LEN%", "len"); //replace colvector information and number of vector intermediates tmp = tmp.replace("%TYPE%", _type.name()); tmp = tmp.replace("%CONST_DIM2%", String.valueOf(_constDim2)); tmp = tmp.replace("%TB1%", String.valueOf(TemplateUtils.containsBinary(_output, BinType.VECT_MATRIXMULT))); tmp = tmp.replace("%VECT_MEM%", String.valueOf(_numVectors)); return tmp; } private String getOutputStatement(String varName) { switch (_type) { case NO_AGG: case NO_AGG_B1: case NO_AGG_CONST: return TEMPLATE_NOAGG_OUT.replace("%IN%", varName).replace("%LEN%", _output.getVarname() + ".length"); case FULL_AGG: return TEMPLATE_FULLAGG_OUT.replace("%IN%", varName); case ROW_AGG: return TEMPLATE_ROWAGG_OUT.replace("%IN%", varName); default: return ""; //_type.isColumnAgg() } } @Override public void setOutputDims() { // TODO Auto-generated method stub } @Override public SpoofOutputDimsType getOutputDimType() { switch (_type) { case NO_AGG: return SpoofOutputDimsType.INPUT_DIMS; case NO_AGG_B1: return SpoofOutputDimsType.ROW_RANK_DIMS; case NO_AGG_CONST: return SpoofOutputDimsType.INPUT_DIMS_CONST2; case FULL_AGG: return SpoofOutputDimsType.SCALAR; case ROW_AGG: return SpoofOutputDimsType.ROW_DIMS; case COL_AGG: return SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector case COL_AGG_T: return SpoofOutputDimsType.COLUMN_DIMS_ROWS; //column vector case COL_AGG_B1: return SpoofOutputDimsType.COLUMN_RANK_DIMS; case COL_AGG_B1_T: return SpoofOutputDimsType.COLUMN_RANK_DIMS_T; case COL_AGG_B1R: return SpoofOutputDimsType.RANK_DIMS_COLS; case COL_AGG_CONST: return SpoofOutputDimsType.VECT_CONST2; default: throw new RuntimeException("Unsupported row type: " + _type.toString()); } } @Override public CNodeTpl clone() { CNodeRow tmp = new CNodeRow(_inputs, _output); tmp.setRowType(_type); tmp.setNumVectorIntermediates(_numVectors); return tmp; } @Override public int hashCode() { if (_hash == 0) { int h = UtilFunctions.intHashCode(super.hashCode(), _type.hashCode()); h = UtilFunctions.intHashCode(h, Long.hashCode(_constDim2)); _hash = UtilFunctions.intHashCode(h, Integer.hashCode(_numVectors)); } return _hash; } @Override public boolean equals(Object o) { if (!(o instanceof CNodeRow)) return false; CNodeRow that = (CNodeRow) o; return super.equals(o) && _type == that._type && _numVectors == that._numVectors && _constDim2 == that._constDim2 && equalInputReferences(_output, that._output, _inputs, that._inputs); } @Override public String getTemplateInfo() { StringBuilder sb = new StringBuilder(); sb.append("SPOOF ROWAGGREGATE [type="); sb.append(_type.name()); sb.append(", reqVectMem="); sb.append(_numVectors); sb.append("]"); return sb.toString(); } }