com.linkedin.pinot.query.transform.TransformGroupByTest.java Source code

Java tutorial

Introduction

Here is the source code for com.linkedin.pinot.query.transform.TransformGroupByTest.java

Source

/**
 * Copyright (C) 2014-2016 LinkedIn Corp. (pinot-core@linkedin.com)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.linkedin.pinot.query.transform;

import com.linkedin.pinot.common.data.DimensionFieldSpec;
import com.linkedin.pinot.common.data.FieldSpec;
import com.linkedin.pinot.common.data.MetricFieldSpec;
import com.linkedin.pinot.common.data.Schema;
import com.linkedin.pinot.common.data.TimeFieldSpec;
import com.linkedin.pinot.common.request.AggregationInfo;
import com.linkedin.pinot.common.request.BrokerRequest;
import com.linkedin.pinot.common.request.GroupBy;
import com.linkedin.pinot.common.segment.ReadMode;
import com.linkedin.pinot.core.common.BlockValSet;
import com.linkedin.pinot.core.common.Operator;
import com.linkedin.pinot.core.data.GenericRow;
import com.linkedin.pinot.core.data.readers.FileFormat;
import com.linkedin.pinot.core.data.readers.RecordReader;
import com.linkedin.pinot.core.indexsegment.IndexSegment;
import com.linkedin.pinot.core.indexsegment.generator.SegmentGeneratorConfig;
import com.linkedin.pinot.core.operator.BReusableFilteredDocIdSetOperator;
import com.linkedin.pinot.core.operator.BaseOperator;
import com.linkedin.pinot.core.operator.MProjectionOperator;
import com.linkedin.pinot.core.operator.blocks.IntermediateResultsBlock;
import com.linkedin.pinot.core.operator.filter.MatchEntireSegmentOperator;
import com.linkedin.pinot.core.operator.query.AggregationGroupByOperator;
import com.linkedin.pinot.core.operator.transform.TransformExpressionOperator;
import com.linkedin.pinot.core.operator.transform.function.TimeConversionTransform;
import com.linkedin.pinot.core.operator.transform.function.TransformFunction;
import com.linkedin.pinot.core.operator.transform.function.TransformFunctionFactory;
import com.linkedin.pinot.core.plan.AggregationFunctionInitializer;
import com.linkedin.pinot.core.plan.DocIdSetPlanNode;
import com.linkedin.pinot.core.plan.TransformPlanNode;
import com.linkedin.pinot.core.query.aggregation.AggregationFunctionContext;
import com.linkedin.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
import com.linkedin.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
import com.linkedin.pinot.core.segment.creator.impl.SegmentIndexCreationDriverImpl;
import com.linkedin.pinot.core.segment.index.loader.Loaders;
import com.linkedin.pinot.pql.parsers.Pql2Compiler;
import com.linkedin.pinot.util.TestUtils;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/**
 * Unit test for transforms on group by columns.
 */
public class TransformGroupByTest {
    private static final Logger LOGGER = LoggerFactory.getLogger(TransformExpressionOperatorTest.class);

    private static final String SEGMENT_DIR_NAME = System.getProperty("java.io.tmpdir") + File.separator
            + "xformGroupBy";
    private static final String SEGMENT_NAME = "xformGroupBySeg";
    private static final String TABLE_NAME = "xformGroupByTable";

    private static final long RANDOM_SEED = System.nanoTime();
    private static final int NUM_ROWS = DocIdSetPlanNode.MAX_DOC_PER_CALL;
    private static final double EPSILON = 1e-5;
    private static final String DIMENSION_NAME = "dimension";
    private static final String TIME_COLUMN_NAME = "millisSinceEpoch";
    private static final String METRIC_NAME = "metric";
    private static final String[] _dimensionValues = new String[] { "abcd", "ABCD", "bcde", "BCDE", "cdef",
            "CDEF" };

    private IndexSegment _indexSegment;
    private RecordReader _recordReader;

    @BeforeClass
    public void setup() throws Exception {
        TransformFunctionFactory
                .init(new String[] { ToUpper.class.getName(), TimeConversionTransform.class.getName() });

        Schema schema = buildSchema();
        _recordReader = buildSegment(SEGMENT_DIR_NAME, SEGMENT_NAME, schema);
        _indexSegment = Loaders.IndexSegment.load(new File(SEGMENT_DIR_NAME, SEGMENT_NAME), ReadMode.heap);
    }

    @AfterClass
    public void tearDown() throws IOException {
        FileUtils.deleteDirectory(new File(SEGMENT_DIR_NAME));
    }

