com.facebook.presto.operator.aggregation.AbstractTestApproximateAggregationFunction.java Source code

Java tutorial

Introduction

Here is the source code for com.facebook.presto.operator.aggregation.AbstractTestApproximateAggregationFunction.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.block.BlockAssertions;
import com.facebook.presto.operator.OperatorAssertion;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.type.Type;
import com.google.common.collect.ImmutableList;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.testng.annotations.Test;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Random;

import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertApproximateAggregation;
import static com.facebook.presto.operator.aggregation.AggregationTestUtils.getFinalBlock;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.DoubleType.DOUBLE;
import static org.testng.Assert.assertTrue;

public abstract class AbstractTestApproximateAggregationFunction extends AbstractTestAggregationFunction {
    private static final int WEIGHT = 10;

    protected abstract Type getType();

    // values may contain nulls
    protected abstract Double getExpectedValue(List<Number> values);

    @Override
    public double getConfidence() {
        return 0.5;
    }

    @Override
    protected boolean isApproximate() {
        return true;
    }

    @Override
    protected List<String> getFunctionParameterTypes() {
        return ImmutableList.of(getType().getTypeSignature().toString());
    }

    @Override
    public Block[] getSequenceBlocks(int start, int length) {
        BlockBuilder blockBuilder = getType().createBlockBuilder(new BlockBuilderStatus(), length);
        for (int i = start; i < start + length; i++) {
            if (getType().equals(BIGINT)) {
                BIGINT.writeLong(blockBuilder, (long) i);
            } else {
                DOUBLE.writeDouble(blockBuilder, (double) i);
            }
        }
        return new Block[] { blockBuilder.build() };
    }

    @Override
    public Double getExpectedValue(int start, int length) {
        List<Number> values = new ArrayList<>();
        for (int i = 0; i < length; i++) {
            if (getType().equals(BIGINT)) {
                for (int j = 0; j < WEIGHT; j++) {
                    values.add((long) start + i);
                }
            } else {
                for (int j = 0; j < WEIGHT; j++) {
                    values.add((double) start + i);
                }
            }
        }

        return getExpectedValue(values);
    }

    @Override
    public Double getExpectedValueIncludingNulls(int start, int length, int lengthIncludingNulls) {
        List<Number> values = new ArrayList<>();
        for (int i = 0; i < length; i++) {
            if (getType().equals(BIGINT)) {
                for (int j = 0; j < WEIGHT; j++) {
                    values.add((long) start + i);
                }
            } else {
                for (int j = 0; j < WEIGHT; j++) {
                    values.add((double) start + i);
                }
            }
        }
        for (int i = length; i < lengthIncludingNulls; i++) {
            for (int j = 0; j < WEIGHT; j++) {
                values.add(null);
            }
        }

        return getExpectedValue(values);
    }

    @Test
    public void testCorrectnessOnGaussianData() throws Exception {
        int originalDataSize = 10000;
        Random distribution = new Random(0);
        List<Number> list = new ArrayList<>();
        for (int i = 0; i < originalDataSize; i++) {
            list.add(distribution.nextGaussian() * 100);
        }

        testCorrectnessOfErrorFunction(list);
    }

    @Test
    public void testCorrectnessOnUniformData() throws Exception {
        int originalDataSize = 10000;
        Random distribution = new Random(0);
        List<Number> list = new ArrayList<>();
        for (int i = 0; i < originalDataSize; i++) {
            list.add(distribution.nextDouble() * 1000);
        }

        testCorrectnessOfErrorFunction(list);
    }

    private void testCorrectnessOfErrorFunction(List<Number> inputList) throws Exception {
        int inRange = 0;
        int numberOfRuns = 1000;
        double sampleRatio = 1 / (double) WEIGHT;
        double actual = getExpectedValue(inputList);
        Random rand = new Random(1);

        for (int i = 0; i < numberOfRuns; i++) {
            //Compute Sampled Value using sampledList (numberOfRuns times)
            ImmutableList.Builder<Number> sampledList = ImmutableList.builder();
            for (Number x : inputList) {
                if (rand.nextDouble() < sampleRatio) {
                    sampledList.add(x);
                }
            }

            ImmutableList<Number> list = sampledList.build();
            BlockBuilder builder = getType().createBlockBuilder(new BlockBuilderStatus(), list.size());
            for (Number sample : list) {
                if (getType().equals(BIGINT)) {
                    BIGINT.writeLong(builder, sample.longValue());
                } else if (getType().equals(DOUBLE)) {
                    DOUBLE.writeDouble(builder, sample.doubleValue());
                } else {
                    throw new AssertionError("Can only handle longs and doubles");
                }
            }
            Page page = new Page(builder.build());
            page = OperatorAssertion.appendSampleWeight(ImmutableList.of(page), WEIGHT).get(0);
            Accumulator accumulator = getFunction().bind(ImmutableList.of(0), Optional.empty(),
                    Optional.of(page.getChannelCount() - 1), getConfidence()).createAccumulator();

            accumulator.addInput(page);
            Block result = getFinalBlock(accumulator);

            String approxValue = BlockAssertions.toValues(accumulator.getFinalType(), result).get(0).toString();
            double approx = Double.parseDouble(approxValue.split(" ")[0]);
            double error = Double.parseDouble(approxValue.split(" ")[2]);

            //Check if actual answer lies within [approxAnswer - error, approxAnswer + error]
            if (Math.abs(approx - actual) <= error) {
                inRange++;
            }
        }

        BinomialDistribution binomial = new BinomialDistribution(numberOfRuns, getConfidence());
        int lowerBound = binomial.inverseCumulativeProbability(0.01);
        int upperBound = binomial.inverseCumulativeProbability(0.99);
        assertTrue(lowerBound < inRange && inRange < upperBound, String
                .format("%d out of %d passed. Expected [%d, %d]", inRange, numberOfRuns, lowerBound, upperBound));
    }

    @Override
    protected void testAggregation(Object expectedValue, Block... block) {
        Page page = OperatorAssertion.appendSampleWeight(ImmutableList.of(new Page(block)), WEIGHT).get(0);
        assertApproximateAggregation(getFunction(), page.getChannelCount() - 1, getConfidence(),
                (Double) expectedValue, page);
    }
}