org.apache.flink.table.codegen.SortCodeGeneratorTest.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.flink.table.codegen.SortCodeGeneratorTest.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.codegen;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.operators.sort.QuickSort;
import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.codegen.sort.SortCodeGenerator;
import org.apache.flink.table.dataformat.BinaryArray;
import org.apache.flink.table.dataformat.BinaryGeneric;
import org.apache.flink.table.dataformat.BinaryRow;
import org.apache.flink.table.dataformat.BinaryRowWriter;
import org.apache.flink.table.dataformat.BinaryString;
import org.apache.flink.table.dataformat.BinaryWriter;
import org.apache.flink.table.dataformat.DataFormatConverters;
import org.apache.flink.table.dataformat.Decimal;
import org.apache.flink.table.dataformat.GenericRow;
import org.apache.flink.table.dataformat.TypeGetterSetters;
import org.apache.flink.table.generated.GeneratedNormalizedKeyComputer;
import org.apache.flink.table.generated.GeneratedRecordComparator;
import org.apache.flink.table.generated.NormalizedKeyComputer;
import org.apache.flink.table.generated.RecordComparator;
import org.apache.flink.table.plan.util.SortUtil;
import org.apache.flink.table.runtime.sort.BinaryInMemorySortBuffer;
import org.apache.flink.table.type.ArrayType;
import org.apache.flink.table.type.DecimalType;
import org.apache.flink.table.type.GenericType;
import org.apache.flink.table.type.InternalType;
import org.apache.flink.table.type.InternalTypes;
import org.apache.flink.table.type.RowType;
import org.apache.flink.table.typeutils.AbstractRowSerializer;
import org.apache.flink.table.typeutils.BinaryRowSerializer;
import org.apache.flink.util.MutableObjectIterator;

import org.apache.commons.lang3.RandomStringUtils;
import org.junit.Assert;
import org.junit.Test;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;

/**
 * Random test for sort code generator.
 */
public class SortCodeGeneratorTest {

    private static final int RECORD_NUM = 3000;

    private final InternalType[] types = new InternalType[] { InternalTypes.BOOLEAN, InternalTypes.BYTE,
            InternalTypes.SHORT, InternalTypes.INT, InternalTypes.LONG, InternalTypes.FLOAT, InternalTypes.DOUBLE,
            InternalTypes.STRING, new DecimalType(18, 2), new DecimalType(38, 18), InternalTypes.BINARY,
            new ArrayType(InternalTypes.BYTE), new RowType(InternalTypes.INT),
            new RowType(new RowType(InternalTypes.INT)), new GenericType<>(Types.INT) };

    private int[] fields;
    private int[] keys;
    private boolean[] orders;
    private boolean[] nullsIsLast;

    private static final DataFormatConverters.DataFormatConverter INT_ROW_CONV = DataFormatConverters
            .getConverterForTypeInfo(new RowTypeInfo(Types.INT));
    private static final TypeComparator INT_ROW_COMP = new RowTypeInfo(Types.INT).createComparator(new int[] { 0 },
            new boolean[] { true }, 0, new ExecutionConfig());
    private static final DataFormatConverters.DataFormatConverter NEST_ROW_CONV = DataFormatConverters
            .getConverterForTypeInfo(new RowTypeInfo(new RowTypeInfo(Types.INT)));
    private static final TypeComparator NEST_ROW_COMP = new RowTypeInfo(new RowTypeInfo(Types.INT))
            .createComparator(new int[] { 0 }, new boolean[] { true }, 0, new ExecutionConfig());

    @Test
    public void testMultiKeys() throws Exception {
        for (int i = 0; i < 100; i++) {
            randomKeysAndOrders();
            testInner();
        }
    }

    @Test
    public void testOneKey() throws Exception {
        for (int time = 0; time < 100; time++) {
            Random rnd = new Random();
            fields = new int[rnd.nextInt(9) + 1];
            for (int i = 0; i < fields.length; i++) {
                fields[i] = rnd.nextInt(types.length);
            }

            keys = new int[] { 0 };
            orders = new boolean[] { rnd.nextBoolean() };
            nullsIsLast = SortUtil.getNullDefaultOrders(orders);
            testInner();
        }
    }

