org.apache.flink.runtime.operators.util.OutputEmitterTest.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.flink.runtime.operators.util.OutputEmitterTest.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.operators.util;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;

import junit.framework.Assert;
import junit.framework.TestCase;

import org.apache.commons.lang.NotImplementedException;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.typeutils.runtime.record.RecordComparatorFactory;
import org.apache.flink.api.java.typeutils.runtime.record.RecordSerializerFactory;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.InputViewDataInputStreamWrapper;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.OutputViewDataOutputStreamWrapper;
import org.apache.flink.runtime.io.network.api.ChannelSelector;
import org.apache.flink.runtime.operators.shipping.OutputEmitter;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.apache.flink.runtime.plugable.SerializationDelegate;
import org.apache.flink.types.DeserializationException;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.KeyFieldOutOfBoundsException;
import org.apache.flink.types.NullKeyFieldException;
import org.apache.flink.types.Record;
import org.apache.flink.types.StringValue;
import org.junit.Test;

public class OutputEmitterTest extends TestCase {

    //   private static final long SEED = 485213591485399L;

    @Test
    public void testPartitionHash() {
        // Test for IntValue
        @SuppressWarnings("unchecked")
        final TypeComparator<Record> intComp = new RecordComparatorFactory(new int[] { 0 },
                new Class[] { IntValue.class }).createComparator();
        final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(
                ShipStrategyType.PARTITION_HASH, intComp);
        final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(
                new RecordSerializerFactory().getSerializer());

        int numChans = 100;
        int numRecs = 50000;
        int[] hit = new int[numChans];

        for (int i = 0; i < numRecs; i++) {
            IntValue k = new IntValue(i);
            Record rec = new Record(k);

            delegate.setInstance(rec);

            int[] chans = oe1.selectChannels(delegate, hit.length);
            for (int j = 0; j < chans.length; j++) {
                hit[chans[j]]++;
            }
        }

        int cnt = 0;
        for (int i = 0; i < hit.length; i++) {
            assertTrue(hit[i] > 0);
            cnt += hit[i];
        }
        assertTrue(cnt == numRecs);

        // Test for StringValue
        @SuppressWarnings("unchecked")
        final TypeComparator<Record> stringComp = new RecordComparatorFactory(new int[] { 0 },
                new Class[] { StringValue.class }).createComparator();
        final ChannelSelector<SerializationDelegate<Record>> oe2 = new OutputEmitter<Record>(
                ShipStrategyType.PARTITION_HASH, stringComp);

        numChans = 100;
        numRecs = 10000;

        hit = new int[numChans];

        for (int i = 0; i < numRecs; i++) {
            StringValue k = new StringValue(i + "");
            Record rec = new Record(k);
            delegate.setInstance(rec);

            int[] chans = oe2.selectChannels(delegate, hit.length);
            for (int j = 0; j < chans.length; j++) {
                hit[chans[j]]++;
            }
        }

        cnt = 0;
        for (int i = 0; i < hit.length; i++) {
            assertTrue(hit[i] > 0);
            cnt += hit[i];
        }
        assertTrue(cnt == numRecs);

        // test hash corner cases
        final TestIntComparator testIntComp = new TestIntComparator();
        final ChannelSelector<SerializationDelegate<Integer>> oe3 = new OutputEmitter<Integer>(
                ShipStrategyType.PARTITION_HASH, testIntComp);
        final SerializationDelegate<Integer> intDel = new SerializationDelegate<Integer>(new IntSerializer());

        numChans = 100;

        // MinVal hash
        intDel.setInstance(Integer.MIN_VALUE);
        int[] chans = oe3.selectChannels(intDel, numChans);
        assertTrue(chans.length == 1);
        assertTrue(chans[0] >= 0 && chans[0] <= numChans - 1);

        // -1 hash
        intDel.setInstance(-1);
        chans = oe3.selectChannels(intDel, hit.length);
        assertTrue(chans.length == 1);
        assertTrue(chans[0] >= 0 && chans[0] <= numChans - 1);

        // 0 hash
        intDel.setInstance(0);
        chans = oe3.selectChannels(intDel, hit.length);
        assertTrue(chans.length == 1);
        assertTrue(chans[0] >= 0 && chans[0] <= numChans - 1);

        // 1 hash
        intDel.setInstance(1);
        chans = oe3.selectChannels(intDel, hit.length);
        assertTrue(chans.length == 1);
        assertTrue(chans[0] >= 0 && chans[0] <= numChans - 1);

        // MaxVal hash
        intDel.setInstance(Integer.MAX_VALUE);
        chans = oe3.selectChannels(intDel, hit.length);
        assertTrue(chans.length == 1);
        assertTrue(chans[0] >= 0 && chans[0] <= numChans - 1);
    }

