Java tutorial
/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.prestosql.operator; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.primitives.Ints; import io.prestosql.memory.context.LocalMemoryContext; import io.prestosql.operator.window.FramedWindowFunction; import io.prestosql.operator.window.WindowPartition; import io.prestosql.spi.Page; import io.prestosql.spi.PageBuilder; import io.prestosql.spi.block.Block; import io.prestosql.spi.block.SortOrder; import io.prestosql.spi.type.Type; import io.prestosql.sql.planner.plan.PlanNodeId; import java.util.List; import java.util.Optional; import java.util.OptionalInt; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiPredicate; import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkPositionIndex; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.concat; import static io.prestosql.spi.block.SortOrder.ASC_NULLS_LAST; import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; public class WindowOperator implements Operator { public static class WindowOperatorFactory implements OperatorFactory { private final int operatorId; private final PlanNodeId planNodeId; private final List<Type> sourceTypes; private final List<Integer> outputChannels; private final List<WindowFunctionDefinition> windowFunctionDefinitions; private final List<Integer> partitionChannels; private final List<Integer> preGroupedChannels; private final List<Integer> sortChannels; private final List<SortOrder> sortOrder; private final int preSortedChannelPrefix; private final int expectedPositions; private boolean closed; private final PagesIndex.Factory pagesIndexFactory; public WindowOperatorFactory(int operatorId, PlanNodeId planNodeId, List<? extends Type> sourceTypes, List<Integer> outputChannels, List<WindowFunctionDefinition> windowFunctionDefinitions, List<Integer> partitionChannels, List<Integer> preGroupedChannels, List<Integer> sortChannels, List<SortOrder> sortOrder, int preSortedChannelPrefix, int expectedPositions, PagesIndex.Factory pagesIndexFactory) { requireNonNull(sourceTypes, "sourceTypes is null"); requireNonNull(planNodeId, "planNodeId is null"); requireNonNull(outputChannels, "outputChannels is null"); requireNonNull(windowFunctionDefinitions, "windowFunctionDefinitions is null"); requireNonNull(partitionChannels, "partitionChannels is null"); requireNonNull(preGroupedChannels, "preGroupedChannels is null"); checkArgument(partitionChannels.containsAll(preGroupedChannels), "preGroupedChannels must be a subset of partitionChannels"); requireNonNull(sortChannels, "sortChannels is null"); requireNonNull(sortOrder, "sortOrder is null"); requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); checkArgument(sortChannels.size() == sortOrder.size(), "Must have same number of sort channels as sort orders"); checkArgument(preSortedChannelPrefix <= sortChannels.size(), "Cannot have more pre-sorted channels than specified sorted channels"); checkArgument( preSortedChannelPrefix == 0 || ImmutableSet.copyOf(preGroupedChannels) .equals(ImmutableSet.copyOf(partitionChannels)), "preSortedChannelPrefix can only be greater than zero if all partition channels are pre-grouped"); this.pagesIndexFactory = pagesIndexFactory; this.operatorId = operatorId; this.planNodeId = planNodeId; this.sourceTypes = ImmutableList.copyOf(sourceTypes); this.outputChannels = ImmutableList.copyOf(outputChannels); this.windowFunctionDefinitions = ImmutableList.copyOf(windowFunctionDefinitions); this.partitionChannels = ImmutableList.copyOf(partitionChannels); this.preGroupedChannels = ImmutableList.copyOf(preGroupedChannels); this.sortChannels = ImmutableList.copyOf(sortChannels); this.sortOrder = ImmutableList.copyOf(sortOrder); this.preSortedChannelPrefix = preSortedChannelPrefix; this.expectedPositions = expectedPositions; } @Override public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, WindowOperator.class.getSimpleName()); return new WindowOperator(operatorContext, sourceTypes, outputChannels, windowFunctionDefinitions, partitionChannels, preGroupedChannels, sortChannels, sortOrder, preSortedChannelPrefix, expectedPositions, pagesIndexFactory); } @Override public void noMoreOperators() { closed = true; } @Override public OperatorFactory duplicate() { return new WindowOperatorFactory(operatorId, planNodeId, sourceTypes, outputChannels, windowFunctionDefinitions, partitionChannels, preGroupedChannels, sortChannels, sortOrder, preSortedChannelPrefix, expectedPositions, pagesIndexFactory); } } private enum State { NEEDS_INPUT, HAS_OUTPUT, FINISHING, FINISHED } private final OperatorContext operatorContext; private final int[] outputChannels; private final List<FramedWindowFunction> windowFunctions; private final List<Integer> orderChannels; private final List<SortOrder> ordering; private final LocalMemoryContext localUserMemoryContext; private final int[] preGroupedChannels; private final PagesHashStrategy preGroupedPartitionHashStrategy; private final PagesHashStrategy unGroupedPartitionHashStrategy; private final PagesHashStrategy preSortedPartitionHashStrategy; private final PagesHashStrategy peerGroupHashStrategy; private final PagesIndex pagesIndex; private final PageBuilder pageBuilder; private final WindowInfo.DriverWindowInfoBuilder windowInfo; private final AtomicReference<Optional<WindowInfo.DriverWindowInfo>> driverWindowInfo = new AtomicReference<>( Optional.empty()); private State state = State.NEEDS_INPUT; private WindowPartition partition; private Page pendingInput; public WindowOperator(OperatorContext operatorContext, List<Type> sourceTypes, List<Integer> outputChannels, List<WindowFunctionDefinition> windowFunctionDefinitions, List<Integer> partitionChannels, List<Integer> preGroupedChannels, List<Integer> sortChannels, List<SortOrder> sortOrder, int preSortedChannelPrefix, int expectedPositions, PagesIndex.Factory pagesIndexFactory) { requireNonNull(operatorContext, "operatorContext is null"); requireNonNull(outputChannels, "outputChannels is null"); requireNonNull(windowFunctionDefinitions, "windowFunctionDefinitions is null"); requireNonNull(partitionChannels, "partitionChannels is null"); requireNonNull(preGroupedChannels, "preGroupedChannels is null"); checkArgument(partitionChannels.containsAll(preGroupedChannels), "preGroupedChannels must be a subset of partitionChannels"); requireNonNull(sortChannels, "sortChannels is null"); requireNonNull(sortOrder, "sortOrder is null"); requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); checkArgument(sortChannels.size() == sortOrder.size(), "Must have same number of sort channels as sort orders"); checkArgument(preSortedChannelPrefix <= sortChannels.size(), "Cannot have more pre-sorted channels than specified sorted channels"); checkArgument( preSortedChannelPrefix == 0 || ImmutableSet.copyOf(preGroupedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedChannelPrefix can only be greater than zero if all partition channels are pre-grouped"); this.operatorContext = operatorContext; this.localUserMemoryContext = operatorContext.localUserMemoryContext(); this.outputChannels = Ints.toArray(outputChannels); this.windowFunctions = windowFunctionDefinitions.stream() .map(functionDefinition -> new FramedWindowFunction(functionDefinition.createWindowFunction(), functionDefinition.getFrameInfo())) .collect(toImmutableList()); List<Type> types = Stream .concat(outputChannels.stream().map(sourceTypes::get), windowFunctionDefinitions.stream().map(WindowFunctionDefinition::getType)) .collect(toImmutableList()); this.pagesIndex = pagesIndexFactory.newPagesIndex(sourceTypes, expectedPositions); this.preGroupedChannels = Ints.toArray(preGroupedChannels); this.preGroupedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(preGroupedChannels, OptionalInt.empty()); List<Integer> unGroupedPartitionChannels = partitionChannels.stream() .filter(channel -> !preGroupedChannels.contains(channel)).collect(toImmutableList()); this.unGroupedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(unGroupedPartitionChannels, OptionalInt.empty()); List<Integer> preSortedChannels = sortChannels.stream().limit(preSortedChannelPrefix) .collect(toImmutableList()); this.preSortedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(preSortedChannels, OptionalInt.empty()); this.peerGroupHashStrategy = pagesIndex.createPagesHashStrategy(sortChannels, OptionalInt.empty()); this.pageBuilder = new PageBuilder(types); if (preSortedChannelPrefix > 0) { // This already implies that set(preGroupedChannels) == set(partitionChannels) (enforced with checkArgument) this.orderChannels = ImmutableList.copyOf(Iterables.skip(sortChannels, preSortedChannelPrefix)); this.ordering = ImmutableList.copyOf(Iterables.skip(sortOrder, preSortedChannelPrefix)); } else { // Otherwise, we need to sort by the unGroupedPartitionChannels and all original sort channels this.orderChannels = ImmutableList.copyOf(concat(unGroupedPartitionChannels, sortChannels)); this.ordering = ImmutableList .copyOf(concat(nCopies(unGroupedPartitionChannels.size(), ASC_NULLS_LAST), sortOrder)); } windowInfo = new WindowInfo.DriverWindowInfoBuilder(); operatorContext.setInfoSupplier(this::getWindowInfo); } private OperatorInfo getWindowInfo() { return new WindowInfo(driverWindowInfo.get().map(ImmutableList::of).orElse(ImmutableList.of())); } @Override public OperatorContext getOperatorContext() { return operatorContext; } @Override public void finish() { if (state == State.FINISHING || state == State.FINISHED) { return; } if (state == State.NEEDS_INPUT) { // Since was waiting for more input, prepare what we have for output since we will not be getting any more input finishPagesIndex(); } state = State.FINISHING; } @Override public boolean isFinished() { return state == State.FINISHED; } @Override public boolean needsInput() { return state == State.NEEDS_INPUT; } @Override public void addInput(Page page) { checkState(state == State.NEEDS_INPUT, "Operator can not take input at this time"); requireNonNull(page, "page is null"); checkState(pendingInput == null, "Operator already has pending input"); if (page.getPositionCount() == 0) { return; } pendingInput = page; if (processPendingInput()) { state = State.HAS_OUTPUT; } localUserMemoryContext.setBytes(pagesIndex.getEstimatedSize().toBytes()); } /** * @return true if a full group has been buffered after processing the pendingInput, false otherwise */ private boolean processPendingInput() { checkState(pendingInput != null); pendingInput = updatePagesIndex(pendingInput); // If we have unused input or are finishing, then we have buffered a full group if (pendingInput != null || state == State.FINISHING) { finishPagesIndex(); return true; } else { return false; } } /** * @return the unused section of the page, or null if fully applied. * pagesIndex guaranteed to have at least one row after this method returns */ private Page updatePagesIndex(Page page) { checkArgument(page.getPositionCount() > 0); // TODO: Fix pagesHashStrategy to allow specifying channels for comparison, it currently requires us to rearrange the right side blocks in consecutive channel order Page preGroupedPage = rearrangePage(page, preGroupedChannels); if (pagesIndex.getPositionCount() == 0 || pagesIndex.positionEqualsRow(preGroupedPartitionHashStrategy, 0, 0, preGroupedPage)) { // Find the position where the pre-grouped columns change int groupEnd = findGroupEnd(preGroupedPage, preGroupedPartitionHashStrategy, 0); // Add the section of the page that contains values for the current group pagesIndex.addPage(page.getRegion(0, groupEnd)); if (page.getPositionCount() - groupEnd > 0) { // Save the remaining page, which may contain multiple partitions return page.getRegion(groupEnd, page.getPositionCount() - groupEnd); } else { // Page fully consumed return null; } } else { // We had previous results buffered, but the new page starts with new group values return page; } } private static Page rearrangePage(Page page, int[] channels) { Block[] newBlocks = new Block[channels.length]; for (int i = 0; i < channels.length; i++) { newBlocks[i] = page.getBlock(channels[i]); } return new Page(page.getPositionCount(), newBlocks); } @Override public Page getOutput() { if (state == State.NEEDS_INPUT || state == State.FINISHED) { return null; } Page page = extractOutput(); localUserMemoryContext.setBytes(pagesIndex.getEstimatedSize().toBytes()); return page; } private Page extractOutput() { // INVARIANT: pagesIndex contains the full grouped & sorted data for one or more partitions // Iterate through the positions sequentially until we have one full page while (!pageBuilder.isFull()) { if (partition == null || !partition.hasNext()) { int partitionStart = partition == null ? 0 : partition.getPartitionEnd(); if (partitionStart >= pagesIndex.getPositionCount()) { // Finished all of the partitions in the current pagesIndex partition = null; pagesIndex.clear(); // Try to extract more partitions from the pendingInput if (pendingInput != null && processPendingInput()) { partitionStart = 0; } else if (state == State.FINISHING) { state = State.FINISHED; // Output the remaining page if we have anything buffered if (!pageBuilder.isEmpty()) { Page page = pageBuilder.build(); pageBuilder.reset(); return page; } return null; } else { state = State.NEEDS_INPUT; return null; } } int partitionEnd = findGroupEnd(pagesIndex, unGroupedPartitionHashStrategy, partitionStart); partition = new WindowPartition(pagesIndex, partitionStart, partitionEnd, outputChannels, windowFunctions, peerGroupHashStrategy); windowInfo.addPartition(partition); } partition.processNextRow(pageBuilder); } Page page = pageBuilder.build(); pageBuilder.reset(); return page; } private void sortPagesIndexIfNecessary() { if (pagesIndex.getPositionCount() > 1 && !orderChannels.isEmpty()) { int startPosition = 0; while (startPosition < pagesIndex.getPositionCount()) { int endPosition = findGroupEnd(pagesIndex, preSortedPartitionHashStrategy, startPosition); pagesIndex.sort(orderChannels, ordering, startPosition, endPosition); startPosition = endPosition; } } } private void finishPagesIndex() { sortPagesIndexIfNecessary(); windowInfo.addIndex(pagesIndex); } // Assumes input grouped on relevant pagesHashStrategy columns private static int findGroupEnd(Page page, PagesHashStrategy pagesHashStrategy, int startPosition) { checkArgument(page.getPositionCount() > 0, "Must have at least one position"); checkPositionIndex(startPosition, page.getPositionCount(), "startPosition out of bounds"); return findEndPosition(startPosition, page.getPositionCount(), (firstPosition, secondPosition) -> pagesHashStrategy.rowEqualsRow(firstPosition, page, secondPosition, page)); } // Assumes input grouped on relevant pagesHashStrategy columns private static int findGroupEnd(PagesIndex pagesIndex, PagesHashStrategy pagesHashStrategy, int startPosition) { checkArgument(pagesIndex.getPositionCount() > 0, "Must have at least one position"); checkPositionIndex(startPosition, pagesIndex.getPositionCount(), "startPosition out of bounds"); return findEndPosition(startPosition, pagesIndex.getPositionCount(), (firstPosition, secondPosition) -> pagesIndex.positionEqualsPosition(pagesHashStrategy, firstPosition, secondPosition)); } /** * @param startPosition - inclusive * @param endPosition - exclusive * @param comparator - returns true if positions given as parameters are equal * @return the end of the group position exclusive (the position the very next group starts) */ @VisibleForTesting static int findEndPosition(int startPosition, int endPosition, BiPredicate<Integer, Integer> comparator) { checkArgument(startPosition >= 0, "startPosition must be greater or equal than zero: %s", startPosition); checkArgument(startPosition < endPosition, "startPosition (%s) must be less than endPosition (%s)", startPosition, endPosition); int left = startPosition; int right = endPosition; while (left + 1 < right) { int middle = (left + right) >>> 1; if (comparator.test(startPosition, middle)) { left = middle; } else { right = middle; } } return right; } @Override public void close() { driverWindowInfo.set(Optional.of(windowInfo.build())); } }