Java tutorial
/* * Copyright (c) 2013-2015 Massachusetts Institute of Technology * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ package edu.mit.streamjit.impl.compiler2; import static com.google.common.base.Preconditions.*; import com.google.common.collect.DiscreteDomain; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSortedSet; import com.google.common.collect.ImmutableTable; import com.google.common.collect.Iterables; import com.google.common.collect.Maps; import com.google.common.collect.Range; import com.google.common.collect.Sets; import edu.mit.streamjit.api.Filter; import edu.mit.streamjit.api.Joiner; import edu.mit.streamjit.api.Splitter; import edu.mit.streamjit.util.bytecode.methodhandles.Combinators; import static edu.mit.streamjit.util.bytecode.methodhandles.LookupUtils.findStatic; import edu.mit.streamjit.util.bytecode.methodhandles.ProxyFactory; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.BiFunction; import java.util.stream.IntStream; /** * Compiler IR for a fused group of workers (what used to be called StreamNode). * @author Jeffrey Bosboom <jbosboom@csail.mit.edu> * @since 9/22/2013 */ public class ActorGroup implements Comparable<ActorGroup> { private ImmutableSortedSet<Actor> actors; private ImmutableMap<Actor, Integer> schedule; private ActorGroup(ImmutableSortedSet<Actor> actors) { this.actors = actors; for (Actor a : actors) a.setGroup(this); } public static ActorGroup of(Actor actor) { assert actor.group() == null : actor.group(); return new ActorGroup(ImmutableSortedSet.of(actor)); } public static ActorGroup fuse(ActorGroup first, ActorGroup second) { return new ActorGroup( ImmutableSortedSet.<Actor>naturalOrder().addAll(first.actors()).addAll(second.actors()).build()); } public void remove(Actor a) { assert actors.contains(a) : a; actors = ImmutableSortedSet.copyOf(Sets.difference(actors, ImmutableSet.of(a))); schedule = ImmutableMap.copyOf(Maps.difference(schedule, ImmutableMap.of(a, 0)).entriesOnlyOnLeft()); } public ImmutableSet<Actor> actors() { return actors; } public boolean isTokenGroup() { for (Actor a : actors()) if (a instanceof TokenActor) return true; return false; } public int id() { return Collections.min(actors()).id(); } public boolean isPeeking() { for (Actor a : actors()) if (a.isPeeking()) return true; return false; } public boolean isStateful() { for (Actor a : actors()) if (a instanceof WorkerActor && ((WorkerActor) a).archetype().isStateful()) return true; return false; } public Set<Storage> inputs() { ImmutableSet.Builder<Storage> builder = ImmutableSet.builder(); for (Actor a : actors()) for (Storage s : a.inputs()) if (!s.isInternal()) builder.add(s); return builder.build(); } public Set<Storage> outputs() { ImmutableSet.Builder<Storage> builder = ImmutableSet.builder(); for (Actor a : actors()) for (Storage s : a.outputs()) if (!s.isInternal()) builder.add(s); return builder.build(); } public Set<Storage> internalEdges() { return Sets.filter(allEdges(), Storage::isInternal); } private Set<Storage> allEdges() { ImmutableSet.Builder<Storage> builder = ImmutableSet.builder(); for (Actor a : actors) { builder.addAll(a.inputs()); builder.addAll(a.outputs()); } return builder.build(); } public Set<ActorGroup> predecessorGroups() { ImmutableSet.Builder<ActorGroup> builder = ImmutableSet.builder(); for (Actor a : actors) for (Storage s : a.inputs()) for (Actor b : s.upstream()) if (b.group() != this) builder.add(b.group()); return builder.build(); } public Set<ActorGroup> successorGroups() { ImmutableSet.Builder<ActorGroup> builder = ImmutableSet.builder(); for (Actor a : actors) for (Storage s : a.outputs()) for (Actor b : s.downstream()) if (b.group() != this) builder.add(b.group()); return builder.build(); } public ImmutableMap<Actor, Integer> schedule() { checkState(schedule != null, "schedule not yet initialized"); return schedule; } public void setSchedule(ImmutableMap<Actor, Integer> schedule) { checkState(this.schedule == null, "already initialized schedule"); for (Actor a : actors()) checkArgument(schedule.containsKey(a), "schedule doesn't contain actor " + a); this.schedule = schedule; } /** * Returns the physical indices read from the given storage during the given * group iteration. * @param s the storage being read from * @param iteration the group iteration number * @return the physical indices read */ public ImmutableSortedSet<Integer> reads(Storage s, int iteration) { ImmutableSortedSet.Builder<Integer> builder = ImmutableSortedSet.naturalOrder(); for (Actor a : actors()) builder.addAll( a.reads(s, Range.closedOpen(iteration * schedule.get(a), (iteration + 1) * schedule.get(a)))); return builder.build(); } /** * Returns the physical indices read from the given storage during the given * group iterations. * @param s the storage being read from * @param iterations the group iterations * @return the physical indices read */ public ImmutableSortedSet<Integer> reads(Storage s, Range<Integer> iterations) { iterations = iterations.canonical(DiscreteDomain.integers()); ImmutableSortedSet.Builder<Integer> builder = ImmutableSortedSet.naturalOrder(); for (Actor a : actors()) builder.addAll(a.reads(s, Range.closedOpen(iterations.lowerEndpoint() * schedule.get(a), iterations.upperEndpoint() * schedule.get(a)))); return builder.build(); } /** * Returns a map mapping each input Storage to the set of physical indices * read in that Storage during the given ActorGroup iteration. * @param iteration the iteration to simulate * @return a map of read physical indices */ public ImmutableMap<Storage, ImmutableSortedSet<Integer>> reads(final int iteration) { return Maps.toMap(inputs(), (Storage input) -> reads(input, iteration)); } /** * Returns the physical indices written to the given storage during the * given group iteration. * @param s the storage being written to * @param iteration the group iteration number * @return the physical indices written */ public ImmutableSortedSet<Integer> writes(Storage s, int iteration) { ImmutableSortedSet.Builder<Integer> builder = ImmutableSortedSet.naturalOrder(); for (Actor a : actors()) builder.addAll( a.writes(s, Range.closedOpen(iteration * schedule.get(a), (iteration + 1) * schedule.get(a)))); return builder.build(); } /** * Returns the physical indices written to the given storage during the * given group iteration. * @param s the storage being written to * @param iterations the group iterations * @return the physical indices written */ public ImmutableSortedSet<Integer> writes(Storage s, Range<Integer> iterations) { ImmutableSortedSet.Builder<Integer> builder = ImmutableSortedSet.naturalOrder(); for (Actor a : actors()) builder.addAll(a.writes(s, Range.closedOpen(iterations.lowerEndpoint() * schedule.get(a), iterations.upperEndpoint() * schedule.get(a)))); return builder.build(); } /** * Returns a map mapping each output Storage to the set of physical indices * written in that Storage during the given ActorGroup iteration. * @param iteration the iteration to simulate * @return a map of written physical indices */ public ImmutableMap<Storage, ImmutableSortedSet<Integer>> writes(final int iteration) { return Maps.toMap(outputs(), (Storage output) -> writes(output, iteration)); } /** * Returns a void->void MethodHandle that will run this ActorGroup for the * given iterations using the given ConcreteStorage instances. * @param iterations the range of iterations to run for * @param storage the storage being used * @return a void->void method handle */ public MethodHandle specialize(Range<Integer> iterations, Map<Storage, ConcreteStorage> storage, BiFunction<MethodHandle[], WorkerActor, MethodHandle> switchFactory, int unrollFactor, ImmutableTable<Actor, Integer, IndexFunctionTransformer> inputTransformers, ImmutableTable<Actor, Integer, IndexFunctionTransformer> outputTransformers) { //TokenActors are special. assert !isTokenGroup() : actors(); Map<Actor, MethodHandle> withRWHandlesBound = bindActorsToStorage(iterations, storage, switchFactory, inputTransformers, outputTransformers); int totalIterations = iterations.upperEndpoint() - iterations.lowerEndpoint(); unrollFactor = Math.min(unrollFactor, totalIterations); int unrolls = (totalIterations / unrollFactor); int unrollEndpoint = iterations.lowerEndpoint() + unrolls * unrollFactor; MethodHandle overall = Combinators.semicolon( makeGroupLoop(Range.closedOpen(iterations.lowerEndpoint(), unrollEndpoint), unrollFactor, withRWHandlesBound), makeGroupLoop(Range.closedOpen(unrollEndpoint, iterations.upperEndpoint()), 1, withRWHandlesBound)); return overall; } /** * Compute the read and write method handles for each Actor. These don't * depend on the iteration, so we can bind and reuse them. */ private Map<Actor, MethodHandle> bindActorsToStorage(Range<Integer> iterations, Map<Storage, ConcreteStorage> storage, BiFunction<MethodHandle[], WorkerActor, MethodHandle> switchFactory, ImmutableTable<Actor, Integer, IndexFunctionTransformer> inputTransformers, ImmutableTable<Actor, Integer, IndexFunctionTransformer> outputTransformers) { Map<Actor, MethodHandle> withRWHandlesBound = new HashMap<>(); for (Actor a : actors()) { WorkerActor wa = (WorkerActor) a; MethodHandle specialized = wa.archetype().specialize(wa); assert a.inputs().size() > 0 : a; MethodType readHandleType = MethodType.methodType(wa.inputType().getRawType(), int.class); MethodHandle read; if (wa.worker() instanceof Joiner) { MethodHandle[] table = new MethodHandle[a.inputs().size()]; IntStream.range(0, a.inputs().size()).forEachOrdered(i -> table[i] = MethodHandles .filterArguments(storage.get(a.inputs().get(i)).readHandle(), 0, inputTransformers.get(a, i) .transform(a.inputIndexFunctions().get(i).asHandle(), () -> a.peeks(i, iterations))) .asType(readHandleType)); read = switchFactory.apply(table, wa); } else read = MethodHandles .filterArguments(storage.get(a.inputs().get(0)).readHandle(), 0, inputTransformers.get(a, 0) .transform(a.inputIndexFunctions().get(0).asHandle(), () -> a.peeks(0, iterations))) .asType(readHandleType); assert a.outputs().size() > 0 : a; MethodType writeHandleType = MethodType.methodType(void.class, int.class, wa.outputType().getRawType()); MethodHandle write; if (wa.worker() instanceof Splitter) { MethodHandle[] table = new MethodHandle[a.outputs().size()]; IntStream.range(0, a.outputs().size()).forEachOrdered( i -> table[i] = MethodHandles.filterArguments(storage.get(a.outputs().get(i)).writeHandle(), 0, outputTransformers.get(a, i).transform( a.outputIndexFunctions().get(i).asHandle(), () -> a.pushes(i, iterations))) .asType(writeHandleType)); write = switchFactory.apply(table, wa); } else write = MethodHandles.filterArguments(storage.get(a.outputs().get(0)).writeHandle(), 0, outputTransformers.get(a, 0).transform(a.outputIndexFunctions().get(0).asHandle(), () -> a.pushes(0, iterations))) .asType(writeHandleType); withRWHandlesBound.put(wa, specialized.bindTo(read).bindTo(write)); } return withRWHandlesBound; } /** * Make loop handles for each Actor that execute the iteration given as * an argument, then bind them together in an outer loop body that * executes all the iterations. Before the outer loop we must also * reinitialize the splitter/joiner index arrays to their initial * values. */ private MethodHandle makeGroupLoop(Range<Integer> iterations, int unrollFactor, Map<Actor, MethodHandle> withRWHandlesBound) { if (iterations.isEmpty()) return Combinators.nop(); List<MethodHandle> loopHandles = new ArrayList<>(actors().size()); Map<int[], int[]> requiredCopies = new LinkedHashMap<>(); for (Actor a : actors()) loopHandles.add(makeWorkerLoop((WorkerActor) a, withRWHandlesBound.get(a), unrollFactor, iterations.lowerEndpoint(), requiredCopies)); MethodHandle groupLoop = MethodHandles.insertArguments(OVERALL_GROUP_LOOP, 0, Combinators.semicolon(loopHandles), iterations.lowerEndpoint(), iterations.upperEndpoint(), unrollFactor); if (!requiredCopies.isEmpty()) { int[][] copies = new int[requiredCopies.size() * 2][]; int i = 0; for (Map.Entry<int[], int[]> e : requiredCopies.entrySet()) { copies[i++] = e.getKey(); copies[i++] = e.getValue(); } groupLoop = Combinators .semicolon(MethodHandles.insertArguments(REINITIALIZE_ARRAYS, 0, (Object) copies), groupLoop); } return groupLoop; } /** * Makes the loop for the given actor, which implements group executions * based on the unroll factor. * @param a the actor * @param base the specialized work method with read/write handles bound; * takes two int or int[] parameters * @param unrollFactor the number of group iterations to execute * @param firstIteration the first iteration to execute, for computing the * initial contents of index arrays * @param requiredCopies accumulates the copies required to reinitialize the * index arrays * @return a MethodHandle taking one int parameter */ private MethodHandle makeWorkerLoop(WorkerActor a, MethodHandle base, int unrollFactor, int firstIteration, Map<int[], int[]> requiredCopies) { int subiterations = schedule.get(a); Object pop, push; if (base.type().parameterType(0).equals(int.class)) { assert a.inputs().size() == 1; pop = a.pop(0).max(); } else { int[] readIndices = new int[a.inputs().size()]; for (int m = 0; m < a.inputs().size(); ++m) readIndices[m] = firstIteration * subiterations * a.pop(m).max(); pop = readIndices.clone(); requiredCopies.put(readIndices, (int[]) pop); } if (base.type().parameterType(1).equals(int.class)) { assert a.outputs().size() == 1; push = a.push(0).max(); } else { int[] writeIndices = new int[a.outputs().size()]; for (int m = 0; m < a.outputs().size(); ++m) writeIndices[m] = firstIteration * subiterations * a.push(m).max(); push = writeIndices.clone(); requiredCopies.put(writeIndices, (int[]) push); } MethodHandle loopHandle; if (a.worker() instanceof Filter) loopHandle = FILTER_LOOP; else if (a.worker() instanceof Splitter) loopHandle = SPLITTER_LOOP; else if (a.worker() instanceof Joiner) loopHandle = JOINER_LOOP; else throw new AssertionError(a); return MethodHandles.insertArguments(loopHandle, 0, base, unrollFactor, subiterations, pop, push); } private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup(); private static final MethodHandle FILTER_LOOP = findStatic(LOOKUP, "_filterLoop"); private static final MethodHandle SPLITTER_LOOP = findStatic(LOOKUP, "_splitterLoop"); private static final MethodHandle JOINER_LOOP = findStatic(LOOKUP, "_joinerLoop"); private static final MethodHandle REINITIALIZE_ARRAYS = findStatic(LOOKUP, "_reinitializeArrays"); private static final MethodHandle OVERALL_GROUP_LOOP = findStatic(LOOKUP, "_overallGroupLoop"); private static void _filterLoop(MethodHandle work, int iterations, int subiterations, int pop, int push, int firstIteration) throws Throwable { for (int i = firstIteration * subiterations; i < (firstIteration + iterations) * subiterations; ++i) work.invokeExact(i * pop, i * push); } private static void _splitterLoop(MethodHandle work, int iterations, int subiterations, int pop, int[] writeIndices, int firstIteration) throws Throwable { for (int i = firstIteration * subiterations; i < (firstIteration + iterations) * subiterations; ++i) work.invokeExact(i * pop, writeIndices); } private static void _joinerLoop(MethodHandle work, int iterations, int subiterations, int[] readIndices, int push, int firstIteration) throws Throwable { for (int i = firstIteration * subiterations; i < (firstIteration + iterations) * subiterations; ++i) work.invokeExact(readIndices, i * push); } private static void _reinitializeArrays(int[][] indexArrays) { for (int i = 0; i < indexArrays.length; i += 2) System.arraycopy(indexArrays[i], 0, indexArrays[i + 1], 0, indexArrays[i].length); } private static void _overallGroupLoop(MethodHandle loopBody, int begin, int end, int increment) throws Throwable { for (int i = begin; i < end; i += increment) loopBody.invokeExact(i); } /** * This is inconsistent with equals, but we should never have two distinct * ActorGroup objects with the same id, so we'll never notice the * inconsistency. (We used to have equals() and hashCode() by id, but actor * removal would then change the hash code and screw up maps.) */ @Override public int compareTo(ActorGroup o) { return Integer.compare(id(), o.id()); } @Override public String toString() { return "ActorGroup@" + id() + actors(); } }