    @Test
    public void testForward() {
        // Test for IntValue
        @SuppressWarnings("unchecked")
        final TypeComparator<Record> intComp = new RecordComparatorFactory(new int[] { 0 },
                new Class[] { IntValue.class }).createComparator();
        final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(
                ShipStrategyType.FORWARD, intComp);
        final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(
                new RecordSerializerFactory().getSerializer());

        int numChannels = 100;
        int numRecords = 50000;

        int[] hit = new int[numChannels];

        for (int i = 0; i < numRecords; i++) {
            IntValue k = new IntValue(i);
            Record rec = new Record(k);
            delegate.setInstance(rec);

            int[] chans = oe1.selectChannels(delegate, hit.length);
            for (int j = 0; j < chans.length; j++) {
                hit[chans[j]]++;
            }
        }

        int cnt = 0;
        for (int i = 0; i < hit.length; i++) {
            assertTrue(hit[i] == (numRecords / numChannels) || hit[i] == (numRecords / numChannels) - 1);
            cnt += hit[i];
        }
        assertTrue(cnt == numRecords);

        // Test for StringValue
        @SuppressWarnings("unchecked")
        final TypeComparator<Record> stringComp = new RecordComparatorFactory(new int[] { 0 },
                new Class[] { StringValue.class }).createComparator();
        final ChannelSelector<SerializationDelegate<Record>> oe2 = new OutputEmitter<Record>(
                ShipStrategyType.FORWARD, stringComp);

        numChannels = 100;
        numRecords = 10000;

        hit = new int[numChannels];

        for (int i = 0; i < numRecords; i++) {
            StringValue k = new StringValue(i + "");
            Record rec = new Record(k);
            delegate.setInstance(rec);

            int[] chans = oe2.selectChannels(delegate, hit.length);
            for (int j = 0; j < chans.length; j++) {
                hit[chans[j]]++;
            }
        }

        cnt = 0;
        for (int i = 0; i < hit.length; i++) {
            assertTrue(hit[i] == (numRecords / numChannels) || hit[i] == (numRecords / numChannels) - 1);
            cnt += hit[i];
        }
        assertTrue(cnt == numRecords);

    }

    @Test
    public void testBroadcast() {
        // Test for IntValue
        @SuppressWarnings("unchecked")
        final TypeComparator<Record> intComp = new RecordComparatorFactory(new int[] { 0 },
                new Class[] { IntValue.class }).createComparator();
        final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(
                ShipStrategyType.BROADCAST, intComp);
        final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(
                new RecordSerializerFactory().getSerializer());

        int numChannels = 100;
        int numRecords = 50000;

        int[] hit = new int[numChannels];

        for (int i = 0; i < numRecords; i++) {
            IntValue k = new IntValue(i);
            Record rec = new Record(k);
            delegate.setInstance(rec);

            int[] chans = oe1.selectChannels(delegate, hit.length);
            for (int j = 0; j < chans.length; j++) {
                hit[chans[j]]++;
            }
        }

        for (int i = 0; i < hit.length; i++) {
            assertTrue(hit[i] + "", hit[i] == numRecords);
        }

        // Test for StringValue
        @SuppressWarnings("unchecked")
        final TypeComparator<Record> stringComp = new RecordComparatorFactory(new int[] { 0 },
                new Class[] { StringValue.class }).createComparator();
        final ChannelSelector<SerializationDelegate<Record>> oe2 = new OutputEmitter<Record>(
                ShipStrategyType.BROADCAST, stringComp);

        numChannels = 100;
        numRecords = 5000;

        hit = new int[numChannels];

        for (int i = 0; i < numRecords; i++) {
            StringValue k = new StringValue(i + "");
            Record rec = new Record(k);
            delegate.setInstance(rec);

            int[] chans = oe2.selectChannels(delegate, hit.length);
            for (int j = 0; j < chans.length; j++) {
                hit[chans[j]]++;
            }
        }

        for (int i = 0; i < hit.length; i++) {
            assertTrue(hit[i] + "", hit[i] == numRecords);
        }
    }

