com.linkedin.pinot.common.query.gen.AvroQueryGenerator.java Source code

Java tutorial

Introduction

Here is the source code for com.linkedin.pinot.common.query.gen.AvroQueryGenerator.java

Source

/**
 * Copyright (C) 2014-2016 LinkedIn Corp. (pinot-core@linkedin.com)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.linkedin.pinot.common.query.gen;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;

import org.apache.avro.Schema;
import org.apache.avro.Schema.Field;
import org.apache.avro.Schema.Type;
import org.apache.avro.file.DataFileStream;
import org.apache.avro.generic.GenericData.Array;
import org.apache.avro.generic.GenericData.Record;
import org.apache.avro.generic.GenericDatumReader;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.util.Utf8;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.Predicate;

import com.linkedin.pinot.common.data.FieldSpec.DataType;

/**
 * Oct 17, 2014
 */

public class AvroQueryGenerator {

    public interface TestAggreationQuery {
        String getPql();
    }

    public static class TestSimpleAggreationQuery implements TestAggreationQuery {
        public final String pql;
        public final Double result;

        public TestSimpleAggreationQuery(String pql, Double result) {
            this.pql = pql;
            this.result = result;
        }

        @Override
        public String toString() {
            return pql + " : " + result;
        }

        @Override
        public String getPql() {
            return pql;
        }
    }

    public static class TestGroupByAggreationQuery implements TestAggreationQuery {
        public final String pql;
        public final Map<Object, Double> groupResults;

        public TestGroupByAggreationQuery(String pql, Map<Object, Double> groupResults) {
            this.pql = pql;
            this.groupResults = groupResults;
        }

        @Override
        public String toString() {
            final StringBuilder bld = new StringBuilder();
            bld.append(pql + " : ");
            for (final Object key : groupResults.keySet()) {
                bld.append(key + ":" + groupResults.get(key) + ",");
            }
            return bld.toString();
        }

        @Override
        public String getPql() {
            return pql;
        }
    }

    private final File avroFile;
    private final List<String> dimensions;
    private final List<String> metrics;
    private final String time;
    private final Map<String, DataType> dataTypeMap;
    private final Map<String, Boolean> isSingleValueMap;

    private DataFileStream<GenericRecord> dataStream;
    private Schema schema;
    private List<TestSimpleAggreationQuery> aggregationQueries;
    private List<TestGroupByAggreationQuery> groupByQueries;
    private final String resourceName;
    private boolean isRealtimeSegment = false;

    public AvroQueryGenerator(File avroFile, List<String> dimensions, List<String> metrics, String time,
            String resourceName) throws FileNotFoundException, IOException {
        this.avroFile = avroFile;
        this.dimensions = dimensions;
        this.metrics = metrics;
        this.time = time;
        dataTypeMap = new HashMap<String, DataType>();
        isSingleValueMap = new HashMap<String, Boolean>();
        this.resourceName = resourceName;
    }

    public AvroQueryGenerator(File avroFile, List<String> dimensions, List<String> metrics, String time,
            String resourceName, boolean isRealtimeSegment) throws FileNotFoundException, IOException {
        this.avroFile = avroFile;
        this.dimensions = dimensions;
        this.metrics = metrics;
        this.time = time;
        dataTypeMap = new HashMap<String, DataType>();
        isSingleValueMap = new HashMap<String, Boolean>();
        this.resourceName = resourceName;
        this.isRealtimeSegment = isRealtimeSegment;
    }

    public void init() throws FileNotFoundException, IOException {
        dataStream = new DataFileStream<GenericRecord>(new FileInputStream(avroFile),
                new GenericDatumReader<GenericRecord>());
        schema = dataStream.getSchema();
    }

    public List<TestSimpleAggreationQuery> giveMeNSimpleAggregationQueries(int n) {
        Collections.shuffle(aggregationQueries, new Random(System.currentTimeMillis()));
        if (n <= aggregationQueries.size()) {
            return aggregationQueries.subList(0, n);
        }
        return aggregationQueries;
    }

    public List<TestGroupByAggreationQuery> giveMeNGroupByAggregationQueries(int n) {
        Collections.shuffle(groupByQueries, new Random(System.currentTimeMillis()));
        if (n <= aggregationQueries.size()) {
            return groupByQueries.subList(0, n);
        }
        return groupByQueries;
    }

    public void temp() {

    }