    private void randomKeysAndOrders() {
        Random rnd = new Random();
        fields = new int[rnd.nextInt(9) + 1];
        for (int i = 0; i < fields.length; i++) {
            fields[i] = rnd.nextInt(types.length);
        }

        keys = new int[rnd.nextInt(fields.length) + 1];
        LinkedList<Integer> indexQueue = new LinkedList<>();
        for (int i = 0; i < fields.length; i++) {
            indexQueue.add(i);
        }
        Collections.shuffle(indexQueue);
        orders = new boolean[keys.length];
        for (int i = 0; i < keys.length; i++) {
            keys[i] = indexQueue.poll();
            orders[i] = rnd.nextBoolean();
        }
        nullsIsLast = SortUtil.getNullDefaultOrders(orders);
    }

    private Object[] shuffle(Object[] objects) {
        Collections.shuffle(Arrays.asList(objects));
        return objects;
    }

    private BinaryRow row(int i, Object[][] values) {
        BinaryRow row = new BinaryRow(fields.length);
        BinaryRowWriter writer = new BinaryRowWriter(row);
        for (int j = 0; j < fields.length; j++) {
            Object value = values[j][i];
            if (value == null) {
                writer.setNullAt(j);
            } else {
                BinaryWriter.write(writer, j, value, types[fields[j]]);
            }
        }

        writer.complete();
        return row;
    }

    private BinaryRow[] getTestData() {
        BinaryRow[] result = new BinaryRow[RECORD_NUM];
        Object[][] values = new Object[fields.length][];
        for (int i = 0; i < fields.length; i++) {
            values[i] = shuffle(generateValues(types[fields[i]]));
        }
        for (int i = 0; i < RECORD_NUM; i++) {
            result[i] = row(i, values);
        }
        return result;
    }

    private Object[] generateValues(InternalType type) {

        Random rnd = new Random();

        int seedNum = RECORD_NUM / 5;
        Object[] seeds = new Object[seedNum];
        seeds[0] = null;
        seeds[1] = value1(type, rnd);
        seeds[2] = value2(type, rnd);
        seeds[3] = value3(type, rnd);
        for (int i = 4; i < seeds.length; i++) {
            if (type.equals(InternalTypes.BOOLEAN)) {
                seeds[i] = rnd.nextBoolean();
            } else if (type.equals(InternalTypes.BYTE)) {
                seeds[i] = (byte) rnd.nextLong();
            } else if (type.equals(InternalTypes.SHORT)) {
                seeds[i] = (short) rnd.nextLong();
            } else if (type.equals(InternalTypes.INT)) {
                seeds[i] = rnd.nextInt();
            } else if (type.equals(InternalTypes.LONG)) {
                seeds[i] = rnd.nextLong();
            } else if (type.equals(InternalTypes.FLOAT)) {
                seeds[i] = rnd.nextFloat() * rnd.nextLong();
            } else if (type.equals(InternalTypes.DOUBLE)) {
                seeds[i] = rnd.nextDouble() * rnd.nextLong();
            } else if (type.equals(InternalTypes.STRING)) {
                seeds[i] = BinaryString.fromString(RandomStringUtils.random(rnd.nextInt(20)));
            } else if (type instanceof DecimalType) {
                DecimalType decimalType = (DecimalType) type;
                BigDecimal decimal = new BigDecimal(rnd.nextInt()).divide(
                        new BigDecimal(ThreadLocalRandom.current().nextInt(1, 256)),
                        ThreadLocalRandom.current().nextInt(1, 30), BigDecimal.ROUND_HALF_EVEN);
                seeds[i] = Decimal.fromBigDecimal(decimal, decimalType.precision(), decimalType.scale());
            } else if (type instanceof ArrayType || type.equals(InternalTypes.BINARY)) {
                byte[] bytes = new byte[rnd.nextInt(16) + 1];
                rnd.nextBytes(bytes);
                seeds[i] = type.equals(InternalTypes.BINARY) ? bytes : BinaryArray.fromPrimitiveArray(bytes);
            } else if (type instanceof RowType) {
                RowType rowType = (RowType) type;
                if (rowType.getTypeAt(0).equals(InternalTypes.INT)) {
                    seeds[i] = GenericRow.of(rnd.nextInt());
                } else {
                    seeds[i] = GenericRow.of(GenericRow.of(rnd.nextInt()));
                }
            } else if (type instanceof GenericType) {
                seeds[i] = new BinaryGeneric<>(rnd.nextInt(), IntSerializer.INSTANCE);
            } else {
                throw new RuntimeException("Not support!");
            }
        }

        // result values
        Object[] results = new Object[RECORD_NUM];
        for (int i = 0; i < RECORD_NUM; i++) {
            results[i] = seeds[rnd.nextInt(seedNum)];
        }
        return results;
    }