    /**
     * Test for group-by with transformed string dimension column.
     */
    @Test
    public void testGroupByString() throws Exception {
        String query = String.format("select sum(%s) from xformSegTable group by ToUpper(%s)", METRIC_NAME,
                DIMENSION_NAME);
        AggregationGroupByResult groupByResult = executeGroupByQuery(_indexSegment, query);
        Assert.assertNotNull(groupByResult);

        // Compute the expected answer for the query.
        Map<String, Double> expectedValuesMap = new HashMap<>();
        _recordReader.rewind();
        for (int row = 0; row < NUM_ROWS; row++) {
            GenericRow genericRow = _recordReader.next();
            String key = ((String) genericRow.getValue(DIMENSION_NAME)).toUpperCase();
            Double value = (Double) genericRow.getValue(METRIC_NAME);
            Double prevValue = expectedValuesMap.get(key);

            if (prevValue == null) {
                expectedValuesMap.put(key, value);
            } else {
                expectedValuesMap.put(key, prevValue + value);
            }
        }

        compareGroupByResults(groupByResult, expectedValuesMap);
    }

    /**
     * Test for group-by with transformed time column from millis to days.
     *
     * @throws Exception
     */
    @Test
    public void testTimeRollUp() throws Exception {
        String query = String.format(
                "select sum(%s) from xformSegTable group by timeConvert(%s, 'MILLISECONDS', 'DAYS')", METRIC_NAME,
                TIME_COLUMN_NAME);

        AggregationGroupByResult groupByResult = executeGroupByQuery(_indexSegment, query);
        Assert.assertNotNull(groupByResult);

        Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator = groupByResult.getGroupKeyIterator();
        Assert.assertNotNull(groupKeyIterator);

        // Compute the expected answer for the query.
        Map<String, Double> expectedValuesMap = new HashMap<>();
        _recordReader.rewind();
        for (int row = 0; row < NUM_ROWS; row++) {
            GenericRow genericRow = _recordReader.next();
            long daysSinceEpoch = TimeUnit.DAYS.convert(((Long) genericRow.getValue(TIME_COLUMN_NAME)),
                    TimeUnit.MILLISECONDS);

            Double value = (Double) genericRow.getValue(METRIC_NAME);
            String key = String.valueOf(daysSinceEpoch);
            Double prevValue = expectedValuesMap.get(key);

            if (prevValue == null) {
                expectedValuesMap.put(key, value);
            } else {
                expectedValuesMap.put(key, prevValue + value);
            }
        }

        compareGroupByResults(groupByResult, expectedValuesMap);
    }

    /**
     * Helper method that executes the group by query on the index and returns the group by result.
     *
     * @param query Query to execute
     * @return Group by result
     */
    private AggregationGroupByResult executeGroupByQuery(IndexSegment indexSegment, String query) {
        Operator filterOperator = new MatchEntireSegmentOperator(indexSegment.getSegmentMetadata().getTotalDocs());
        final BReusableFilteredDocIdSetOperator docIdSetOperator = new BReusableFilteredDocIdSetOperator(
                filterOperator, indexSegment.getSegmentMetadata().getTotalDocs(), NUM_ROWS);

        final Map<String, BaseOperator> dataSourceMap = buildDataSourceMap(
                indexSegment.getSegmentMetadata().getSchema());
        final MProjectionOperator projectionOperator = new MProjectionOperator(dataSourceMap, docIdSetOperator);

        Pql2Compiler compiler = new Pql2Compiler();
        BrokerRequest brokerRequest = compiler.compileToBrokerRequest(query);

        List<AggregationInfo> aggregationsInfo = brokerRequest.getAggregationsInfo();
        int numAggFunctions = aggregationsInfo.size();

        AggregationFunctionContext[] aggrFuncContextArray = new AggregationFunctionContext[numAggFunctions];
        AggregationFunctionInitializer aggFuncInitializer = new AggregationFunctionInitializer(
                indexSegment.getSegmentMetadata());
        for (int i = 0; i < numAggFunctions; i++) {
            AggregationInfo aggregationInfo = aggregationsInfo.get(i);
            aggrFuncContextArray[i] = AggregationFunctionContext.instantiate(aggregationInfo);
            aggrFuncContextArray[i].getAggregationFunction().accept(aggFuncInitializer);
        }

        GroupBy groupBy = brokerRequest.getGroupBy();
        Set<String> expressions = new HashSet<>(groupBy.getExpressions());

        TransformExpressionOperator transformOperator = new TransformExpressionOperator(projectionOperator,
                TransformPlanNode.buildTransformExpressionTrees(expressions));

        AggregationGroupByOperator groupByOperator = new AggregationGroupByOperator(aggrFuncContextArray, groupBy,
                Integer.MAX_VALUE, transformOperator, NUM_ROWS);

        IntermediateResultsBlock block = (IntermediateResultsBlock) groupByOperator.nextBlock();
        return block.getAggregationGroupByResult();
    }

