com.wrmsr.wava.TestOutlining.java Source code

Java tutorial

Introduction

Here is the source code for com.wrmsr.wava.TestOutlining.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.wrmsr.wava;

import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.wrmsr.wava.analyze.Analyses;
import com.wrmsr.wava.analyze.ControlTransferAnalysis;
import com.wrmsr.wava.analyze.LocalAnalysis;
import com.wrmsr.wava.analyze.ValueTypeAnalysis;
import com.wrmsr.wava.compile.binary.BinaryCompilerImpl;
import com.wrmsr.wava.compile.call.CallCompilerImpl;
import com.wrmsr.wava.compile.call.CallIndirectCompilerImpl;
import com.wrmsr.wava.compile.const_.ConstCompilerImpl;
import com.wrmsr.wava.compile.function.FunctionAccess;
import com.wrmsr.wava.compile.function.FunctionCompilerImpl;
import com.wrmsr.wava.compile.memory.LoadStoreCompilerImpl;
import com.wrmsr.wava.compile.unary.UnaryCompilerImpl;
import com.wrmsr.wava.core.node.Node;
import com.wrmsr.wava.core.node.Switch;
import com.wrmsr.wava.core.node.visitor.Visitor;
import com.wrmsr.wava.core.type.Index;
import com.wrmsr.wava.core.type.Name;
import com.wrmsr.wava.core.type.Type;
import com.wrmsr.wava.core.unit.Function;
import com.wrmsr.wava.core.unit.Local;
import com.wrmsr.wava.core.unit.Locals;
import com.wrmsr.wava.driver.StandardFunctionProcessor;
import com.wrmsr.wava.java.lang.JAccess;
import com.wrmsr.wava.java.lang.JRenderer;
import com.wrmsr.wava.java.lang.tree.declaration.JMethod;
import com.wrmsr.wava.java.poet.CodeBlock;
import com.wrmsr.wava.transform.Outlining;
import com.wrmsr.wava.transform.Transforms;
import com.wrmsr.wava.util.Json;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.Test;

import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Sets.immutableEnumSet;

public class TestOutlining {
    @Test
    public void testOutlining() throws Throwable {
        Path outputFile = Paths.get("tmp/post.json");
        Function function = Json.OBJECT_MAPPER_SUPPLIER.get().readValue(Files.readAllBytes(outputFile),
                Function.class);
        Node body = function.getBody();

        LocalAnalysis loa = LocalAnalysis.analyze(body);
        ControlTransferAnalysis cfa = ControlTransferAnalysis.analyze(body);
        ValueTypeAnalysis vta = ValueTypeAnalysis.analyze(body, false);

        Map<Node, Optional<Node>> parentsByNode = Analyses.getParents(body);
        Map<Node, Integer> totalChildrenByNode = Analyses.getChildCounts(body);
        Map<Name, Node> nodesByName = Analyses.getNamedNodes(body);

        Node maxNode = body;
        int maxDiff = 0;

        Node cur = body;
        while (true) {
            System.out.println(
                    String.format("%s -> %d (%d)", cur, totalChildrenByNode.get(cur), cur.getChildren().size()));
            Optional<Node> maxChild = cur.getChildren().stream()
                    .max((l, r) -> Integer.compare(totalChildrenByNode.get(l), totalChildrenByNode.get(r)));
            if (!maxChild.isPresent()) {
                break;
            }
            int diff = totalChildrenByNode.get(cur) - totalChildrenByNode.get(maxChild.get());
            if (diff > maxDiff) {
                maxNode = cur;
                maxDiff = diff;
            }
            cur = maxChild.get();
        }

        System.out.println();
        System.out.println(maxNode);

        System.out.println();
        List<Node> alsdfj = new ArrayList<>(maxNode.getChildren());
        Collections.sort(alsdfj,
                (l, r) -> -Integer.compare(totalChildrenByNode.get(l), totalChildrenByNode.get(r)));
        for (Node child : alsdfj) {
            System.out.println(String.format("%s -> %d", child, totalChildrenByNode.get(child)));
        }
        System.out.println();

        Index externalRetControl = Index.of(function.getLocals().getList().size());
        Index externalRetValue = Index.of(function.getLocals().getList().size() + 1);
        List<Local> localList = ImmutableList.<Local>builder().addAll(function.getLocals().getList())
                .add(new Local(Name.of("_external$control"), externalRetControl, Type.I32))
                .add(new Local(Name.of("_external$value"), externalRetValue, Type.I64)).build();

        maxNode.accept(new Visitor<Void, Void>() {
            @Override
            protected Void visitNode(Node node, Void context) {
                Outlining.OutlinedFunction of = Outlining.outlineFunction(function, node, Name.of("outlined"),
                        externalRetControl, externalRetValue, loa, cfa, vta, parentsByNode, nodesByName);

                try {
                    compileFunction(of.getFunction());
                } catch (Throwable e) {
                    throw Throwables.propagate(e);
                }

                return null;
            }

            @Override
            public Void visitSwitch(Switch node, Void context) {
                Optional<Switch.Entry> maxEntry = node.getEntries().stream().max((l, r) -> Integer
                        .compare(totalChildrenByNode.get(l.getBody()), totalChildrenByNode.get(r.getBody())));
                Node maxNode = maxEntry.get().getBody();

                Outlining.OutlinedFunction of = Outlining.outlineFunction(function, maxNode, Name.of("outlined"),
                        externalRetControl, externalRetValue, loa, cfa, vta, parentsByNode, nodesByName);

                try {
                    compileFunction(of.getFunction());
                } catch (Throwable e) {
                    throw Throwables.propagate(e);
                }

                Function newFunc = new Function(function.getName(), function.getResult(), function.getArgCount(),
                        new Locals(localList),
                        Transforms.replaceNode(function.getBody(), maxNode, of.getCallsite(), true));

                System.out.println();
                try {
                    compileFunction(newFunc);
                } catch (Throwable e) {
                    throw Throwables.propagate(e);
                }

                //                Map<Index, Index> localTranslation

                //                new Function(
                //                        Name.of("laksdjflkad"),
                //                        Type.I32,
                //
                //                        )

                /*
                TODO:
                 - FALLTHROUGH analysis
                 - local index translation
                 - spills in/out
                 - breaks as return codes (!! with break values), and returns as return codes
                  - I64 retval cell
                    
                NEXT:
                 - vta for non switch-cases, with inline value returns if no breaks
                 - pre-alloc cells? would need to kid-glove return temp loading
                 - sp based retvals (setjmp/exceptions gon fuck my day up?)
                  - oh fuck. shadowstack? :/
                   - NO, no this is doable. pushed ONLY immediately before ret, popped ALWAYS immediately after ret, stack remains same during execution
                */

                //                TempManager tm = new TempManager(
                //                        new NameGenerator(
                //                                function.getLocals().getLocals().stream().map(Local::getName).collect(toImmutableSet()),
                //                                "_temp$"),
                //                        Index.of(function.getLocals().getLocals().size()),
                //                        false);
                //
                //                Locals locals = new Locals(Stream.concat(function.getLocals().getLocals().stream(), tm.getTempList().stream().map(t -> new Local(t.getName(), t.getIndex(), t.getType()))).collect(toImmutableList()));
                //                function = new Function(
                //                        NameMangler.DEFAULT.mangleName(function.getName()),
                //                        function.getResult(),
                //                        function.getArgCount(),
                //                        locals,
                //                        body);

                return null;
            }
        }, null);
    }