    private Object value1(InternalType type, Random rnd) {
        if (type.equals(InternalTypes.BOOLEAN)) {
            return false;
        } else if (type.equals(InternalTypes.BYTE)) {
            return Byte.MIN_VALUE;
        } else if (type.equals(InternalTypes.SHORT)) {
            return Short.MIN_VALUE;
        } else if (type.equals(InternalTypes.INT)) {
            return Integer.MIN_VALUE;
        } else if (type.equals(InternalTypes.LONG)) {
            return Long.MIN_VALUE;
        } else if (type.equals(InternalTypes.FLOAT)) {
            return Float.MIN_VALUE;
        } else if (type.equals(InternalTypes.DOUBLE)) {
            return Double.MIN_VALUE;
        } else if (type.equals(InternalTypes.STRING)) {
            return BinaryString.fromString("");
        } else if (type instanceof DecimalType) {
            DecimalType decimalType = (DecimalType) type;
            return Decimal.fromBigDecimal(new BigDecimal(Integer.MIN_VALUE), decimalType.precision(),
                    decimalType.scale());
        } else if (type instanceof ArrayType) {
            byte[] bytes = new byte[rnd.nextInt(7) + 1];
            rnd.nextBytes(bytes);
            BinaryArray array = BinaryArray.fromPrimitiveArray(bytes);
            for (int i = 0; i < bytes.length; i++) {
                array.setNullByte(i);
            }
            return array;
        } else if (type.equals(InternalTypes.BINARY)) {
            byte[] bytes = new byte[rnd.nextInt(7) + 1];
            rnd.nextBytes(bytes);
            return bytes;
        } else if (type instanceof RowType) {
            return GenericRow.of(new Object[] { null });
        } else if (type instanceof GenericType) {
            return new BinaryGeneric<>(rnd.nextInt(), IntSerializer.INSTANCE);
        } else {
            throw new RuntimeException("Not support!");
        }
    }

    private Object value2(InternalType type, Random rnd) {
        if (type.equals(InternalTypes.BOOLEAN)) {
            return false;
        } else if (type.equals(InternalTypes.BYTE)) {
            return (byte) 0;
        } else if (type.equals(InternalTypes.SHORT)) {
            return (short) 0;
        } else if (type.equals(InternalTypes.INT)) {
            return 0;
        } else if (type.equals(InternalTypes.LONG)) {
            return 0L;
        } else if (type.equals(InternalTypes.FLOAT)) {
            return 0f;
        } else if (type.equals(InternalTypes.DOUBLE)) {
            return 0d;
        } else if (type.equals(InternalTypes.STRING)) {
            return BinaryString.fromString("0");
        } else if (type instanceof DecimalType) {
            DecimalType decimalType = (DecimalType) type;
            return Decimal.fromBigDecimal(new BigDecimal(0), decimalType.precision(), decimalType.scale());
        } else if (type instanceof ArrayType || type.equals(InternalTypes.BINARY)) {
            byte[] bytes = new byte[rnd.nextInt(7) + 10];
            rnd.nextBytes(bytes);
            return type.equals(InternalTypes.BINARY) ? bytes : BinaryArray.fromPrimitiveArray(bytes);
        } else if (type instanceof RowType) {
            RowType rowType = (RowType) type;
            if (rowType.getTypeAt(0).equals(InternalTypes.INT)) {
                return GenericRow.of(rnd.nextInt());
            } else {
                return GenericRow.of(GenericRow.of(new Object[] { null }));
            }
        } else if (type instanceof GenericType) {
            return new BinaryGeneric<>(rnd.nextInt(), IntSerializer.INSTANCE);
        } else {
            throw new RuntimeException("Not support!");
        }
    }