    @Test
    public void testMultiKeys() {
        @SuppressWarnings("unchecked")
        final TypeComparator<Record> multiComp = new RecordComparatorFactory(new int[] { 0, 1, 3 },
                new Class[] { IntValue.class, StringValue.class, DoubleValue.class }).createComparator();
        final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(
                ShipStrategyType.PARTITION_HASH, multiComp);
        final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(
                new RecordSerializerFactory().getSerializer());

        int numChannels = 100;
        int numRecords = 5000;

        int[] hit = new int[numChannels];

        for (int i = 0; i < numRecords; i++) {
            Record rec = new Record(4);
            rec.setField(0, new IntValue(i));
            rec.setField(1, new StringValue("AB" + i + "CD" + i));
            rec.setField(3, new DoubleValue(i * 3.141d));
            delegate.setInstance(rec);

            int[] chans = oe1.selectChannels(delegate, hit.length);
            for (int j = 0; j < chans.length; j++) {
                hit[chans[j]]++;
            }
        }

        int cnt = 0;
        for (int i = 0; i < hit.length; i++) {
            assertTrue(hit[i] > 0);
            cnt += hit[i];
        }
        assertTrue(cnt == numRecords);

    }

    @Test
    public void testMissingKey() {
        // Test for IntValue
        @SuppressWarnings("unchecked")
        final TypeComparator<Record> intComp = new RecordComparatorFactory(new int[] { 1 },
                new Class[] { IntValue.class }).createComparator();
        final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(
                ShipStrategyType.PARTITION_HASH, intComp);
        final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(
                new RecordSerializerFactory().getSerializer());

        Record rec = new Record(0);
        rec.setField(0, new IntValue(1));
        delegate.setInstance(rec);

        try {
            oe1.selectChannels(delegate, 100);
        } catch (KeyFieldOutOfBoundsException re) {
            Assert.assertEquals(1, re.getFieldNumber());
            return;
        }
        Assert.fail("Expected a KeyFieldOutOfBoundsException.");
    }

    @Test
    public void testNullKey() {
        // Test for IntValue
        @SuppressWarnings("unchecked")
        final TypeComparator<Record> intComp = new RecordComparatorFactory(new int[] { 0 },
                new Class[] { IntValue.class }).createComparator();
        final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(
                ShipStrategyType.PARTITION_HASH, intComp);
        final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(
                new RecordSerializerFactory().getSerializer());

        Record rec = new Record(2);
        rec.setField(1, new IntValue(1));
        delegate.setInstance(rec);

        try {
            oe1.selectChannels(delegate, 100);
        } catch (NullKeyFieldException re) {
            Assert.assertEquals(0, re.getFieldNumber());
            return;
        }
        Assert.fail("Expected a NullKeyFieldException.");
    }

    @Test
    public void testWrongKeyClass() {

        // Test for IntValue
        @SuppressWarnings("unchecked")
        final TypeComparator<Record> doubleComp = new RecordComparatorFactory(new int[] { 0 },
                new Class[] { DoubleValue.class }).createComparator();
        final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(
                ShipStrategyType.PARTITION_HASH, doubleComp);
        final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(
                new RecordSerializerFactory().getSerializer());

        PipedInputStream pipedInput = new PipedInputStream(1024 * 1024);
        DataInputStream in = new DataInputStream(pipedInput);
        DataOutputStream out;
        Record rec = null;

        try {
            out = new DataOutputStream(new PipedOutputStream(pipedInput));

            rec = new Record(1);
            rec.setField(0, new IntValue());

            rec.write(new OutputViewDataOutputStreamWrapper(out));
            rec = new Record();
            rec.read(new InputViewDataInputStreamWrapper(in));

        } catch (IOException e) {
            fail("Test erroneous");
        }

        try {
            delegate.setInstance(rec);
            oe1.selectChannels(delegate, 100);
        } catch (DeserializationException re) {
            return;
        }
        Assert.fail("Expected a NullKeyFieldException.");
    }