    public void generateSimpleAggregationOnSingleColumnFilters() throws IOException {
        final Map<String, Map<Object, Integer>> cardinalityCountsMap = new HashMap<String, Map<Object, Integer>>();
        final Map<String, Map<Object, Map<String, Double>>> sumMap = new HashMap<String, Map<Object, Map<String, Double>>>();
        // here string key is columnName:columnValue:MetricName:GroupColumnName:groupKey:metricValue

        final Map<String, Map<Object, Double>> sumGroupBy = new HashMap<String, Map<Object, Double>>();

        aggregationQueries = new ArrayList<AvroQueryGenerator.TestSimpleAggreationQuery>();
        groupByQueries = new ArrayList<AvroQueryGenerator.TestGroupByAggreationQuery>();
        for (final Field f : schema.getFields()) {
            final String fieldName = f.name();
            if (dimensions.contains(fieldName) || metrics.contains(fieldName) || time.equals(fieldName)) {
                isSingleValueMap.put(fieldName, isSingleValueField(f));
                dataTypeMap.put(fieldName, getColumnType(f));
                if (!metrics.contains(fieldName)) {
                    cardinalityCountsMap.put(fieldName, new HashMap<Object, Integer>());
                }
            }
        }

        for (final String column : cardinalityCountsMap.keySet()) {
            sumMap.put(column, new HashMap<Object, Map<String, Double>>());
        }

        // here string key is columnName:columnValue:MetricName:GroupColumnName:groupKey:metricValue

        while (dataStream.hasNext()) {
            final GenericRecord record = dataStream.next();

            for (final String column : cardinalityCountsMap.keySet()) {
                Object value = record.get(column);

                if (value == null) {
                    switch (schema.getField(column).schema().getType()) {
                    case INT:
                        value = 0;
                        break;
                    case FLOAT:
                        value = 0F;
                        break;
                    case LONG:
                        value = 0L;
                        break;
                    case DOUBLE:
                        value = 0D;
                        break;
                    case STRING:
                    case BOOLEAN:
                        value = "null";
                        break;
                    }
                }

                if (value instanceof Utf8) {
                    value = ((Utf8) value).toString();
                }

                if (value instanceof Array) {
                    continue;
                }

                // here string key is columnName:columnValue:MetricName:GroupColumnName:groupKey:metricValue

                for (final String metricName : metrics) {
                    final String groupbyKeyBase = column + ":" + record.get(column) + ":" + metricName;
                    int dimCounter = 1;
                    for (final String dim : cardinalityCountsMap.keySet()) {
                        if (!dim.equals(column)) {
                            dimCounter++;
                            final String groupbyKey = groupbyKeyBase + ":" + dim;
                            if (sumGroupBy.containsKey(groupbyKey)) {
                                if (sumGroupBy.get(groupbyKey).containsKey(record.get(dim))) {
                                    sumGroupBy.get(groupbyKey).put(record.get(dim),
                                            getAppropriateNumberType(metricName, record.get(metricName),
                                                    sumGroupBy.get(groupbyKey).get(record.get(dim))));
                                } else {
                                    sumGroupBy.get(groupbyKey).put(record.get(dim),
                                            Double.parseDouble(record.get(metricName).toString()));
                                }
                            } else {
                                sumGroupBy.put(groupbyKey, new HashMap<Object, Double>());
                                sumGroupBy.get(groupbyKey).put(record.get(dim),
                                        Double.parseDouble(record.get(metricName).toString()));
                            }
                        }
                        if (dimCounter == 4) {
                            break;
                        }
                    }
                }

                if (cardinalityCountsMap.get(column).containsKey(value)) {
                    cardinalityCountsMap.get(column).put(value, cardinalityCountsMap.get(column).get(value) + 1);
                } else {
                    cardinalityCountsMap.get(column).put(value, 1);
                }

                if (!sumMap.get(column).containsKey(value)) {
                    sumMap.get(column).put(value, new HashMap<String, Double>());
                }

                for (final String metric : metrics) {
                    if (!sumMap.get(column).get(value).containsKey(metric)) {
                        sumMap.get(column).get(value).put(metric,
                                getAppropriateNumberType(metric, record.get(metric), 0D));
                    } else {
                        sumMap.get(column).get(value).put(metric, getAppropriateNumberType(metric,
                                record.get(metric), sumMap.get(column).get(value).get(metric)));
                    }
                }
                // here string key is columnName:columnValue:MetricName:GroupColumnName:groupKey:metricValue
            }
        }

        dataStream.close();

        if (!isRealtimeSegment) {
            for (final String column : cardinalityCountsMap.keySet()) {
                for (final Object entry : cardinalityCountsMap.get(column).keySet()) {
                    final StringBuilder bld = new StringBuilder();
                    bld.append("select count(*) from ");
                    bld.append(resourceName);
                    bld.append(" where ");
                    bld.append(column);
                    bld.append("=");
                    bld.append("'");
                    bld.append(entry);
                    bld.append("'");
                    bld.append(" ");
                    bld.append("limit 0");
                    String queryString = bld.toString();
                    if (!queryString.contains("null")) {
                        aggregationQueries.add(new TestSimpleAggreationQuery(queryString,
                                new Double(cardinalityCountsMap.get(column).get(entry))));
                    }
                }
            }
        }

        for (final String column : sumMap.keySet()) {
            for (final Object value : sumMap.get(column).keySet()) {
                for (final String metric : sumMap.get(column).get(value).keySet()) {
                    final StringBuilder bld = new StringBuilder();
                    bld.append("select sum('" + metric + "') from ");
                    bld.append(resourceName);
                    bld.append(" where ");
                    bld.append(column);
                    bld.append("=");
                    bld.append("'");
                    bld.append(value);
                    bld.append("'");
                    bld.append(" ");
                    bld.append("limit 0");
                    String queryString = bld.toString();
                    if (!queryString.contains("null")) {
                        aggregationQueries.add(new TestSimpleAggreationQuery(bld.toString(),
                                sumMap.get(column).get(value).get(metric)));
                    }
                }
            }
        }

        for (final String groupKey : sumGroupBy.keySet()) {
            final String columnName = groupKey.split(":")[0];
            final String columnValue = groupKey.split(":")[1];
            final String metricColumn = groupKey.split(":")[2];
            final String groupByColumnName = groupKey.split(":")[3];

            final StringBuilder bld = new StringBuilder();
            bld.append("select sum('" + metricColumn + "') from ");
            bld.append(resourceName);
            bld.append(" where ");
            bld.append(columnName);
            bld.append("=");
            bld.append("'");
            bld.append(columnValue);
            bld.append("'");
            bld.append(" ");
            bld.append(" group by ");
            bld.append(groupByColumnName);
            bld.append(" top 10 ");
            bld.append("limit 0");
            String queryString = bld.toString();
            if (!queryString.contains("null")) {
                groupByQueries.add(new TestGroupByAggreationQuery(bld.toString(), sumGroupBy.get(groupKey)));
            }
        }
    }

