com.wrmsr.wava.TestWhatever.java Source code

Java tutorial

Introduction

Here is the source code for com.wrmsr.wava.TestWhatever.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.wrmsr.wava;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.wrmsr.wava.analyze.Analyses;
import com.wrmsr.wava.analyze.ControlFlowGraph;
import com.wrmsr.wava.analyze.ValueTypeAnalysis;
import com.wrmsr.wava.core.node.Binary;
import com.wrmsr.wava.core.node.Block;
import com.wrmsr.wava.core.node.Break;
import com.wrmsr.wava.core.node.BreakTable;
import com.wrmsr.wava.core.node.If;
import com.wrmsr.wava.core.node.Label;
import com.wrmsr.wava.core.node.Loop;
import com.wrmsr.wava.core.node.Node;
import com.wrmsr.wava.core.node.Nop;
import com.wrmsr.wava.core.node.Unary;
import com.wrmsr.wava.core.op.BinaryOp;
import com.wrmsr.wava.core.op.UnaryOp;
import com.wrmsr.wava.core.type.Name;
import com.wrmsr.wava.core.type.Type;
import com.wrmsr.wava.core.unit.Module;
import com.wrmsr.wava.driver.StandardFunctionProcessor;
import com.wrmsr.wava.basic.Basic;
import com.wrmsr.wava.basic.BasicDominatorInfo;
import com.wrmsr.wava.basic.BasicLoopInfo;
import com.wrmsr.wava.basic.BasicSet;
import com.wrmsr.wava.basic.Basics;
import com.wrmsr.wava.basic.match.BooleanMatching;
import com.wrmsr.wava.basic.match.LoopMatching;
import com.wrmsr.wava.basic.match.SimpleMatching;
import com.wrmsr.wava.yen.global.YModule;
import com.wrmsr.wava.yen.parser.ModuleFactory;
import com.wrmsr.wava.yen.parser.Parser;
import com.wrmsr.wava.yen.parser.element.Element;
import com.wrmsr.wava.yen.parser.input.Input;
import com.wrmsr.wava.yen.translation.UnitTranslation;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.junit.Test;

import javax.annotation.CheckReturnValue;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.wrmsr.wava.TestGraphviz.showGraph;
import static com.wrmsr.wava.core.node.Nodes.nodify;
import static com.wrmsr.wava.basic.Basics.getUnconditionalTarget;
import static com.wrmsr.wava.basic.Basics.minBasicIndex;
import static com.wrmsr.wava.basic.Basics.transformBasics;
import static com.wrmsr.wava.basic.match.BooleanMatching.matchBoolean;
import static com.wrmsr.wava.util.collect.MoreCollectors.toIdentityMap;
import static com.wrmsr.wava.util.collect.MoreCollectors.toImmutableList;
import static com.wrmsr.wava.util.collect.MoreCollectors.toImmutableSet;
import static com.wrmsr.wava.util.collect.MoreMaps.indexIdentityMap;
import static com.wrmsr.wava.util.collect.MoreOptionals.optionalToList;
import static com.wrmsr.wava.util.collect.MoreOptionals.optionalToStream;
import static com.wrmsr.wava.util.function.Bind.bind;
import static java.util.Objects.requireNonNull;

public class TestWhatever {
    @CheckReturnValue
    private static BasicSet shrinkSimpleLoop(BasicLoopInfo li, BasicSet basics, Basic basic) {
        if (basic.getAllTargets().size() != 2 || basic.getAllTargets().contains(Basics.UNREACHABLE_NAME)
                || basic.getAllTargets().contains(basic.getName())) {
            return basics;
        }
        Set<Name> backEdgeInputs = basics.getInputs(basic.getName()).stream()
                .filter(t -> li.getBackEdges().containsEntry(basic.getName(), t)).collect(toImmutableSet());
        if (backEdgeInputs.size() != 1) {
            return basics;
        }
        Name loopBodyName = getOnlyElement(backEdgeInputs);
        if (!basic.getAllTargets().contains(loopBodyName) || !ImmutableSet
                .copyOf(basics.getInputs().get(loopBodyName)).equals(ImmutableSet.of(basic.getName()))) {
            return basics;
        }

        Basic loopBody = requireNonNull(basics.get(loopBodyName));
        Name succName;
        Node condition;
        if (loopBodyName.equals(basic.getBreakTable().getDefaultTarget())) {
            succName = getOnlyElement(basic.getBreakTable().getTargets());
            // FIXME r u sure
            condition = new Unary(UnaryOp.EqZ, Type.I32, basic.getBreakTable().getCondition());
        } else {
            succName = basic.getBreakTable().getDefaultTarget();
            condition = basic.getBreakTable().getCondition();
        }

        Node loop = new Loop(
                // FIXME check
                loopBodyName, nodify(ImmutableList.<Node>builder().addAll(loopBody.getBody())
                        .add(new If(condition, new Break(loopBodyName, new Nop()), new Nop())).build()));

        Basic newBasic = new Basic(basic.getName(),
                ImmutableList.<Node>builder().addAll(basic.getBody()).add(loop).build(),
                new BreakTable(ImmutableList.of(), succName, new Nop()), minBasicIndex(basic, loopBody));

        basics = basics.replace(newBasic);
        basics = basics.remove(loopBodyName);

        return basics;
    }

