com.wrmsr.wava.transform.Outlining.java Source code

Java tutorial

Introduction

Here is the source code for com.wrmsr.wava.transform.Outlining.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.wrmsr.wava.transform;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.wrmsr.wava.analyze.ControlTransferAnalysis;
import com.wrmsr.wava.analyze.LocalAnalysis;
import com.wrmsr.wava.analyze.ValueTypeAnalysis;
import com.wrmsr.wava.core.literal.Literal;
import com.wrmsr.wava.core.node.Block;
import com.wrmsr.wava.core.node.Break;
import com.wrmsr.wava.core.node.BreakTable;
import com.wrmsr.wava.core.node.Call;
import com.wrmsr.wava.core.node.Const;
import com.wrmsr.wava.core.node.GetLocal;
import com.wrmsr.wava.core.node.Label;
import com.wrmsr.wava.core.node.Node;
import com.wrmsr.wava.core.node.Nop;
import com.wrmsr.wava.core.node.Return;
import com.wrmsr.wava.core.node.SetLocal;
import com.wrmsr.wava.core.node.Switch;
import com.wrmsr.wava.core.node.Unreachable;
import com.wrmsr.wava.core.node.visitor.Visitor;
import com.wrmsr.wava.core.op.HostOp;
import com.wrmsr.wava.core.type.Index;
import com.wrmsr.wava.core.type.Name;
import com.wrmsr.wava.core.type.Type;
import com.wrmsr.wava.core.unit.Function;
import com.wrmsr.wava.core.unit.Local;
import com.wrmsr.wava.core.unit.Locals;
import org.organicdesign.fp.collections.ImSet;

import javax.annotation.concurrent.Immutable;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkState;
import static com.wrmsr.wava.core.node.Nodes.packI64;
import static com.wrmsr.wava.core.node.Nodes.reconstructNode;
import static com.wrmsr.wava.core.node.Nodes.unpackI64;
import static com.wrmsr.wava.util.Itertools.enumerate;
import static com.wrmsr.wava.util.collect.MoreCollectors.toImmutableList;
import static com.wrmsr.wava.util.collect.MoreCollectors.toImmutableMap;
import static java.util.Objects.requireNonNull;

public final class Outlining {
    private Outlining() {
    }

    @Immutable
    public static final class OutlinedFunction {
        private final Node callsite;
        private final Function function;

        public OutlinedFunction(Node callsite, Function function) {
            this.callsite = requireNonNull(callsite);
            this.function = requireNonNull(function);
        }

        public Node getCallsite() {
            return callsite;
        }

        public Function getFunction() {
            return function;
        }
    }

