org.apache.sysml.hops.codegen.opt.ReachabilityGraph.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.sysml.hops.codegen.opt.ReachabilityGraph.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.hops.codegen.opt;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;

import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.sysml.hops.codegen.opt.PlanSelection.VisitMarkCost;

/**
 *  
 */
public class ReachabilityGraph {
    private HashMap<Pair<Long, Long>, NodeLink> _matPoints = null;
    private NodeLink _root = null;

    private InterestingPoint[] _searchSpace;
    private CutSet[] _cutSets;

    public ReachabilityGraph(PlanPartition part, CPlanMemoTable memo) {
        //create repository of materialization points
        _matPoints = new HashMap<>();
        for (InterestingPoint p : part.getMatPointsExt())
            _matPoints.put(Pair.of(p._fromHopID, p._toHopID), new NodeLink(p));

        //create reachability graph
        _root = new NodeLink(null);
        HashSet<VisitMarkCost> visited = new HashSet<>();
        for (Long hopID : part.getRoots()) {
            Hop rootHop = memo.getHopRefs().get(hopID);
            addInputNodeLinks(rootHop, _root, part, memo, visited);
        }

        //create candidate cutsets 
        List<NodeLink> tmpCS = _matPoints.values().stream().filter(p -> p._inputs.size() > 0 && p._p != null)
                .sorted().collect(Collectors.toList());

        //short-cut for partitions without cutsets
        if (tmpCS.isEmpty()) {
            _cutSets = new CutSet[0];
            //sort materialization points in decreasing order of their sizes
            //which can improve the pruning efficiency by skipping larger sub-spaces.
            _searchSpace = sortBySize(part.getMatPointsExt(), memo, false);
            return;
        }

        //create composite cutsets 
        ArrayList<ArrayList<NodeLink>> candCS = new ArrayList<>();
        ArrayList<NodeLink> current = new ArrayList<>();
        for (NodeLink node : tmpCS) {
            if (current.isEmpty())
                current.add(node);
            else if (current.get(0).equals(node))
                current.add(node);
            else {
                candCS.add(current);
                current = new ArrayList<>();
                current.add(node);
            }
        }
        if (!current.isEmpty())
            candCS.add(current);

        //evaluate cutsets (single, and duplicate pairs)
        ArrayList<ArrayList<NodeLink>> remain = new ArrayList<>();
        ArrayList<Pair<CutSet, Double>> cutSets = evaluateCutSets(candCS, remain);
        if (!remain.isEmpty() && remain.size() < 5) {
            //second chance: for pairs for remaining candidates
            ArrayList<ArrayList<NodeLink>> candCS2 = new ArrayList<>();
            for (int i = 0; i < remain.size() - 1; i++)
                for (int j = i + 1; j < remain.size(); j++) {
                    ArrayList<NodeLink> tmp = new ArrayList<>();
                    tmp.addAll(remain.get(i));
                    tmp.addAll(remain.get(j));
                    candCS2.add(tmp);
                }
            ArrayList<Pair<CutSet, Double>> cutSets2 = evaluateCutSets(candCS2, remain);
            //ensure constructed cutsets are disjoint
            HashSet<InterestingPoint> testDisjoint = new HashSet<>();
            for (Pair<CutSet, Double> cs : cutSets2) {
                if (!CollectionUtils.containsAny(testDisjoint, Arrays.asList(cs.getLeft().cut))) {
                    cutSets.add(cs);
                    CollectionUtils.addAll(testDisjoint, cs.getLeft().cut);
                }
            }
        }

        //sort and linearize search space according to scores
        _cutSets = cutSets.stream().sorted(Comparator.comparing(p -> p.getRight())).map(p -> p.getLeft())
                .toArray(CutSet[]::new);

        //created sorted order of materialization points
        //(cut sets in predetermined order, other points sorted by size)
        HashMap<InterestingPoint, Integer> probe = new HashMap<>();
        ArrayList<InterestingPoint> lsearchSpace = new ArrayList<>();
        for (CutSet cs : _cutSets) {
            CollectionUtils.addAll(lsearchSpace, cs.cut);
            for (InterestingPoint p : cs.cut)
                probe.put(p, probe.size());
        }
        //sort materialization points in decreasing order of their sizes
        //which can improve the pruning efficiency by skipping larger sub-spaces.
        for (InterestingPoint p : sortBySize(part.getMatPointsExt(), memo, false))
            if (!probe.containsKey(p)) {
                lsearchSpace.add(p);
                probe.put(p, probe.size());
            }
        _searchSpace = lsearchSpace.toArray(new InterestingPoint[0]);

        //finalize cut sets (update positions wrt search space)
        for (CutSet cs : _cutSets)
            cs.updatePositions(probe);

        //final sanity check of interesting points
        if (_searchSpace.length != part.getMatPointsExt().length)
            throw new RuntimeException("Corrupt linearized search space: " + _searchSpace.length + " vs "
                    + part.getMatPointsExt().length);
    }