    /**
     * Helper method to build a segment with one dimension column containing values
     * from {@link #_dimensionValues}, and one metric column.
     *
     * Also builds the expected group by result as it builds the segments.
     *
     * @param segmentDirName Name of segment directory
     * @param segmentName Name of segment
     * @param schema Schema for segment
     * @return Schema built for the segment
     * @throws Exception
     */
    private RecordReader buildSegment(String segmentDirName, String segmentName, Schema schema) throws Exception {

        SegmentGeneratorConfig config = new SegmentGeneratorConfig(schema);
        config.setOutDir(segmentDirName);
        config.setFormat(FileFormat.AVRO);
        config.setTableName(TABLE_NAME);
        config.setSegmentName(segmentName);

        Random random = new Random(RANDOM_SEED);
        long currentTimeMillis = System.currentTimeMillis();

        // Divide the day into fixed parts, and decrement time column value by this delta, so as to get
        // continuous days in the input. This gives about 10 days per 10k rows.
        long timeDelta = TimeUnit.MILLISECONDS.convert(1, TimeUnit.DAYS) / 1000;

        final List<GenericRow> data = new ArrayList<>();
        int numDimValues = _dimensionValues.length;

        for (int row = 0; row < NUM_ROWS; row++) {
            HashMap<String, Object> map = new HashMap<>();

            map.put(DIMENSION_NAME, _dimensionValues[random.nextInt(numDimValues)]);
            map.put(METRIC_NAME, random.nextDouble());

            map.put(TIME_COLUMN_NAME, currentTimeMillis);
            currentTimeMillis -= timeDelta;

            GenericRow genericRow = new GenericRow();
            genericRow.init(map);
            data.add(genericRow);
        }

        SegmentIndexCreationDriverImpl driver = new SegmentIndexCreationDriverImpl();
        RecordReader reader = new TestUtils.GenericRowRecordReader(schema, data);
        driver.init(config, reader);
        driver.build();

        LOGGER.info("Built segment {} at {}", segmentName, segmentDirName);
        return reader;
    }

    /**
     * Helper method to build a schema with one string dimension, and one double metric columns.
     */
    private static Schema buildSchema() {
        Schema schema = new Schema();
        DimensionFieldSpec dimensionFieldSpec = new DimensionFieldSpec(DIMENSION_NAME, FieldSpec.DataType.STRING,
                true);
        schema.addField(dimensionFieldSpec);

        MetricFieldSpec metricFieldSpec = new MetricFieldSpec(METRIC_NAME, FieldSpec.DataType.DOUBLE);
        schema.addField(metricFieldSpec);

        TimeFieldSpec timeFieldSpec = new TimeFieldSpec(TIME_COLUMN_NAME, FieldSpec.DataType.LONG,
                TimeUnit.MILLISECONDS);
        schema.setTimeFieldSpec(timeFieldSpec);
        return schema;
    }

    /**
     * Helper method to build data source map for all the metric columns.
     *
     * @param schema Schema for the index segment
     * @return Map of metric name to its data source.
     */
    private Map<String, BaseOperator> buildDataSourceMap(Schema schema) {
        final Map<String, BaseOperator> dataSourceMap = new HashMap<>();
        for (String metricName : schema.getColumnNames()) {
            dataSourceMap.put(metricName, _indexSegment.getDataSource(metricName));
        }
        return dataSourceMap;
    }

    /**
     * Helper method to compare group by result from query execution against a map of group keys and values.
     *
     * @param groupByResult Group by result from query
     * @param expectedValuesMap Map of expected keys and values
     */
    private void compareGroupByResults(AggregationGroupByResult groupByResult,
            Map<String, Double> expectedValuesMap) {
        Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator = groupByResult.getGroupKeyIterator();
        Assert.assertNotNull(groupKeyIterator);

        int numGroupKeys = 0;
        while (groupKeyIterator.hasNext()) {
            GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next();
            Double actual = (Double) groupByResult.getResultForKey(groupKey, 0 /* aggregation function index */);

            String stringKey = groupKey.getStringKey();
            Double expected = expectedValuesMap.get(stringKey);
            Assert.assertNotNull(expected, "Unexpected key in actual result: " + stringKey);
            Assert.assertEquals(actual, expected, EPSILON);
            numGroupKeys++;
        }

        Assert.assertEquals(numGroupKeys, expectedValuesMap.size(), "Mis-match in number of group keys");
    }

    /**
     * Implementation of TransformFunction that converts strings to upper case.
     */
    public static class ToUpper implements TransformFunction {
        @Override
        public String[] transform(int length, BlockValSet... input) {
            String[] inputStrings = input[0].getStringValuesSV();
            String[] outputStrings = new String[length];

            for (int i = 0; i < length; i++) {
                outputStrings[i] = inputStrings[i].toUpperCase();
            }
            return outputStrings;
        }

        @Override
        public FieldSpec.DataType getOutputType() {
            return FieldSpec.DataType.STRING;
        }

        @Override
        public String getName() {
            return "ToUpper";
        }
    }
}