    public static OutlinedFunction outlineFunction(Function function, Node node, Name outlinedName,
            Index externalRetControl, Index externalRetValue, LocalAnalysis loa, ControlTransferAnalysis cfa,
            ValueTypeAnalysis vta, Map<Node, Optional<Node>> parentsByNode, Map<Name, Node> nodesByName) {
        final List<Index> allLocals;
        final List<Index> spilledOutLocalPuts;
        final List<Index> spilledInLocalGets;
        {
            ImSet<Index> entryLocalGets = loa.get(node).getLocalGets();
            ImSet<Index> entryLocalPuts = loa.get(node).getLocalPuts();

            ImSet<Index> nonEntryLocalGets = LocalAnalysis.EMPTY_LOCALS;
            ImSet<Index> nonEntryLocalPuts = LocalAnalysis.EMPTY_LOCALS;
            Optional<Node> cur = Optional.of(node);
            while (true) {
                Optional<Node> next = requireNonNull(parentsByNode.get(cur.get()));
                if (!next.isPresent()) {
                    break;
                }
                for (Node sibling : next.get().getChildren()) {
                    if (sibling == cur.get()) {
                        continue;
                    }
                    nonEntryLocalGets = nonEntryLocalGets.union(loa.get(sibling).getLocalGets());
                    nonEntryLocalPuts = nonEntryLocalPuts.union(loa.get(sibling).getLocalPuts());
                }
                cur = next;
            }

            Set<Index> entryLocals = Sets.union(entryLocalGets, entryLocalPuts);
            Set<Index> nonEntryLocals = Sets.union(nonEntryLocalGets, nonEntryLocalPuts);
            Set<Index> entryOnlyLocals = Sets.difference(entryLocals, nonEntryLocals);

            Set<Index> spilledOutLocalPutsSet = Sets.intersection(entryLocalPuts, nonEntryLocalGets);
            spilledOutLocalPuts = spilledOutLocalPutsSet.stream().sorted().collect(toImmutableList());
            spilledInLocalGets = Sets.intersection(nonEntryLocalPuts, entryLocalGets).stream()
                    .filter(i -> !spilledOutLocalPutsSet.contains(i)).sorted().collect(toImmutableList());

            List<Index> entryOnlyLocalList = entryOnlyLocals.stream().sorted().collect(toImmutableList());
            checkState(
                    Sets.intersection(ImmutableSet.copyOf(spilledInLocalGets), ImmutableSet.copyOf(entryOnlyLocals))
                            .isEmpty());

            allLocals = ImmutableList.<Index>builder().addAll(spilledOutLocalPuts).addAll(spilledInLocalGets)
                    .addAll(entryOnlyLocalList).build();
        }

        Map<Index, Index> localTranslationMap = enumerate(allLocals.stream())
                .collect(toImmutableMap(i -> i.getItem(), i -> Index.of(i.getIndex())));

        ControlTransferAnalysis.Entry maxCfa = cfa.get(node);
        List<ControlTransferAnalysis.Target> targets = ImmutableList.copyOf(maxCfa.getTargets());
        Map<ControlTransferAnalysis.Target, Name> targetNameMap = enumerate(targets.stream())
                .collect(toImmutableMap(i -> i.getItem(), i -> Name.of("_epilog$" + i.getIndex())));

        ImmutableMap.Builder<ControlTransferAnalysis.Target, Type> targetTypesBuilder = ImmutableMap.builder();
        for (ControlTransferAnalysis.Target target : targets) {
            if (target.equals(ControlTransferAnalysis.Target.RETURN)) {
                targetTypesBuilder.put(ControlTransferAnalysis.Target.RETURN, function.getResult());
            } else if (target instanceof ControlTransferAnalysis.NameTarget) {
                Name name = ((ControlTransferAnalysis.NameTarget) target).getName();
                Type type = vta.get(nodesByName.get(name)).getType();
                targetTypesBuilder.put(target, type);
            } else {
                throw new IllegalStateException();
            }
        }
        Map<ControlTransferAnalysis.Target, Type> targetTypes = targetTypesBuilder.build();

        Node outlinedBody = node.accept(new Visitor<Void, Node>() {
            @Override
            protected Node visitNode(Node node, Void context) {
                return reconstructNode(node,
                        node.getChildren().stream().map(child -> child.accept(this, context)).iterator());
            }

            @Override
            public Node visitBreak(Break node, Void context) {
                ControlTransferAnalysis.Target target = ControlTransferAnalysis.Target.of(node.getTarget());
                if (targetNameMap.containsKey(target)) {
                    return new Break(targetNameMap.get(ControlTransferAnalysis.Target.of(node.getTarget())),
                            node.getValue().accept(this, context));
                } else {
                    return super.visitBreak(node, context);
                }
            }

            @Override
            public Node visitBreakTable(BreakTable node, Void context) {
                return super.visitBreakTable(node, context);
            }

            @Override
            public Node visitGetLocal(GetLocal node, Void context) {
                return new GetLocal(localTranslationMap.get(node.getIndex()), node.getType());
            }

            @Override
            public Node visitReturn(Return node, Void context) {
                return new Break(targetNameMap.get(ControlTransferAnalysis.Target.RETURN),
                        node.getValue().accept(this, context));
            }

            @Override
            public Node visitSetLocal(SetLocal node, Void context) {
                return new SetLocal(localTranslationMap.get(node.getIndex()), node.getType(),
                        node.getValue().accept(this, context));
            }
        }, null);

        Index internalRetControl = Index.of(allLocals.size());
        Index internalRetValue = Index.of(allLocals.size() + 1);

        ValueTypeAnalysis.Entry maxVta = vta.get(node);
        if (maxCfa.getExecution() == ControlTransferAnalysis.Execution.FALLTHROUGH) {
            if (maxVta.getType() != Type.NONE) {
                outlinedBody = new Block(ImmutableList.of(
                        new SetLocal(internalRetValue, Type.I64, packI64(outlinedBody, maxVta.getType())),
                        new Break(Name.of("_epilog$"), new Nop())));
            } else {
                outlinedBody = new Block(ImmutableList.of(outlinedBody, new Break(Name.of("_epilog$"), new Nop())));
            }
        }

        for (int i = 0; i < targets.size(); ++i) {
            ControlTransferAnalysis.Target target = targets.get(i);
            Name name = targetNameMap.get(target);
            Type type = targetTypes.get(target);

            outlinedBody = new Label(name, outlinedBody);

            switch (type) {
            case NONE:
                break;
            case I32:
            case I64:
            case F32:
            case F64:
                outlinedBody = new SetLocal(internalRetValue, Type.I64, packI64(outlinedBody, type));
                break;
            default:
                throw new IllegalStateException();
            }

            outlinedBody = new Block(ImmutableList.of(new Block(ImmutableList.of(outlinedBody,
                    new SetLocal(internalRetControl, Type.I32, new Const(Literal.of(i))),
                    new Break(Name.of("_epilog$"), new Nop())))));
        }

        Node returnValueSpiller = new Call(new Call.HostTarget(HostOp.SpillPut, Optional.empty()),
                HostOp.SpillPut.getSignature(),
                ImmutableList.of(new Const(Literal.of(0)), new GetLocal(internalRetValue, Type.I64)));

        List<Node> localSpillers = enumerate(spilledOutLocalPuts.stream()).map(i -> {
            Type type = function.getLocals().getLocal(i.getItem()).getType();
            return new Call(new Call.HostTarget(HostOp.SpillPut, Optional.empty()), HostOp.SpillPut.getSignature(),
                    ImmutableList.of(new Const(Literal.of(i.getIndex() + 1)),
                            packI64(new GetLocal(localTranslationMap.get(i.getItem()), type), type)));
        }).collect(toImmutableList());

        outlinedBody = new Block(ImmutableList.of(
                new Label(Name.of("_epilog$"), outlinedBody), new Block(Stream
                        .concat(Stream.of(returnValueSpiller), localSpillers.stream()).collect(toImmutableList())),
                new Return(new GetLocal(internalRetControl, Type.I32))));

        List<Local> localList = ImmutableList.<Local>builder().addAll(allLocals.stream().map(l -> {
            Local o = function.getLocals().getLocal(l);
            return new Local(o.getName(), localTranslationMap.get(l), o.getType());
        }).iterator()).add(new Local(Name.of("_internal$control"), internalRetControl, Type.I32))
                .add(new Local(Name.of("_internal$value"), internalRetValue, Type.I64)).build();

        int argCount = spilledOutLocalPuts.size() + spilledInLocalGets.size();

        Function outlinedFunction = new Function(outlinedName, Type.I32, argCount, new Locals(localList),
                outlinedBody);

        Node outlinedCall = new SetLocal(externalRetControl, Type.I32,
                new Call(new Call.DirectTarget(outlinedName), outlinedFunction.getSignature(),
                        Stream.concat(spilledOutLocalPuts.stream(), spilledInLocalGets.stream())
                                .map(i -> new GetLocal(i, function.getLocals().getLocal(i).getType()))
                                .collect(toImmutableList())));

        Node returnValueUnspiller = new SetLocal(externalRetValue, Type.I64,
                new Call(new Call.HostTarget(HostOp.SpillGet, Optional.empty()), HostOp.SpillGet.getSignature(),
                        ImmutableList.of(new Const(Literal.of(0)))));

        List<Node> localUnspillers = enumerate(spilledOutLocalPuts.stream()).map(i -> {
            Type type = function.getLocals().getLocal(i.getItem()).getType();
            return new SetLocal(i.getItem(), type,
                    unpackI64(new Call(new Call.HostTarget(HostOp.SpillGet, Optional.empty()),
                            HostOp.SpillGet.getSignature(),
                            ImmutableList.of(new Const(Literal.of(i.getIndex() + 1)))), type));
        }).collect(toImmutableList());

        Node controlSwitch = new Switch(new GetLocal(externalRetControl, Type.I32),
                Stream.concat(enumerate(targets.stream()).map(i -> {
                    if (i.getItem().equals(ControlTransferAnalysis.Target.RETURN)) {
                        return new Switch.Entry(ImmutableList.of(Switch.Value.of(i.getIndex() + 1)), new Return(
                                unpackI64(new GetLocal(externalRetValue, Type.I64), function.getResult())));
                    } else if (i.getItem() instanceof ControlTransferAnalysis.NameTarget) {
                        Name name = ((ControlTransferAnalysis.NameTarget) i.getItem()).getName();
                        Type type = vta.get(nodesByName.get(name)).getType();
                        Node value = type != Type.NONE
                                ? unpackI64(new GetLocal(externalRetValue, Type.I64), function.getResult())
                                : new Nop();
                        return new Switch.Entry(ImmutableList.of(Switch.Value.of(i.getIndex() + 1)),
                                new Break(name, value));
                    } else {
                        throw new IllegalStateException();
                    }
                }), Stream.concat(
                        Stream.of(new Switch.Entry(ImmutableList.of(Switch.Value.DEFAULT), new Unreachable())),
                        maxCfa.getExecution() == ControlTransferAnalysis.Execution.FALLTHROUGH
                                ? Stream.of(new Switch.Entry(ImmutableList.of(Switch.Value.of(0)), new Nop()))
                                : Stream.empty()))
                        .collect(toImmutableList()));

        Node callsite = new Block(ImmutableList.<Node>builder().add(outlinedCall).add(returnValueUnspiller)
                .addAll(localUnspillers).add(controlSwitch)
                .add(maxCfa.getExecution() == ControlTransferAnalysis.Execution.FALLTHROUGH
                        ? unpackI64(new GetLocal(externalRetControl, Type.I64), maxVta.getType())
                        : new Nop())
                .build());

        return new OutlinedFunction(callsite, outlinedFunction);
    }
}