    //    public static final class MatchedIfOr
    //    {
    //        public final Node condition;
    //    }

    //    public static Stream<BasicSet> collapseIfAnd(BasicSet basics, Basic basic)
    //    {
    //
    //    }

    public static Stream<BasicSet> collapseIfOr(BasicSet basics, Basic basic) {
        return matchBoolean(basic).flatMap(m1 -> {
            Basic or = basics.get(m1.ifFalse);
            if (!or.getBody().isEmpty() || !basics.getInputs(or).equals(ImmutableSet.of(basic.getName()))) {
                return Stream.empty();
            }
            Basic then = basics.get(m1.ifTrue);
            if (!basics.getInputs(then).equals(ImmutableSet.of(basic.getName(), or.getName()))) {
                return Stream.empty();
            }
            Optional<Name> after = getUnconditionalTarget(then.getBreakTable());
            if (!after.isPresent()) {
                return Stream.empty();
            }
            return matchBoolean(or).flatMap(m2 -> {
                if (!m2.ifTrue.equals(then.getName()) || !m2.ifFalse.equals(after.get())) {
                    return Stream.empty();
                }
                Basic newBasic = new Basic(basic.getName(),
                        ImmutableList.<Node>builder().addAll(basic.getBody())
                                .add(new If(new Binary(BinaryOp.CondOr, Type.I32, m1.condition, m2.condition),
                                        nodify(then.getBody()), new Nop()))
                                .build(),
                        new BreakTable(ImmutableList.of(), after.get(), new Nop()), minBasicIndex(basic, or, then));
                return Stream.of(basics.replace(newBasic).remove(or).remove(then));
            });
        });
    }

    // technical af
    //    private static void shrinkLoopSandwiches(Map<Name, Basic> basics, Multimap<Name, Name> inputs, Set<Name> loops, Multimap<Name, Name> backEdges)
    //    {
    //        basics = new HashMap<>(basics);
    //        inputs = HashMultimap.create(inputs);
    //        for (Name name : ImmutableList.copyOf(basics.keySet())) {
    //            if (!loops.contains(name)) {
    //                continue;
    //            }
    //            Basic basic = basics.get(name);
    //            if (basic == null) {
    //                continue;
    //            }
    //
    //            if (basic.getAllTargets().size() != 2 || basic.getAllTargets().contains(Basics.UNREACHABLE_NAME) || basic.getAllTargets().contains(name)) {
    //                continue;
    //            }
    //            Set<Name> backEdgeInputs = inputs.get(name).stream()
    //                    .filter(t -> backEdges.containsEntry(t, name))
    //                    .collect(toImmutableSet());
    //            if (backEdgeInputs.size() != 1) {
    //                continue;
    //            }
    //            Name loopBodyName = getOnlyElement(backEdgeInputs);
    //            if (!basic.getAllTargets().contains(loopBodyName) || !ImmutableSet.copyOf(inputs.get(loopBodyName)).equals(ImmutableSet.of(name))) {
    //                continue;
    //            }
    //        }
    //        return unmodifiableMap(basics);
    //    }

