org.apache.flink.api.java.Utils.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.flink.api.java.Utils.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.api.java;

import org.apache.commons.lang3.StringUtils;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.accumulators.SerializedListAccumulator;
import org.apache.flink.api.common.accumulators.SimpleAccumulator;
import org.apache.flink.api.common.io.RichOutputFormat;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.typeutils.GenericTypeInfo;
import org.apache.flink.configuration.Configuration;

import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.Random;

import static org.apache.flink.api.java.functions.FunctionAnnotation.SkipCodeAnalysis;

/**
 * Utility class that contains helper methods to work with Java APIs.
 */
@Internal
public final class Utils {

    public static final Random RNG = new Random();

    public static String getCallLocationName() {
        return getCallLocationName(4);
    }

    public static String getCallLocationName(int depth) {
        StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();

        if (stackTrace.length < depth) {
            return "<unknown>";
        }

        StackTraceElement elem = stackTrace[depth];

        return String.format("%s(%s:%d)", elem.getMethodName(), elem.getFileName(), elem.getLineNumber());
    }

    // --------------------------------------------------------------------------------------------

    /**
     * Utility sink function that counts elements and writes the count into an accumulator,
     * from which it can be retrieved by the client. This sink is used by the
     * {@link DataSet#count()} function.
     * 
     * @param <T> Type of elements to count.
     */
    @SkipCodeAnalysis
    public static class CountHelper<T> extends RichOutputFormat<T> {

        private static final long serialVersionUID = 1L;

        private final String id;
        private long counter;

        public CountHelper(String id) {
            this.id = id;
            this.counter = 0L;
        }

        @Override
        public void configure(Configuration parameters) {
        }

        @Override
        public void open(int taskNumber, int numTasks) {
        }

        @Override
        public void writeRecord(T record) {
            counter++;
        }

        @Override
        public void close() {
            getRuntimeContext().getLongCounter(id).add(counter);
        }
    }

    /**
     * Utility sink function that collects elements into an accumulator,
     * from which it they can be retrieved by the client. This sink is used by the
     * {@link DataSet#collect()} function.
     *
     * @param <T> Type of elements to count.
     */
    @SkipCodeAnalysis
    public static class CollectHelper<T> extends RichOutputFormat<T> {

        private static final long serialVersionUID = 1L;

        private final String id;
        private final TypeSerializer<T> serializer;

        private SerializedListAccumulator<T> accumulator;

        public CollectHelper(String id, TypeSerializer<T> serializer) {
            this.id = id;
            this.serializer = serializer;
        }

        @Override
        public void configure(Configuration parameters) {
        }

        @Override
        public void open(int taskNumber, int numTasks) {
            this.accumulator = new SerializedListAccumulator<>();
        }

        @Override
        public void writeRecord(T record) throws IOException {
            accumulator.add(record, serializer);
        }

        @Override
        public void close() {
            // Important: should only be added in close method to minimize traffic of accumulators
            getRuntimeContext().addAccumulator(id, accumulator);
        }
    }

    public static class ChecksumHashCode implements SimpleAccumulator<ChecksumHashCode> {

        private static final long serialVersionUID = 1L;

        private long count;
        private long checksum;

        public ChecksumHashCode() {
        }

        public ChecksumHashCode(long count, long checksum) {
            this.count = count;
            this.checksum = checksum;
        }

        public long getCount() {
            return count;
        }

        public long getChecksum() {
            return checksum;
        }

        @Override
        public void add(ChecksumHashCode value) {
            this.count += value.count;
            this.checksum += value.checksum;
        }

        @Override
        public ChecksumHashCode getLocalValue() {
            return this;
        }

        @Override
        public void resetLocal() {
            this.count = 0;
            this.checksum = 0;
        }

        @Override
        public void merge(Accumulator<ChecksumHashCode, ChecksumHashCode> other) {
            this.add(other.getLocalValue());
        }

        @Override
        public ChecksumHashCode clone() {
            return new ChecksumHashCode(count, checksum);
        }

        @Override
        public boolean equals(Object obj) {
            if (obj instanceof ChecksumHashCode) {
                ChecksumHashCode other = (ChecksumHashCode) obj;
                return this.count == other.count && this.checksum == other.checksum;
            } else {
                return false;
            }
        }

        @Override
        public int hashCode() {
            return (int) (this.count + this.checksum);
        }

        @Override
        public String toString() {
            return String.format("ChecksumHashCode 0x%016x, count %d", this.checksum, this.count);
        }
    }

    @SkipCodeAnalysis
    public static class ChecksumHashCodeHelper<T> extends RichOutputFormat<T> {

        private static final long serialVersionUID = 1L;

        private final String id;
        private long counter;
        private long checksum;

        public ChecksumHashCodeHelper(String id) {
            this.id = id;
            this.counter = 0L;
            this.checksum = 0L;
        }

        @Override
        public void configure(Configuration parameters) {
        }

        @Override
        public void open(int taskNumber, int numTasks) {
        }

        @Override
        public void writeRecord(T record) throws IOException {
            counter++;
            // convert 32-bit integer to non-negative long
            checksum += record.hashCode() & 0xffffffffL;
        }

        @Override
        public void close() throws IOException {
            ChecksumHashCode update = new ChecksumHashCode(counter, checksum);
            getRuntimeContext().addAccumulator(id, update);
        }
    }

    // --------------------------------------------------------------------------------------------

    /**
     * Debugging utility to understand the hierarchy of serializers created by the Java API.
     * Tested in GroupReduceITCase.testGroupByGenericType()
     */
    public static <T> String getSerializerTree(TypeInformation<T> ti) {
        return getSerializerTree(ti, 0);
    }

    private static <T> String getSerializerTree(TypeInformation<T> ti, int indent) {
        String ret = "";
        if (ti instanceof CompositeType) {
            ret += StringUtils.repeat(' ', indent) + ti.getClass().getSimpleName() + "\n";
            CompositeType<T> cti = (CompositeType<T>) ti;
            String[] fieldNames = cti.getFieldNames();
            for (int i = 0; i < cti.getArity(); i++) {
                TypeInformation<?> fieldType = cti.getTypeAt(i);
                ret += StringUtils.repeat(' ', indent + 2) + fieldNames[i] + ":"
                        + getSerializerTree(fieldType, indent);
            }
        } else {
            if (ti instanceof GenericTypeInfo) {
                ret += StringUtils.repeat(' ', indent) + "GenericTypeInfo (" + ti.getTypeClass().getSimpleName()
                        + ")\n";
                ret += getGenericTypeTree(ti.getTypeClass(), indent + 4);
            } else {
                ret += StringUtils.repeat(' ', indent) + ti.toString() + "\n";
            }
        }
        return ret;
    }

    private static String getGenericTypeTree(Class<?> type, int indent) {
        String ret = "";
        for (Field field : type.getDeclaredFields()) {
            if (Modifier.isStatic(field.getModifiers()) || Modifier.isTransient(field.getModifiers())) {
                continue;
            }
            ret += StringUtils.repeat(' ', indent) + field.getName() + ":" + field.getType().getName()
                    + (field.getType().isEnum() ? " (is enum)" : "") + "\n";
            if (!field.getType().isPrimitive()) {
                ret += getGenericTypeTree(field.getType(), indent + 4);
            }
        }
        return ret;
    }

    /**
     * Private constructor to prevent instantiation.
     */
    private Utils() {
        throw new RuntimeException();
    }
}