    private Double getAppropriateNumberType(String column, Object entry, Double sum) {
        switch (dataTypeMap.get(column)) {
        case INT:
            return sum + (Integer) entry;
        case FLOAT:
            return sum + (Float) entry;
        case LONG:
            return sum + (Long) entry;
        case DOUBLE:
            return sum + (Double) entry;
        default:
            return 0D;
        }
    }

    private static boolean isSingleValueField(Field field) {
        org.apache.avro.Schema fieldSchema = field.schema();
        fieldSchema = extractSchemaFromUnionIfNeeded(fieldSchema);

        final Type type = fieldSchema.getType();
        if (type == Type.ARRAY) {
            return false;
        }
        return true;
    }

    public static DataType getColumnType(Field field) {
        org.apache.avro.Schema fieldSchema = field.schema();
        fieldSchema = extractSchemaFromUnionIfNeeded(fieldSchema);

        final Type type = fieldSchema.getType();
        if (type == Type.ARRAY) {
            org.apache.avro.Schema elementSchema = extractSchemaFromUnionIfNeeded(fieldSchema.getElementType());
            if (elementSchema.getType() == Type.RECORD) {
                if (elementSchema.getFields().size() == 1) {
                    elementSchema = elementSchema.getFields().get(0).schema();
                } else {
                    throw new RuntimeException("More than one schema in Multi-value column!");
                }
                elementSchema = extractSchemaFromUnionIfNeeded(elementSchema);
            }
            return DataType.valueOf(elementSchema.getType());
        } else {
            return DataType.valueOf(type);
        }
    }

    private static org.apache.avro.Schema extractSchemaFromUnionIfNeeded(org.apache.avro.Schema fieldSchema) {
        if ((fieldSchema).getType() == Type.UNION) {
            fieldSchema = ((org.apache.avro.Schema) CollectionUtils.find(fieldSchema.getTypes(), new Predicate() {
                @Override
                public boolean evaluate(Object object) {
                    return ((org.apache.avro.Schema) object).getType() != Type.NULL;
                }
            }));
        }
        return fieldSchema;
    }

    private static Object[] transformAvroArrayToObjectArray(Array arr) {
        if (arr == null) {
            return new Object[0];
        }
        final Object[] ret = new Object[arr.size()];
        final Iterator iterator = arr.iterator();
        int i = 0;
        while (iterator.hasNext()) {
            Object value = iterator.next();
            if (value instanceof Record) {
                value = ((Record) value).get(0);
            }
            if (value instanceof Utf8) {
                value = ((Utf8) value).toString();
            }
            ret[i++] = value;
        }
        return ret;
    }

}