Java tutorial
/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.aggregation; import com.facebook.presto.block.BlockAssertions; import com.facebook.presto.operator.GroupByIdBlock; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.RunLengthEncodedBlock; import com.google.common.primitives.Ints; import org.apache.commons.math3.util.Precision; import java.util.Collections; import java.util.Objects; import java.util.Optional; import java.util.function.BiFunction; import java.util.stream.IntStream; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static org.testng.Assert.assertEquals; import static org.testng.Assert.fail; public final class AggregationTestUtils { private AggregationTestUtils() { } public static void assertAggregation(InternalAggregationFunction function, Object expectedValue, Block... blocks) { assertAggregation(function, expectedValue, new Page(blocks)); } public static void assertAggregation(InternalAggregationFunction function, Object expectedValue, Page page) { BiFunction<Object, Object, Boolean> equalAssertion; if (expectedValue instanceof Double && !expectedValue.equals(Double.NaN)) { equalAssertion = (actual, expected) -> Precision.equals((double) actual, (double) expected, 1e-10); } else if (expectedValue instanceof Float && !expectedValue.equals(Float.NaN)) { equalAssertion = (actual, expected) -> Precision.equals((float) actual, (float) expected, 1e-10f); } else { equalAssertion = Objects::equals; } assertAggregation(function, equalAssertion, null, page, expectedValue); } public static void assertAggregation(InternalAggregationFunction function, BiFunction<Object, Object, Boolean> equalAssertion, String testDescription, Page page, Object expectedValue) { int positions = page.getPositionCount(); for (int i = 1; i < page.getChannelCount(); i++) { assertEquals(positions, page.getBlock(i).getPositionCount(), "input blocks provided are not equal in position count"); } if (positions == 0) { assertAggregationInternal(function, equalAssertion, testDescription, expectedValue, new Page[] {}); } else if (positions == 1) { assertAggregationInternal(function, equalAssertion, testDescription, expectedValue, page); } else { int split = positions / 2; // [0, split - 1] goes to first list of blocks; [split, positions - 1] goes to second list of blocks. Page page1 = page.getRegion(0, split); Page page2 = page.getRegion(split, positions - split); assertAggregationInternal(function, equalAssertion, testDescription, expectedValue, page1, page2); } } public static Block getIntermediateBlock(Accumulator accumulator) { BlockBuilder blockBuilder = accumulator.getIntermediateType().createBlockBuilder(null, 1000); accumulator.evaluateIntermediate(blockBuilder); return blockBuilder.build(); } public static Block getIntermediateBlock(GroupedAccumulator accumulator) { BlockBuilder blockBuilder = accumulator.getIntermediateType().createBlockBuilder(null, 1000); accumulator.evaluateIntermediate(0, blockBuilder); return blockBuilder.build(); } public static Block getFinalBlock(Accumulator accumulator) { BlockBuilder blockBuilder = accumulator.getFinalType().createBlockBuilder(null, 1000); accumulator.evaluateFinal(blockBuilder); return blockBuilder.build(); } public static Block getFinalBlock(GroupedAccumulator accumulator) { BlockBuilder blockBuilder = accumulator.getFinalType().createBlockBuilder(null, 1000); accumulator.evaluateFinal(0, blockBuilder); return blockBuilder.build(); } private static void assertAggregationInternal(InternalAggregationFunction function, BiFunction<Object, Object, Boolean> isEqual, String testDescription, Object expectedValue, Page... pages) { // This assertAggregation does not try to split up the page to test the correctness of combine function. // Do not use this directly. Always use the other assertAggregation. assertFunctionEquals(isEqual, testDescription, aggregation(function, pages), expectedValue); assertFunctionEquals(isEqual, testDescription, partialAggregation(function, pages), expectedValue); if (pages.length > 0) { assertFunctionEquals(isEqual, testDescription, groupedAggregation(isEqual, function, pages), expectedValue); assertFunctionEquals(isEqual, testDescription, groupedPartialAggregation(isEqual, function, pages), expectedValue); assertFunctionEquals(isEqual, testDescription, distinctAggregation(function, pages), expectedValue); } } private static void assertFunctionEquals(BiFunction<Object, Object, Boolean> isEqual, String testDescription, Object actualValue, Object expectedValue) { if (!isEqual.apply(actualValue, expectedValue)) { StringBuilder sb = new StringBuilder(); if (testDescription != null) { sb.append(String.format("Test: %s, ", testDescription)); } sb.append(String.format("Expected: %s, actual: %s", expectedValue, actualValue)); fail(sb.toString()); } } public static Object distinctAggregation(InternalAggregationFunction function, Page... pages) { Optional<Integer> maskChannel = Optional.of(pages[0].getChannelCount()); // Execute normally Object aggregation = aggregation(function, createArgs(function), maskChannel, maskPages(true, pages)); Page[] dupedPages = new Page[pages.length * 2]; // Create two copies of each page with one of them masked off System.arraycopy(maskPages(true, pages), 0, dupedPages, 0, pages.length); System.arraycopy(maskPages(false, pages), 0, dupedPages, pages.length, pages.length); // Execute with masked pages and assure equal to normal execution Object aggregationWithDupes = aggregation(function, createArgs(function), maskChannel, dupedPages); assertEquals(aggregationWithDupes, aggregation, "Inconsistent results with mask"); return aggregation; } // Adds the mask as the last channel private static Page[] maskPages(boolean maskValue, Page... pages) { Page[] maskedPages = new Page[pages.length]; for (int i = 0; i < pages.length; i++) { Page page = pages[i]; BlockBuilder blockBuilder = BOOLEAN.createBlockBuilder(null, page.getPositionCount()); for (int j = 0; j < page.getPositionCount(); j++) { BOOLEAN.writeBoolean(blockBuilder, maskValue); } maskedPages[i] = page.appendColumn(blockBuilder.build()); } return maskedPages; } public static Object aggregation(InternalAggregationFunction function, Page... pages) { // execute with args in positions: arg0, arg1, arg2 Object aggregation = aggregation(function, createArgs(function), Optional.empty(), pages); // execute with args in reverse order: arg2, arg1, arg0 if (function.getParameterTypes().size() > 1) { Object aggregationWithOffset = aggregation(function, reverseArgs(function), Optional.empty(), reverseColumns(pages)); assertEquals(aggregationWithOffset, aggregation, "Inconsistent results with reversed channels"); } // execute with args at an offset (and possibly reversed): null, null, null, arg2, arg1, arg0 Object aggregationWithOffset = aggregation(function, offsetArgs(function, 3), Optional.empty(), offsetColumns(pages, 3)); assertEquals(aggregationWithOffset, aggregation, "Inconsistent results with channel offset"); return aggregation; } private static Object aggregation(InternalAggregationFunction function, int[] args, Optional<Integer> maskChannel, Page... pages) { Accumulator aggregation = function.bind(Ints.asList(args), maskChannel).createAccumulator(); for (Page page : pages) { if (page.getPositionCount() > 0) { aggregation.addInput(page); } } Block block = getFinalBlock(aggregation); return BlockAssertions.getOnlyValue(aggregation.getFinalType(), block); } public static Object partialAggregation(InternalAggregationFunction function, Page... pages) { // execute with args in positions: arg0, arg1, arg2 Object aggregation = partialAggregation(function, createArgs(function), pages); // execute with args in reverse order: arg2, arg1, arg0 if (function.getParameterTypes().size() > 1) { Object aggregationWithOffset = partialAggregation(function, reverseArgs(function), reverseColumns(pages)); assertEquals(aggregationWithOffset, aggregation, "Inconsistent results with reversed channels"); } // execute with args at an offset (and possibly reversed): null, null, null, arg2, arg1, arg0 Object aggregationWithOffset = partialAggregation(function, offsetArgs(function, 3), offsetColumns(pages, 3)); assertEquals(aggregationWithOffset, aggregation, "Inconsistent results with channel offset"); return aggregation; } public static Object partialAggregation(InternalAggregationFunction function, int[] args, Page... pages) { AccumulatorFactory factory = function.bind(Ints.asList(args), Optional.empty()); Accumulator finalAggregation = factory.createIntermediateAccumulator(); // Test handling of empty intermediate blocks Accumulator emptyAggregation = factory.createAccumulator(); Block emptyBlock = getIntermediateBlock(emptyAggregation); finalAggregation.addIntermediate(emptyBlock); for (Page page : pages) { Accumulator partialAggregation = factory.createAccumulator(); if (page.getPositionCount() > 0) { partialAggregation.addInput(page); } Block partialBlock = getIntermediateBlock(partialAggregation); finalAggregation.addIntermediate(partialBlock); } finalAggregation.addIntermediate(emptyBlock); Block finalBlock = getFinalBlock(finalAggregation); return BlockAssertions.getOnlyValue(finalAggregation.getFinalType(), finalBlock); } public static Object groupedAggregation(InternalAggregationFunction function, Page... pages) { return groupedAggregation(Objects::equals, function, pages); } public static Object groupedAggregation(BiFunction<Object, Object, Boolean> isEqual, InternalAggregationFunction function, Page... pages) { // execute with args in positions: arg0, arg1, arg2 Object aggregation = groupedAggregation(function, createArgs(function), pages); // execute with args in reverse order: arg2, arg1, arg0 if (function.getParameterTypes().size() > 1) { Object aggregationWithOffset = groupedAggregation(function, reverseArgs(function), reverseColumns(pages)); assertFunctionEquals(isEqual, "Inconsistent results with reversed channels", aggregationWithOffset, aggregation); } // execute with args at an offset (and possibly reversed): null, null, null, arg2, arg1, arg0 Object aggregationWithOffset = groupedAggregation(function, offsetArgs(function, 3), offsetColumns(pages, 3)); assertFunctionEquals(isEqual, "Consistent results with channel offset", aggregationWithOffset, aggregation); return aggregation; } public static Object groupedAggregation(InternalAggregationFunction function, int[] args, Page... pages) { GroupedAccumulator groupedAggregation = function.bind(Ints.asList(args), Optional.empty()) .createGroupedAccumulator(); for (Page page : pages) { groupedAggregation.addInput(createGroupByIdBlock(0, page.getPositionCount()), page); } Object groupValue = getGroupValue(groupedAggregation, 0); for (Page page : pages) { groupedAggregation.addInput(createGroupByIdBlock(4000, page.getPositionCount()), page); } Object largeGroupValue = getGroupValue(groupedAggregation, 4000); assertEquals(largeGroupValue, groupValue, "Inconsistent results with large group id"); return groupValue; } public static Object groupedPartialAggregation(BiFunction<Object, Object, Boolean> isEqual, InternalAggregationFunction function, Page... pages) { // execute with args in positions: arg0, arg1, arg2 Object aggregation = groupedPartialAggregation(function, createArgs(function), pages); // execute with args in reverse order: arg2, arg1, arg0 if (function.getParameterTypes().size() > 1) { Object aggregationWithOffset = groupedPartialAggregation(function, reverseArgs(function), reverseColumns(pages)); assertFunctionEquals(isEqual, "Consistent results with reversed channels", aggregationWithOffset, aggregation); } // execute with args at an offset (and possibly reversed): null, null, null, arg2, arg1, arg0 Object aggregationWithOffset = groupedPartialAggregation(function, offsetArgs(function, 3), offsetColumns(pages, 3)); assertFunctionEquals(isEqual, "Consistent results with channel offset", aggregationWithOffset, aggregation); return aggregation; } public static Object groupedPartialAggregation(InternalAggregationFunction function, int[] args, Page... pages) { AccumulatorFactory factory = function.bind(Ints.asList(args), Optional.empty()); GroupedAccumulator finalAggregation = factory.createGroupedIntermediateAccumulator(); // Add an empty block to test the handling of empty intermediates GroupedAccumulator emptyAggregation = factory.createGroupedAccumulator(); Block emptyBlock = getIntermediateBlock(emptyAggregation); finalAggregation.addIntermediate(createGroupByIdBlock(0, emptyBlock.getPositionCount()), emptyBlock); for (Page page : pages) { GroupedAccumulator partialAggregation = factory.createGroupedAccumulator(); partialAggregation.addInput(createGroupByIdBlock(0, page.getPositionCount()), page); Block partialBlock = getIntermediateBlock(partialAggregation); finalAggregation.addIntermediate(createGroupByIdBlock(0, partialBlock.getPositionCount()), partialBlock); } finalAggregation.addIntermediate(createGroupByIdBlock(0, emptyBlock.getPositionCount()), emptyBlock); return getGroupValue(finalAggregation, 0); } public static GroupByIdBlock createGroupByIdBlock(int groupId, int positions) { BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, positions); for (int i = 0; i < positions; i++) { BIGINT.writeLong(blockBuilder, groupId); } return new GroupByIdBlock(groupId, blockBuilder.build()); } private static int[] createArgs(InternalAggregationFunction function) { int[] args = new int[function.getParameterTypes().size()]; for (int i = 0; i < args.length; i++) { args[i] = i; } return args; } public static int[] reverseArgs(InternalAggregationFunction function) { int[] args = createArgs(function); Collections.reverse(Ints.asList(args)); return args; } public static int[] offsetArgs(InternalAggregationFunction function, int offset) { int[] args = createArgs(function); for (int i = 0; i < args.length; i++) { args[i] += offset; } return args; } public static Page[] reverseColumns(Page[] pages) { Page[] newPages = new Page[pages.length]; for (int i = 0; i < pages.length; i++) { Page page = pages[i]; if (page.getPositionCount() == 0) { newPages[i] = page; } else { Block[] newBlocks = new Block[page.getChannelCount()]; for (int channel = 0; channel < page.getChannelCount(); channel++) { newBlocks[channel] = page.getBlock(page.getChannelCount() - channel - 1); } newPages[i] = new Page(page.getPositionCount(), newBlocks); } } return newPages; } public static Page[] offsetColumns(Page[] pages, int offset) { Page[] newPages = new Page[pages.length]; for (int i = 0; i < pages.length; i++) { Page page = pages[i]; Block[] newBlocks = new Block[page.getChannelCount() + offset]; for (int channel = 0; channel < offset; channel++) { newBlocks[channel] = createNullRLEBlock(page.getPositionCount()); } for (int channel = 0; channel < page.getChannelCount(); channel++) { newBlocks[channel + offset] = page.getBlock(channel); } newPages[i] = new Page(page.getPositionCount(), newBlocks); } return newPages; } private static RunLengthEncodedBlock createNullRLEBlock(int positionCount) { return (RunLengthEncodedBlock) RunLengthEncodedBlock.create(BOOLEAN, null, positionCount); } public static Object getGroupValue(GroupedAccumulator groupedAggregation, int groupId) { BlockBuilder out = groupedAggregation.getFinalType().createBlockBuilder(null, 1); groupedAggregation.evaluateFinal(groupId, out); return BlockAssertions.getOnlyValue(groupedAggregation.getFinalType(), out.build()); } public static double[] constructDoublePrimitiveArray(int start, int length) { return IntStream.range(start, start + length).asDoubleStream().toArray(); } }