    private static void compileFunction(Function function) throws Throwable {
        function = new StandardFunctionProcessor().processFunction(function);
        JMethod method = getOnlyElement(
                new FunctionCompilerImpl(new FunctionAccess(immutableEnumSet(JAccess.PUBLIC, JAccess.FINAL)),
                        new BinaryCompilerImpl(), new CallCompilerImpl(), new CallIndirectCompilerImpl(),
                        new ConstCompilerImpl(), new LoadStoreCompilerImpl(), new UnaryCompilerImpl())
                                .compileFunction(function));

        CodeBlock.Builder code = CodeBlock.builder();
        new JRenderer(code).renderDeclaration(method);
        CodeBlock block = code.build();

        String str = JRenderer.renderWithIndent(block, "    ");
        System.out.println(str);
    }

    public static Pair<Function, Function> inlineThatOneSwitch(Function function, int num) {
        Node body = function.getBody();

        LocalAnalysis loa = LocalAnalysis.analyze(body);
        ControlTransferAnalysis cfa = ControlTransferAnalysis.analyze(body);
        ValueTypeAnalysis vta = ValueTypeAnalysis.analyze(body, false);

        Map<Node, Optional<Node>> parentsByNode = Analyses.getParents(body);
        Map<Node, Integer> totalChildrenByNode = Analyses.getChildCounts(body);
        Map<Name, Node> nodesByName = Analyses.getNamedNodes(body);

        Node maxNode = body;
        int maxDiff = 0;

        Node cur = body;
        while (true) {
            Optional<Node> maxChild = cur.getChildren().stream()
                    .max((l, r) -> Integer.compare(totalChildrenByNode.get(l), totalChildrenByNode.get(r)));
            if (!maxChild.isPresent()) {
                break;
            }
            int diff = totalChildrenByNode.get(cur) - totalChildrenByNode.get(maxChild.get());
            if (diff > maxDiff) {
                maxNode = cur;
                maxDiff = diff;
            }
            cur = maxChild.get();
        }

        List<Node> alsdfj = new ArrayList<>(maxNode.getChildren());
        Collections.sort(alsdfj,
                (l, r) -> -Integer.compare(totalChildrenByNode.get(l), totalChildrenByNode.get(r)));

        Index externalRetControl;
        Index externalRetValue;
        List<Local> localList;

        if (function.getLocals().getLocalsByName().containsKey(Name.of("_external$control"))) {
            externalRetControl = function.getLocals().getLocal(Name.of("_external$control")).getIndex();
            externalRetValue = function.getLocals().getLocal(Name.of("_external$value")).getIndex();
            localList = function.getLocals().getList();
        } else {
            externalRetControl = Index.of(function.getLocals().getList().size());
            externalRetValue = Index.of(function.getLocals().getList().size() + 1);
            localList = ImmutableList.<Local>builder().addAll(function.getLocals().getList())
                    .add(new Local(Name.of("_external$control"), externalRetControl, Type.I32))
                    .add(new Local(Name.of("_external$value"), externalRetValue, Type.I64)).build();
        }

        Node node = maxNode;
        if (maxNode instanceof Switch) {
            Switch switchNode = (Switch) node;
            Optional<Switch.Entry> maxEntry = switchNode.getEntries().stream().max((l, r) -> Integer
                    .compare(totalChildrenByNode.get(l.getBody()), totalChildrenByNode.get(r.getBody())));
            node = maxEntry.get().getBody();
        }

        Outlining.OutlinedFunction of = Outlining.outlineFunction(function, node,
                Name.of(function.getName().get() + "$outlined$" + num), externalRetControl, externalRetValue, loa,
                cfa, vta, parentsByNode, nodesByName);

        Function newFunc = new Function(function.getName(), function.getResult(), function.getArgCount(),
                new Locals(localList), Transforms.replaceNode(function.getBody(), node, of.getCallsite(), true));

        return ImmutablePair.of(newFunc, of.getFunction());
    }
}