    private Object value3(InternalType type, Random rnd) {
        if (type.equals(InternalTypes.BOOLEAN)) {
            return true;
        } else if (type.equals(InternalTypes.BYTE)) {
            return Byte.MAX_VALUE;
        } else if (type.equals(InternalTypes.SHORT)) {
            return Short.MAX_VALUE;
        } else if (type.equals(InternalTypes.INT)) {
            return Integer.MAX_VALUE;
        } else if (type.equals(InternalTypes.LONG)) {
            return Long.MAX_VALUE;
        } else if (type.equals(InternalTypes.FLOAT)) {
            return Float.MAX_VALUE;
        } else if (type.equals(InternalTypes.DOUBLE)) {
            return Double.MAX_VALUE;
        } else if (type.equals(InternalTypes.STRING)) {
            return BinaryString.fromString(RandomStringUtils.random(100));
        } else if (type instanceof DecimalType) {
            DecimalType decimalType = (DecimalType) type;
            return Decimal.fromBigDecimal(new BigDecimal(Integer.MAX_VALUE), decimalType.precision(),
                    decimalType.scale());
        } else if (type instanceof ArrayType || type.equals(InternalTypes.BINARY)) {
            byte[] bytes = new byte[rnd.nextInt(100) + 100];
            rnd.nextBytes(bytes);
            return type.equals(InternalTypes.BINARY) ? bytes : BinaryArray.fromPrimitiveArray(bytes);
        } else if (type instanceof RowType) {
            RowType rowType = (RowType) type;
            if (rowType.getTypeAt(0).equals(InternalTypes.INT)) {
                return GenericRow.of(rnd.nextInt());
            } else {
                return GenericRow.of(GenericRow.of(rnd.nextInt()));
            }
        } else if (type instanceof GenericType) {
            return new BinaryGeneric<>(rnd.nextInt(), IntSerializer.INSTANCE);
        } else {
            throw new RuntimeException("Not support!");
        }
    }

    private InternalType[] getFieldTypes() {
        InternalType[] result = new InternalType[fields.length];
        for (int i = 0; i < fields.length; i++) {
            result[i] = types[fields[i]];
        }
        return result;
    }

    private InternalType[] getKeyTypes() {
        InternalType[] result = new InternalType[keys.length];
        for (int i = 0; i < keys.length; i++) {
            result[i] = types[fields[keys[i]]];
        }
        return result;
    }