    private static void showBasics(Name fname, BasicSet basics, boolean drawDoms) throws Exception {
        BasicDominatorInfo dt = BasicDominatorInfo.build(basics);
        BasicLoopInfo li = BasicLoopInfo.build(basics, dt);

        Function<Name, String> nameMangler = n -> n.get().replace('$', '_');
        StringBuilder sb = new StringBuilder();
        sb.append("digraph G {\n");
        sb.append("labelloc=\"t\";");
        sb.append(String.format("label=\"%s (%d)\";", fname.get(), basics.size()));
        List<Name> order = basics.basics().stream()
                .sorted((l, r) -> Integer.compare(l.getIndex().getAsInt(), r.getIndex().getAsInt()))
                .map(Basic::getName).collect(toImmutableList());
        for (Name name : order) {
            Basic basic = basics.get(name);
            Name idom = dt.getImmediateDominator(basic.getName());
            Set<Name> domFront = dt.getDominanceFrontiers().get(basic.getName());
            boolean isLoop = li.isLoop(basic.getName());
            boolean isIf = !isLoop && basic.getAllTargets().size() == 2; // FIXME WRONG 1035
            boolean isSingle = basics.getInputs(basic).size() == 1 && basic.getAllTargets().size() == 1;
            String nodeStyle = isLoop ? "fillcolor=blue,style=filled"
                    : isIf ? "fillcolor=green,style=filled" : isSingle ? "fillcolor=orange,style=filled" : "";
            int totalSize = basic.getBody().stream().mapToInt(Analyses::getChildCount).sum();
            sb.append(String.format("%s [label=\"%s: %d, %d, %d\",%s];\n", nameMangler.apply(basic.getName()),
                    basic.getName().get(), basic.getIndex().getAsInt(), basic.getBody().size(), totalSize,
                    nodeStyle));
            if (drawDoms) {
                if (idom != null) {
                    sb.append(String.format("%s -> %s [color=red];\n", nameMangler.apply(idom),
                            nameMangler.apply(basic.getName())));
                }
                domFront.forEach(df -> sb.append(String.format("%s -> %s [color=red,style=dotted];\n",
                        nameMangler.apply(basic.getName()), nameMangler.apply(df))));
            }
            Name loopParent = li.getLoopParent(name);
            if (loopParent != null) {
                sb.append(String.format("%s -> %s [color=blue,style=dotted];\n", nameMangler.apply(loopParent),
                        nameMangler.apply(basic.getName())));
            }
            basic.getAllTargets().forEach(output -> {
                String edgeStyle = li.getBackEdges().containsEntry(output, basic.getName()) ? "fillcolor=blue" : "";
                sb.append(String.format("%s -> %s [%s];\n", nameMangler.apply(basic.getName()),
                        nameMangler.apply(output), edgeStyle));
            });
        }
        Basics.TERMINAL_NAMES
                .forEach(n -> sb.append(String.format("%s [label=\"%s\"];\n", nameMangler.apply(n), n.get())));
        sb.append("}\n");
        showGraph(sb.toString());
    }

    public static Node cfgStackify(Map<Name, Basic> basics) {
        checkState(basics.values().stream().allMatch(basic -> basic.getIndex().isPresent()));
        checkState(basics.values().stream().flatMap(basic -> optionalToStream(basic.getIndex()).boxed())
                .collect(toImmutableSet()).size() == basics.size());

        throw new IllegalStateException();
    }

    public static Node worst(Map<Name, Basic> basics, Set<Name> loops) {
        checkState(basics.values().stream().allMatch(basic -> basic.getIndex().isPresent()));
        checkState(basics.values().stream().flatMap(basic -> optionalToStream(basic.getIndex()).boxed())
                .collect(toImmutableSet()).size() == basics.size());
        List<Basic> basicList = new ArrayList<>(basics.values());
        Collections.sort(basicList,
                (left, right) -> Integer.compare(left.getIndex().getAsInt(), right.getIndex().getAsInt()));
        checkState(!basicList.isEmpty());

        Node ret = new Nop();
        for (Basic basic : basicList) {
            Block body = new Block(ImmutableList.<Node>builder().add(ret).addAll(basic.getBody())
                    .addAll(optionalToList(Basics.simplifyBreakTable(basic.getBreakTable()))).build());
            if (loops.contains(basic.getName())) {
                ret = new Loop(basic.getName(), body);
            } else {
                ret = new Label(basic.getName(), body);
            }
        }

        return ret;
    }