    public InterestingPoint[] getSortedSearchSpace() {
        return _searchSpace;
    }

    public boolean isCutSet(boolean[] plan) {
        for (CutSet cs : _cutSets)
            if (isCutSet(cs, plan))
                return true;
        return false;
    }

    public boolean isCutSet(CutSet cs, boolean[] plan) {
        boolean ret = true;
        for (int i = 0; i < cs.posCut.length && ret; i++)
            ret &= plan[cs.posCut[i]];
        return ret;
    }

    public CutSet getCutSet(boolean[] plan) {
        for (CutSet cs : _cutSets)
            if (isCutSet(cs, plan))
                return cs;
        throw new RuntimeException("No valid cut set found.");
    }

    public long getNumSkipPlans(boolean[] plan) {
        for (CutSet cs : _cutSets)
            if (isCutSet(cs, plan)) {
                int pos = cs.posCut[cs.posCut.length - 1];
                return UtilFunctions.pow(2, plan.length - pos - 1);
            }
        throw new RuntimeException("Failed to compute " + "number of skip plans for plan without cutset.");
    }

    public SubProblem[] getSubproblems(boolean[] plan) {
        CutSet cs = getCutSet(plan);
        return new SubProblem[] { new SubProblem(cs.cut.length, cs.posLeft, cs.left),
                new SubProblem(cs.cut.length, cs.posRight, cs.right) };
    }

    @Override
    public String toString() {
        return "ReachabilityGraph(" + _matPoints.size() + "):\n" + _root.explain(new HashSet<>());
    }

    private void addInputNodeLinks(Hop current, NodeLink parent, PlanPartition part, CPlanMemoTable memo,
            HashSet<VisitMarkCost> visited) {
        if (visited.contains(new VisitMarkCost(current.getHopID(), parent._ID)))
            return;

        //process children
        for (Hop in : current.getInput()) {
            if (InterestingPoint.isMatPoint(part.getMatPointsExt(), current.getHopID(), in.getHopID())) {
                NodeLink tmp = _matPoints.get(Pair.of(current.getHopID(), in.getHopID()));
                parent.addInput(tmp);
                addInputNodeLinks(in, tmp, part, memo, visited);
            } else
                addInputNodeLinks(in, parent, part, memo, visited);
        }

        visited.add(new VisitMarkCost(current.getHopID(), parent._ID));
    }

    private void rCollectInputs(NodeLink current, HashSet<NodeLink> probe, HashSet<NodeLink> inputs) {
        for (NodeLink c : current._inputs)
            if (!probe.contains(c)) {
                rCollectInputs(c, probe, inputs);
                inputs.add(c);
            }
    }

    private ArrayList<Pair<CutSet, Double>> evaluateCutSets(ArrayList<ArrayList<NodeLink>> candCS,
            ArrayList<ArrayList<NodeLink>> remain) {
        ArrayList<Pair<CutSet, Double>> cutSets = new ArrayList<>();

        for (ArrayList<NodeLink> cand : candCS) {
            HashSet<NodeLink> probe = new HashSet<>(cand);

            //determine subproblems for cutset candidates
            HashSet<NodeLink> part1 = new HashSet<>();
            rCollectInputs(_root, probe, part1);
            HashSet<NodeLink> part2 = new HashSet<>();
            for (NodeLink rNode : cand)
                rCollectInputs(rNode, probe, part2);

            //select, score and create cutsets
            if (!CollectionUtils.containsAny(part1, part2) && !part1.isEmpty() && !part2.isEmpty()) {
                //score cutsets (smaller is better)
                double base = UtilFunctions.pow(2, _matPoints.size());
                double numComb = UtilFunctions.pow(2, cand.size());
                double score = (numComb - 1) / numComb * base + 1 / numComb * UtilFunctions.pow(2, part1.size())
                        + 1 / numComb * UtilFunctions.pow(2, part2.size());

                //construct cutset
                cutSets.add(Pair.of(new CutSet(cand.stream().map(p -> p._p).toArray(InterestingPoint[]::new),
                        part1.stream().map(p -> p._p).toArray(InterestingPoint[]::new),
                        part2.stream().map(p -> p._p).toArray(InterestingPoint[]::new)), score));
            } else {
                remain.add(cand);
            }
        }

        return cutSets;
    }

