co.cask.cdap.etl.planner.ControlDag.java Source code

Java tutorial

Introduction

Here is the source code for co.cask.cdap.etl.planner.ControlDag.java

Source

/*
 * Copyright  2016 Cask Data, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package co.cask.cdap.etl.planner;

import co.cask.cdap.etl.proto.Connection;
import com.google.common.base.Joiner;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multiset;

import java.util.Collection;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import java.util.UUID;

/**
 * A DAG (directed acyclic graph) where edges represent a happens-before relationship.
 * In these types of scenarios, certain edges may be redundant and can be removed.
 * This simplifies the dag into something that is much easier to work with if it needs
 * to be used as a fork-join type of dag for workflow execution.
 */
public class ControlDag extends Dag {
    private static final Set<String> EMPTY = ImmutableSet.of();
    private final Multiset<String> nodeVisits;

    public ControlDag(Collection<Connection> connections) {
        super(connections);
        this.nodeVisits = HashMultiset.create();
    }

    /**
     * Record that this node was visited.
     *
     * @param node node that was visited
     * @return the number of times this node was visited, including the visit from this call
     */
    public int visit(String node) {
        nodeVisits.add(node);
        return nodeVisits.count(node);
    }

    /**
     * Resets the number of times each node was visited back to 0.
     */
    public void resetVisitCounts() {
        nodeVisits.clear();
    }

    /**
     * Flattens the control dag to remove connections between branches of different forks, which would
     * make the dag unusable in pure fork-join workflows.
     *
     * For example the following dag is not a fork-join dag:
     *
     *      |--> n2 -------|
     *      |              |--> n5
     *      |--> n3 -------|
     *  n1--|           |
     *      |           v
     *      |--> n4 --> n6
     *
     *  There are many ways to turn this a fork-join while still respecting all happens-before relationships,
     *  but for simplicity we'll use an algorithm that doesn't have any nested forks and will turn the above into:
     *
     *       |--> n2 --|
     *       |         |               |--> n5 --|
     *  n1 --|--> n3 --|--> n2.n3.n4 --|         |--> n5.n6
     *       |         |               |--> n6 --|
     *       |--> n4 --|
     *
     *  The algorithm is to insert a join node whenever it sees a fork. Every time there is a fork, we will follow
     *  each branch to its endpoint (a node that forks, merges, or is a sink), then insert a join node that each
     *  branch endpoint connects to.
     */
    public void flatten() {
        // this should never be the case, as it should be checked when the dag is created.
        if (sources.isEmpty()) {
            throw new IllegalStateException("There are no sources in the graph, which means there is a cycle.");
        }

        trim();
        String source;
        // if we have multiple sources, insert a fork node as the new source
        if (sources.size() > 1) {
            // copy to avoid concurrent modification
            Set<String> sourcesCopy = new HashSet<>(sources);
            String newId = generateJoinNodeName(sourcesCopy);
            addNode(newId, EMPTY, sourcesCopy);
            source = newId;
        } else {
            source = sources.iterator().next();
        }

        flattenFrom(source);
    }

    private void flattenFrom(String node) {
        Set<String> outputs = outgoingConnections.get(node);
        if (outputs.isEmpty()) {
            return;
        }

        if (outputs.size() == 1) {
            flattenFrom(outputs.iterator().next());
            return;
        }

        Multimap<String, String> branchEndpointOutputs = HashMultimap.create();
        // can't just use branchEndpointOutputs.keySet(),
        // because that won't track branch endpoints that had no output (sinks)
        Set<String> branchEndpoints = new HashSet<>();
        for (String output : outputs) {
            String branchEndpoint = findBranchEnd(output);
            branchEndpoints.add(branchEndpoint);
            branchEndpointOutputs.putAll(branchEndpoint, outgoingConnections.get(branchEndpoint));
        }

        // if all the branch endpoints connect to a single node, there is no need to add a join node
        Set<String> endpointOutputs = new HashSet<>(branchEndpointOutputs.values());
        if (endpointOutputs.size() == 1) {
            flattenFrom(endpointOutputs.iterator().next());
            return;
        }

        // add a connection from each branch endpoint to a newly added join node
        // then move all outgoing connections from each branch endpoint so that they are coming out of the new join node
        String newJoinNode = generateJoinNodeName(branchEndpoints);
        addNode(newJoinNode, branchEndpoints, endpointOutputs);
        // remove the outgoing connections from endpoints that aren't going to our new join node
        for (Map.Entry<String, String> endpointEntry : branchEndpointOutputs.entries()) {
            removeConnection(endpointEntry.getKey(), endpointEntry.getValue());
        }
        /*
           have to trim again due to reshuffling of nodes. For example, if we have:
            
                  |--> n3
        |--> n2 --|
        |         |--> n4
           n1 --|              |
        |              v
        |--> n5 -----> n6
            
           after we insert the new join node we'll have:
            
        |--> n2 --|           |--> n3
        |         |           |
           n1 --|         |--> join --|--> n4
        |         |           |    |
        |--> n5 --|           |    v
                              |--> n6
            
           and we need to remove the connection from join -> n6, otherwise the algorithm will get messed up
         */
        trim();

        // then keep flattening from the new join node
        flattenFrom(newJoinNode);
    }