    private void testInner() throws Exception {
        List<MemorySegment> segments = new ArrayList<>();
        for (int i = 0; i < 100; i++) {
            segments.add(MemorySegmentFactory.wrap(new byte[32768]));
        }

        InternalType[] fieldTypes = getFieldTypes();
        InternalType[] keyTypes = getKeyTypes();

        Tuple2<NormalizedKeyComputer, RecordComparator> tuple2 = getSortBaseWithNulls(
                this.getClass().getSimpleName(), keyTypes, keys, orders, nullsIsLast);

        BinaryRowSerializer serializer = new BinaryRowSerializer(fieldTypes.length);

        BinaryInMemorySortBuffer sortBuffer = BinaryInMemorySortBuffer.createBuffer(tuple2.f0,
                (AbstractRowSerializer) serializer, serializer, tuple2.f1, segments);

        BinaryRow[] dataArray = getTestData();

        List<BinaryRow> data = Arrays.asList(dataArray.clone());
        List<BinaryRow> binaryRows = Arrays.asList(dataArray.clone());
        Collections.shuffle(binaryRows);

        for (BinaryRow row : binaryRows) {
            if (!sortBuffer.write(row)) {
                throw new RuntimeException();
            }
        }

        new QuickSort().sort(sortBuffer);

        MutableObjectIterator<BinaryRow> iter = sortBuffer.getIterator();
        List<BinaryRow> result = new ArrayList<>();
        BinaryRow row = serializer.createInstance();
        while ((row = iter.next(row)) != null) {
            result.add(row.copy());
        }

        data.sort((o1, o2) -> {
            for (int i = 0; i < keys.length; i++) {
                InternalType t = types[fields[keys[i]]];
                boolean order = orders[i];
                Object first = null;
                Object second = null;
                if (!o1.isNullAt(keys[i])) {
                    first = TypeGetterSetters.get(o1, keys[i], keyTypes[i]);
                }
                if (!o2.isNullAt(keys[i])) {
                    second = TypeGetterSetters.get(o2, keys[i], keyTypes[i]);
                }

                if (first == null && second == null) {
                } else if (first == null) {
                    return order ? -1 : 1;
                } else if (second == null) {
                    return order ? 1 : -1;
                } else if (first instanceof Comparable) {
                    int ret = ((Comparable) first).compareTo(second);
                    if (ret != 0) {
                        return order ? ret : -ret;
                    }
                } else if (t instanceof ArrayType) {
                    BinaryArray leftArray = (BinaryArray) first;
                    BinaryArray rightArray = (BinaryArray) second;
                    int minLength = Math.min(leftArray.numElements(), rightArray.numElements());
                    for (int j = 0; j < minLength; j++) {
                        boolean isNullLeft = leftArray.isNullAt(j);
                        boolean isNullRight = rightArray.isNullAt(j);
                        if (isNullLeft && isNullRight) {
                            // Do nothing.
                        } else if (isNullLeft) {
                            return order ? -1 : 1;
                        } else if (isNullRight) {
                            return order ? 1 : -1;
                        } else {
                            int comp = Byte.compare(leftArray.getByte(j), rightArray.getByte(j));
                            if (comp != 0) {
                                return order ? comp : -comp;
                            }
                        }
                    }
                    if (leftArray.numElements() < rightArray.numElements()) {
                        return order ? -1 : 1;
                    } else if (leftArray.numElements() > rightArray.numElements()) {
                        return order ? 1 : -1;
                    }
                } else if (t.equals(InternalTypes.BINARY)) {
                    int comp = org.apache.flink.table.runtime.sort.SortUtil.compareBinary((byte[]) first,
                            (byte[]) second);
                    if (comp != 0) {
                        return order ? comp : -comp;
                    }
                } else if (t instanceof RowType) {
                    RowType rowType = (RowType) t;
                    int comp;
                    if (rowType.getTypeAt(0).equals(InternalTypes.INT)) {
                        comp = INT_ROW_COMP.compare(INT_ROW_CONV.toExternal(first),
                                INT_ROW_CONV.toExternal(second));
                    } else {
                        comp = NEST_ROW_COMP.compare(NEST_ROW_CONV.toExternal(first),
                                NEST_ROW_CONV.toExternal(second));
                    }
                    if (comp != 0) {
                        return order ? comp : -comp;
                    }
                } else if (t instanceof GenericType) {
                    Integer i1 = BinaryGeneric.getJavaObjectFromBinaryGeneric((BinaryGeneric) first,
                            IntSerializer.INSTANCE);
                    Integer i2 = BinaryGeneric.getJavaObjectFromBinaryGeneric((BinaryGeneric) second,
                            IntSerializer.INSTANCE);
                    int comp = Integer.compare(i1, i2);
                    if (comp != 0) {
                        return order ? comp : -comp;
                    }
                } else {
                    throw new RuntimeException();
                }
            }
            return 0;
        });

        StringBuilder builder = new StringBuilder();
        for (int i = 0; i < data.size(); i++) {
            builder.append("\n").append("expect: ").append(data.get(i).toOriginString(fieldTypes))
                    .append("; actual: ").append(result.get(i).toOriginString(fieldTypes));
        }
        builder.append("\n").append("types: ").append(Arrays.asList(fieldTypes));
        builder.append("\n").append("keys: ").append(Arrays.toString(keys));
        String msg = builder.toString();
        for (int i = 0; i < data.size(); i++) {
            for (int j = 0; j < keys.length; j++) {
                boolean isNull1 = data.get(i).isNullAt(keys[j]);
                boolean isNull2 = result.get(i).isNullAt(keys[j]);
                Assert.assertEquals(msg, isNull1, isNull2);
                if (!isNull1 || !isNull2) {
                    Object o1 = TypeGetterSetters.get(data.get(i), keys[j], keyTypes[j]);
                    Object o2 = TypeGetterSetters.get(result.get(i), keys[j], keyTypes[j]);
                    if (keyTypes[j].equals(InternalTypes.BINARY)) {
                        Assert.assertArrayEquals(msg, (byte[]) o1, (byte[]) o2);
                    } else {
                        Assert.assertEquals(msg, o1, o2);
                    }
                }
            }
        }
    }

    public static Tuple2<NormalizedKeyComputer, RecordComparator> getSortBaseWithNulls(String namePrefix,
            InternalType[] keyTypes, int[] keys, boolean[] orders, boolean[] nullsIsLast)
            throws IllegalAccessException, InstantiationException {
        SortCodeGenerator generator = new SortCodeGenerator(new TableConfig(), keys, keyTypes, orders, nullsIsLast);
        GeneratedNormalizedKeyComputer computer = generator.generateNormalizedKeyComputer(namePrefix + "Computer");
        GeneratedRecordComparator comparator = generator.generateRecordComparator(namePrefix + "Comparator");
        ClassLoader cl = Thread.currentThread().getContextClassLoader();
        return new Tuple2<>(computer.newInstance(cl), comparator.newInstance(cl));
    }
}