    private InterestingPoint[] sortBySize(InterestingPoint[] points, CPlanMemoTable memo, boolean asc) {
        return Arrays.stream(points)
                .sorted(Comparator.comparing(p -> (asc ? 1 : -1) * getSize(memo.getHopRefs().get(p.getToHopID()))))
                .toArray(InterestingPoint[]::new);
    }

    private static long getSize(Hop hop) {
        return Math.max(hop.getDim1(), 1) * Math.max(hop.getDim2(), 1);
    }

    public static class SubProblem {
        public int offset;
        public int[] freePos;
        public InterestingPoint[] freeMat;

        public SubProblem(int off, int[] pos, InterestingPoint[] mat) {
            offset = off;
            freePos = pos;
            freeMat = mat;
        }

        @Override
        public String toString() {
            return "SubProblem: " + Arrays.toString(freeMat) + "; " + offset + "; " + Arrays.toString(freePos);
        }
    }

    private static class CutSet {
        private final InterestingPoint[] cut;
        private final InterestingPoint[] left;
        private final InterestingPoint[] right;
        private int[] posCut;
        private int[] posLeft;
        private int[] posRight;

        private CutSet(InterestingPoint[] cutPoints, InterestingPoint[] l, InterestingPoint[] r) {
            cut = cutPoints;
            left = (InterestingPoint[]) ArrayUtils.addAll(cut, l);
            right = (InterestingPoint[]) ArrayUtils.addAll(cut, r);
        }

        private void updatePositions(HashMap<InterestingPoint, Integer> probe) {
            int lenCut = cut.length;
            posCut = new int[lenCut];
            for (int i = 0; i < lenCut; i++)
                posCut[i] = probe.get(cut[i]);

            int lenLeft = left.length - cut.length;
            posLeft = new int[lenLeft];
            for (int i = 0; i < lenLeft; i++)
                posLeft[i] = probe.get(left[lenCut + i]);

            int lenRight = right.length - cut.length;
            posRight = new int[lenRight];
            for (int i = 0; i < lenRight; i++)
                posRight[i] = probe.get(right[lenCut + i]);
        }

        @Override
        public String toString() {
            return "Cut : " + Arrays.toString(cut);
        }
    }

    private static class NodeLink implements Comparable<NodeLink> {
        private static final IDSequence _seqID = new IDSequence();

        private ArrayList<NodeLink> _inputs = new ArrayList<>();
        private long _ID;
        private InterestingPoint _p;

        private NodeLink(InterestingPoint p) {
            _ID = _seqID.getNextID();
            _p = p;
        }

        private void addInput(NodeLink in) {
            _inputs.add(in);
        }

        @Override
        public int hashCode() {
            return Arrays.hashCode(new int[] { _inputs.hashCode(), Long.hashCode(_ID), _p.hashCode() });
        }

        @Override
        public boolean equals(Object o) {
            if (!(o instanceof NodeLink))
                return false;
            NodeLink that = (NodeLink) o;
            boolean ret = (_inputs.size() == that._inputs.size());
            for (int i = 0; i < _inputs.size() && ret; i++)
                ret &= (_inputs.get(i)._ID == that._inputs.get(i)._ID);
            return ret;
        }

        @Override
        public int compareTo(NodeLink that) {
            if (_inputs.size() > that._inputs.size())
                return -1;
            else if (_inputs.size() < that._inputs.size())
                return 1;
            for (int i = 0; i < _inputs.size(); i++) {
                int comp = Long.compare(_inputs.get(i)._ID, that._inputs.get(i)._ID);
                if (comp != 0)
                    return comp;
            }
            return 0;
        }

        @Override
        public String toString() {
            StringBuilder inputs = new StringBuilder();
            for (NodeLink in : _inputs) {
                if (inputs.length() > 0)
                    inputs.append(",");
                inputs.append(in._ID);
            }
            return _ID + " (" + inputs.toString() + ") " + ((_p != null) ? _p : "null");
        }

        private String explain(HashSet<Long> visited) {
            if (visited.contains(_ID))
                return "";
            //add children
            StringBuilder sb = new StringBuilder();
            StringBuilder inputs = new StringBuilder();
            for (NodeLink in : _inputs) {
                String tmp = in.explain(visited);
                if (!tmp.isEmpty())
                    sb.append(tmp + "\n");
                if (inputs.length() > 0)
                    inputs.append(",");
                inputs.append(in._ID);
            }
            //add node itself
            sb.append(_ID + " (" + inputs + ") " + ((_p != null) ? _p : "null"));
            visited.add(_ID);

            return sb.toString();
        }
    }
}