    @SuppressWarnings("serial")
    private static class TestIntComparator extends TypeComparator<Integer> {

        @Override
        public int hash(Integer record) {
            return record;
        }

        @Override
        public void setReference(Integer toCompare) {
            throw new NotImplementedException();
        }

        @Override
        public boolean equalToReference(Integer candidate) {
            throw new NotImplementedException();
        }

        @Override
        public int compareToReference(TypeComparator<Integer> referencedComparator) {
            throw new NotImplementedException();
        }

        @Override
        public int compare(Integer first, Integer second) {
            throw new NotImplementedException();
        }

        @Override
        public int compare(DataInputView firstSource, DataInputView secondSource) {
            throw new NotImplementedException();
        }

        @Override
        public boolean supportsNormalizedKey() {
            throw new NotImplementedException();
        }

        @Override
        public boolean supportsSerializationWithKeyNormalization() {
            throw new NotImplementedException();
        }

        @Override
        public int getNormalizeKeyLen() {
            throw new NotImplementedException();
        }

        @Override
        public boolean isNormalizedKeyPrefixOnly(int keyBytes) {
            throw new NotImplementedException();
        }

        @Override
        public void putNormalizedKey(Integer record, MemorySegment target, int offset, int numBytes) {
            throw new NotImplementedException();
        }

        @Override
        public void writeWithKeyNormalization(Integer record, DataOutputView target) throws IOException {
            throw new NotImplementedException();
        }

        @Override
        public Integer readWithKeyDenormalization(Integer reuse, DataInputView source) throws IOException {
            throw new NotImplementedException();
        }

        @Override
        public boolean invertNormalizedKey() {
            throw new NotImplementedException();
        }

        @Override
        public TypeComparator<Integer> duplicate() {
            throw new NotImplementedException();
        }

    }

    //   @Test
    //   public void testPartitionRange() {
    //      final Random rnd = new Random(SEED);
    //      
    //      final int DISTR_MIN = 0;
    //      final int DISTR_MAX = 1000000;
    //      final int DISTR_RANGE = DISTR_MAX - DISTR_MIN + 1;
    //      final int NUM_BUCKETS = 137;
    //      final float BUCKET_WIDTH = DISTR_RANGE / ((float) NUM_BUCKETS);
    //      
    //      final int NUM_ELEMENTS = 10000000;
    //      
    //      final DataDistribution distri = new UniformIntegerDistribution(DISTR_MIN, DISTR_MAX);
    //      
    //      @SuppressWarnings("unchecked")
    //      final TypeComparator<Record> intComp = new RecordComparatorFactory(new int[] {0}, new Class[] {IntValue.class}).createComparator();
    //      final ChannelSelector<SerializationDelegate<Record>> oe = new OutputEmitter<Record>(ShipStrategyType.PARTITION_RANGE, intComp, distri);
    //      final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(new RecordSerializerFactory().getSerializer());
    //      
    //      final IntValue integer = new IntValue();
    //      final Record rec = new Record();
    //      
    //      for (int i = 0; i < NUM_ELEMENTS; i++) {
    //         final int nextValue = rnd.nextInt(DISTR_RANGE) + DISTR_MIN;
    //         integer.setValue(nextValue);
    //         rec.setField(0, integer);
    //         delegate.setInstance(rec);
    //         
    //         final int[] channels = oe.selectChannels(delegate, NUM_BUCKETS);
    //         if (channels.length != 1) {
    //            Assert.fail("Resulting channels array has more than one channel.");
    //         }
    //         
    //         final int bucket = channels[0];
    //         final int shouldBeBucket = (int) ((nextValue - DISTR_MIN) / BUCKET_WIDTH);
    //         
    //         if (shouldBeBucket != bucket) {
    //            // we may have a rounding imprecision in the 'should be bucket' computation.
    //            final int lowerBoundaryForSelectedBucket = DISTR_MIN + (int) ((bucket    ) * BUCKET_WIDTH);
    //            final int upperBoundaryForSelectedBucket = DISTR_MIN + (int) ((bucket + 1) * BUCKET_WIDTH);
    //            if (nextValue <= lowerBoundaryForSelectedBucket || nextValue > upperBoundaryForSelectedBucket) {
    //               Assert.fail("Wrong bucket selected");
    //            }
    //         }
    //         
    //      }
    //   }

}