    // go down a branch until we find a node with multiple outputs, a node with multiple inputs, or a sink
    private String findBranchEnd(String node) {
        Set<String> outputs = outgoingConnections.get(node);
        // if this is a sink, or if this is a fork on a branch
        if (outputs.isEmpty() || outputs.size() > 1) {
            return node;
        }
        // if the next node is a join node
        String output = outputs.iterator().next();
        if (incomingConnections.get(output).size() > 1) {
            return node;
        }
        // otherwise keep going down this branch
        return findBranchEnd(output);
    }

    /**
     * Returns the number of paths from the start node to the stop node.
     * The number of paths from a node to itself is 1.
     *
     * @param start the node to start from
     * @param stop the node to end at
     * @return the number of paths from the start node to the stop node
     */
    private int numPaths(String start, String stop) {
        if (start.equals(stop)) {
            return 1;
        }
        int count = 0;
        for (String output : getNodeOutputs(start)) {
            count += numPaths(output, stop);
        }
        return count;
    }

    /**
     * Trims any redundant control connections.
     *
     * For example:
     *   n1 ------> n2
     *       |      |
     *       |      v
     *       |----> n3
     * has a redundant edge n1 -> n3, because the edge from n2 -> n3 already enforces n1 -> n3.
     * The approach is look at each node (call it nodeB). For each input into nodeB (call it nodeA),
     * if there is another path from nodeA to nodeB besides the direct edge, we can remove the edge nodeA -> nodeB.
     *
     * @return number of connections removed.
     */
    public int trim() {
        int numRemoved = 0;
        for (String node : nodes) {
            Set<Connection> toRemove = new HashSet<>();
            for (String nodeInput : getNodeInputs(node)) {
                if (numPaths(nodeInput, node) > 1) {
                    toRemove.add(new Connection(nodeInput, node));
                }
            }
            for (Connection conn : toRemove) {
                removeConnection(conn.getFrom(), conn.getTo());
            }
            numRemoved += toRemove.size();
        }
        return numRemoved;
    }

    /**
     * Add a node with the following outputs and inputs
     */
    private void addNode(String node, Collection<String> inputs, Collection<String> outputs) {
        nodes.add(node);
        for (String output : outputs) {
            outgoingConnections.put(node, output);
            incomingConnections.put(output, node);
            sources.remove(output);
        }
        for (String input : inputs) {
            incomingConnections.put(node, input);
            outgoingConnections.put(input, node);
            sinks.remove(input);
        }
        if (outputs.isEmpty()) {
            sinks.add(node);
        }
        if (inputs.isEmpty()) {
            sources.add(node);
        }
    }

    private String generateJoinNodeName(Set<String> inputs) {
        // using sorted sets to guarantee the name is deterministic
        String name = Joiner.on('.').join(new TreeSet<>(inputs));
        if (nodes.contains(name)) {
            name += UUID.randomUUID().toString();
        }
        return name;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || getClass() != o.getClass()) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }

        ControlDag that = (ControlDag) o;

        return Objects.equals(nodeVisits, that.nodeVisits);
    }

    @Override
    public int hashCode() {
        return Objects.hash(super.hashCode(), nodeVisits);
    }

    @Override
    public String toString() {
        return "ControlDag{" + "nodeVisits=" + nodeVisits + "} " + super.toString();
    }
}