org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.hops.codegen.cplan;

import java.util.ArrayList;
import java.util.Arrays;

import org.apache.commons.collections.CollectionUtils;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
import org.apache.sysml.runtime.util.UtilFunctions;

public class CNodeMultiAgg extends CNodeTpl {
    private static final String TEMPLATE = "package codegen;\n"
            + "import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;\n"
            + "import org.apache.sysml.runtime.codegen.SpoofCellwise;\n"
            + "import org.apache.sysml.runtime.codegen.SpoofCellwise.AggOp;\n"
            + "import org.apache.sysml.runtime.codegen.SpoofMultiAggregate;\n"
            + "import org.apache.sysml.runtime.codegen.SpoofOperator.SideInput;\n"
            + "import org.apache.commons.math3.util.FastMath;\n" + "\n"
            + "public final class %TMP% extends SpoofMultiAggregate { \n" + "  public %TMP%() {\n"
            + "    super(%SPARSE_SAFE%, %AGG_OP%);\n" + "  }\n"
            + "  protected void genexec(double a, SideInput[] b, double[] scalars, double[] c, "
            + "int m, int n, long grix, int rix, int cix) { \n" + "%BODY_dense%" + "  }\n" + "}\n";
    private static final String TEMPLATE_OUT_SUM = "    c[%IX%] += %IN%;\n";
    private static final String TEMPLATE_OUT_SUMSQ = "    c[%IX%] += %IN% * %IN%;\n";
    private static final String TEMPLATE_OUT_MIN = "    c[%IX%] = Math.min(c[%IX%], %IN%);\n";
    private static final String TEMPLATE_OUT_MAX = "    c[%IX%] = Math.max(c[%IX%], %IN%);\n";

    private ArrayList<CNode> _outputs = null;
    private ArrayList<AggOp> _aggOps = null;
    private ArrayList<Hop> _roots = null;
    private boolean _sparseSafe = false;

    public CNodeMultiAgg(ArrayList<CNode> inputs, ArrayList<CNode> outputs) {
        super(inputs, null);
        _outputs = outputs;
    }

    public ArrayList<CNode> getOutputs() {
        return _outputs;
    }

    @Override
    public void resetVisitStatusOutputs() {
        for (CNode output : _outputs)
            output.resetVisitStatus();
    }

    public void setAggOps(ArrayList<AggOp> aggOps) {
        _aggOps = aggOps;
        _hash = 0;
    }

    public ArrayList<AggOp> getAggOps() {
        return _aggOps;
    }

    public void setRootNodes(ArrayList<Hop> roots) {
        _roots = roots;
    }

    public ArrayList<Hop> getRootNodes() {
        return _roots;
    }

    public void setSparseSafe(boolean flag) {
        _sparseSafe = flag;
    }

    public boolean isSparseSafe() {
        return _sparseSafe;
    }

    @Override
    public void renameInputs() {
        rRenameDataNode(_outputs, _inputs.get(0), "a"); // input matrix
        renameInputs(_outputs, _inputs, 1);
    }

    @Override
    public String codegen(boolean sparse) {
        // note: ignore sparse flag, generate both
        String tmp = TEMPLATE;

        //generate dense/sparse bodies
        StringBuilder sb = new StringBuilder();
        for (CNode out : _outputs)
            sb.append(out.codegen(false));
        for (CNode out : _outputs)
            out.resetGenerated();

        //append output assignments
        for (int i = 0; i < _outputs.size(); i++) {
            CNode out = _outputs.get(i);
            String tmpOut = getAggTemplate(i);
            //get variable name (w/ handling of direct consumption of inputs)
            String varName = (out instanceof CNodeData
                    && ((CNodeData) out).getHopID() == ((CNodeData) _inputs.get(0)).getHopID()) ? "a"
                            : out.getVarname();
            tmpOut = tmpOut.replace("%IN%", varName);
            tmpOut = tmpOut.replace("%IX%", String.valueOf(i));
            sb.append(tmpOut);
        }

        //replace class name and body
        tmp = tmp.replace("%TMP%", createVarname());
        tmp = tmp.replace("%BODY_dense%", sb.toString());

        //replace meta data information
        String aggList = "";
        for (AggOp aggOp : _aggOps) {
            aggList += !aggList.isEmpty() ? "," : "";
            aggList += "AggOp." + aggOp.name();
        }
        tmp = tmp.replace("%AGG_OP%", aggList);
        tmp = tmp.replace("%SPARSE_SAFE%", String.valueOf(isSparseSafe()));

        return tmp;
    }

    @Override
    public void setOutputDims() {

    }

    @Override
    public SpoofOutputDimsType getOutputDimType() {
        return SpoofOutputDimsType.MULTI_SCALAR;
    }

    @Override
    public CNodeTpl clone() {
        CNodeMultiAgg ret = new CNodeMultiAgg(_inputs, _outputs);
        ret.setAggOps(getAggOps());
        return ret;
    }

    @Override
    public int hashCode() {
        if (_hash == 0) {
            int h = super.hashCode();
            for (int i = 0; i < _outputs.size(); i++) {
                h = UtilFunctions.intHashCode(h,
                        UtilFunctions.intHashCode(_outputs.get(i).hashCode(), _aggOps.get(i).hashCode()));
            }
            _hash = h;
        }
        return _hash;
    }

    @Override
    public boolean equals(Object o) {
        if (!(o instanceof CNodeMultiAgg))
            return false;
        CNodeMultiAgg that = (CNodeMultiAgg) o;
        return super.equals(o) && CollectionUtils.isEqualCollection(_aggOps, that._aggOps)
                && equalInputReferences(_outputs, that._outputs, _inputs, that._inputs);
    }

    @Override
    public String getTemplateInfo() {
        StringBuilder sb = new StringBuilder();
        sb.append("SPOOF MULTIAGG [aggOps=");
        sb.append(Arrays.toString(_aggOps.toArray(new AggOp[0])));
        sb.append("]");
        return sb.toString();
    }

    private String getAggTemplate(int pos) {
        switch (_aggOps.get(pos)) {
        case SUM:
            return TEMPLATE_OUT_SUM;
        case SUM_SQ:
            return TEMPLATE_OUT_SUMSQ;
        case MIN:
            return TEMPLATE_OUT_MIN;
        case MAX:
            return TEMPLATE_OUT_MAX;
        default:
            return null;
        }
    }
}