    public static void doFunction(com.wrmsr.wava.core.unit.Function function_) throws Exception {
        System.out.println(function_.getName().get());
        com.wrmsr.wava.core.unit.Function function = new StandardFunctionProcessor().processFunction(function_);

        BasicSet basics;
        {
            Node root = function.getBody();
            Map<Name, Node> namedNodes = Analyses.getNamedNodes(root);
            Map<Node, Name> namedNodeNames = namedNodes.entrySet().stream()
                    .collect(toIdentityMap(e -> e.getValue(), e -> e.getKey()));
            List<Node> nodes = Analyses.linearize(root);
            Map<Node, Integer> nodeIndices = indexIdentityMap(nodes);
            Map<Node, Name> nodeNames = nodes.stream()
                    .map(node -> ImmutablePair.of(node,
                            namedNodeNames.containsKey(node) ? namedNodeNames.get(node)
                                    : Name.of("node" + nodeIndices.get(node))))
                    .collect(toIdentityMap(Map.Entry::getKey, Map.Entry::getValue));

            ValueTypeAnalysis vta = ValueTypeAnalysis.analyze(root, false);
            ControlFlowGraph cfg = ControlFlowGraph.analyzeShallow(root, namedNodes, vta);

            basics = BasicSet.build(
                    Basics.buildBasics(cfg, vta, nodeNames, nodeIndices).values().stream().map(Basics::cleanBasic));
        }

        //        showBasics(basics, true);
        //        {
        //            Node root = worst(basics, findBasicLoops(basics, generateBasicDominatorTree(basics)));
        //            JDeclaration jdecl = jcompileFunction(new com.wrmsr.wava.core.unit.Function(function.getName(), function.getResult(), function.getArgCount(), function.getLocals(), root)).get(0);
        //            CodeBlock.Builder code = CodeBlock.builder();
        //            new JRenderer(code).renderDeclaration(jdecl);
        //            System.out.println(JRenderer.renderWithIndent(code.build(), "    "));
        //        }

        while (true) {
            int size = basics.size();
            System.out.println(size);

            basics = transformBasics(SimpleMatching::mergeUnconditionalBasic, basics);
            basics = transformBasics(SimpleMatching::mergeEmptyBasic, basics);
            basics = transformBasics(BooleanMatching::shrinkIf, basics);
            basics = transformBasics(BooleanMatching::shrinkIfElse, basics);
            basics = transformBasics(bind(LoopMatching::shrinkSelfLoops,
                    BasicLoopInfo.build(basics, BasicDominatorInfo.build(basics)))::apply, basics);
            basics = transformBasics(bind(TestWhatever::shrinkSimpleLoop,
                    BasicLoopInfo.build(basics, BasicDominatorInfo.build(basics)))::apply, basics);
            basics = transformBasics(TestWhatever::collapseIfOr, basics);

            if (basics.size() == size) {
                break;
            }
        }

        showBasics(function.getName(), basics, true);

        // getLoopContents(basics, Name.of("node155"));

        //        Map<Name, Name> loopParents = new HashMap<>();
        //        {
        //            Stack<Name> rootStack = new Stack<>();
        //            rootStack.push(null);
        //            Set<Name> set = new HashSet<>(loops);
        //            while (!rootStack.isEmpty()) {
        //                Name parent = rootStack.pop();
        //                if (parent != null) {
        //                    set.remove(parent);
        //                }
        //                Set<Name> rootLoops = set.stream().filter(loop -> !set.stream().anyMatch(otherLoop -> !otherLoop.equals(loop) && loopContents.get(otherLoop).contains(loop))).collect(toImmutableSet());
        //                for (Name rootLoop : rootLoops) {
        //                    loopParents.put(rootLoop, parent);
        //                    rootStack.push(rootLoop);
        //                }
        //            }
        //            checkState(set.isEmpty());
        //        }

        //        if (basics.size() > 3 && basics.size() < 8) {
        //        showBasics(function.getName(), basics, true);
        //        }

        //        cfgStackify(basics);

        //        DirectedGraph<Integer, DefaultEdge> graph = new DefaultDirectedGraph<>(DefaultEdge.class);
        //        nodeIndices.values().forEach(graph::addVertex);
        //        cfg.getEdges().forEach(e -> graph.addEdge(nodeIndices.get(e.getInput()), nodeIndices.get(e.getOutput())));

        //        DominatorTree<Integer, DefaultEdge> dt = new DominatorTree<>(graph, 0);
    }

    @Test
    public void testFoo() throws Throwable {
        String target;
        target = "sqlite3VdbeExec";
        //        target = "yy_reduce";
        target = "sqlite3VXPrintf";
        //        target = "sqlite3Error";
        //        target = "sqlite3WalCheckpoint";
        //        target = "sqlite3StatusAdd";
        //        target = "sqlite3_compileoption_used";

        Module module;
        Element root = (new Parser(Input.ofResource("new/sqlite3.wast"))).parse();
        YModule ymodule = new ModuleFactory(root).create();
        module = UnitTranslation.translateModule(Name.of("HelloWorld"), ymodule);

        doFunction(module.getFunctions().get(Name.of(target)));
        //        module.getFunctions().values().forEach(propagatingConsumer(TestWhatever::doFunction));
    }
}