org.apache.flink.statistics.StatisticsRequest.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.flink.statistics.StatisticsRequest.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.statistics;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.io.IOReadableWritable;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.InputViewDataInputStreamWrapper;
import org.apache.flink.core.memory.OutputViewDataOutputStreamWrapper;
import org.apache.flink.core.statistics.TaskMonitorRequest;
import org.apache.flink.runtime.operators.util.CorruptConfigurationException;
import org.apache.flink.runtime.operators.util.TaskConfig;
import org.apache.flink.util.StringUtils;

/**
 * A collection of annotations which statistics the system should gather at execution time, if possible.
 *
 */
public class StatisticsRequest {
    private static final Log LOG = LogFactory.getLog(StatisticsRequest.class);

    // Wild card for any position
    public static final int UNSPECIFIED_POSITION = -1;

    // Enable general statistics collection for all Match keys.
    public static final String SPECIAL_COLLECT_MATCH_KEYS = "special.collectAllMatchKeys";
    // Wild card, indicating that every statistic should be collected
    public static final String COLLECT_ANY = "io.collectAny";

    public static final String COLLECT_HISTOGRAM = "io.histogram";
    public static final String COLLECT_MULTIHISTOGRAM = "io.multihistogram";
    public static final String COLLECT_BASE_STATISTICS = "io.base_stats";

    // List of all statistics enabled by using COLLECT_ANY
    public static final String[] STATISTICS_LIST = { COLLECT_HISTOGRAM, COLLECT_BASE_STATISTICS };

    @SuppressWarnings("unchecked")
    private Set<Entry>[] inputStats = new Set[2];
    @SuppressWarnings("unchecked")
    private Set<Entry>[] outputStats = new Set[1];
    private Set<String> specialStats = new HashSet<String>();

    public StatisticsRequest() {
        for (int i = 0; i < this.inputStats.length; i++)
            this.inputStats[i] = new HashSet<Entry>();

        for (int i = 0; i < this.outputStats.length; i++)
            this.outputStats[i] = new HashSet<Entry>();
    }

    @SuppressWarnings("unchecked")
    public StatisticsRequest(StatisticsRequest originalRequest) {
        this();
        if (originalRequest == null) {
            return;
        }

        this.inputStats = new Set[originalRequest.inputStats.length];
        for (int i = 0; i < this.inputStats.length; i++) {
            this.inputStats[i] = new HashSet<Entry>();
            this.inputStats[i].addAll(originalRequest.inputStats[i]);
        }

        this.outputStats = new Set[originalRequest.outputStats.length];
        for (int i = 0; i < this.outputStats.length; i++) {
            this.outputStats[i] = new HashSet<Entry>();
            this.outputStats[i].addAll(originalRequest.outputStats[i]);
        }

        this.specialStats.addAll(originalRequest.specialStats);
    }

    public void collectForInput(int inputNumber, int recordPos, String statName) {
        collectForInput(inputNumber, recordPos, statName, null);
    }

    public void collectForInput(int inputNumber, int recordPos, String statName, Configuration conf) {
        if (this.inputStats.length <= inputNumber) {
            int oldLen = this.inputStats.length;
            this.inputStats = Arrays.copyOf(this.inputStats, inputNumber + 1);
            for (int i = oldLen; i < this.inputStats.length; i++)
                this.inputStats[i] = new HashSet<Entry>();
        }
        this.inputStats[inputNumber].add(new Entry(recordPos, new ConfPair(statName, conf)));
    }

    public void removeInputStatistics(int inputNumber, Iterable<Entry> entries) {
        for (Entry entry : entries)
            this.inputStats[inputNumber].remove(entry);
    }

    public void removeInputStatistics(int inputNumber, int recordPos, String statName) {
        this.inputStats[inputNumber].remove(new Entry(recordPos, new ConfPair(statName, null)));
    }

    public void collectForOutput(int outputNumber, int recordPos, String statName) {
        collectForOutput(outputNumber, recordPos, statName, null);
    }

    public void collectForOutput(int outputNumber, int recordPos, String statName, Configuration conf) {
        if (this.outputStats.length <= outputNumber) {
            int oldLen = this.outputStats.length;
            this.outputStats = Arrays.copyOf(this.outputStats, outputNumber + 1);
            for (int i = oldLen; i < this.outputStats.length; i++)
                this.outputStats[i] = new HashSet<Entry>();
        }
        this.outputStats[outputNumber].add(new Entry(recordPos, new ConfPair(statName, conf)));
    }

    public void collectForOutput(int outputNumber, int[] recordPos, String statName, Configuration conf) {
        if (this.outputStats.length <= outputNumber) {
            int oldLen = this.outputStats.length;
            this.outputStats = Arrays.copyOf(this.outputStats, outputNumber + 1);
            for (int i = oldLen; i < this.outputStats.length; i++)
                this.outputStats[i] = new HashSet<Entry>();
        }
        this.outputStats[outputNumber].add(new Entry(recordPos, new ConfPair(statName, conf)));
    }

    public void removeOutputStatistics(int outputNumber, int recordPos, String statName) {
        this.outputStats[outputNumber].remove(new Entry(recordPos, new ConfPair(statName, null)));
    }

    public void removeOutputStatistics(int outputNumber, Iterable<Entry> entries) {
        for (Entry entry : entries)
            this.outputStats[outputNumber].remove(entry);
    }

    public int getNumInputs() {
        return this.inputStats.length;
    }

    public int getNumOutputs() {
        return this.outputStats.length;
    }

    public void collectSpecial(String specialStat) {
        this.specialStats.add(specialStat);
    }

    public void removeSpecial(String specialStat) {
        this.specialStats.remove(specialStat);
    }

    public boolean inputStatisticsRequested(int inputNumber) {
        return this.inputStats[inputNumber].size() > 0;
    }

    public boolean outputStatisticsRequested(int outputNumber) {
        return this.outputStats[outputNumber].size() > 0;
    }

    public Iterable<Entry> getAllInputStatistics(int inputNumber) {
        return this.inputStats[inputNumber];
    }

    @Deprecated
    public Iterable<Entry> getAllInputStatisticsByType(int inputNumber, String type) {
        Set<Entry> retVal = new HashSet<StatisticsRequest.Entry>();
        Set<Entry> set = this.inputStats[inputNumber];
        for (Entry e : set) {
            if (e.conf.stat.equals(type)) {
                retVal.add(e);
            }
        }

        return retVal;
    }

    public Iterable<Entry> getAllOutputStatistics(int outputNumber) {
        return this.outputStats[outputNumber];
    }

    /**
     * Returns a list of output indices for which a statistic of typeNamehas been requested.
     * At the mmomen, this list either contains 0 (for the only output) or is empty. 
     * However, to be flexible in case of multiple outputs, this function is implemented analog to this.getInputStatisticIndicesByType().
     * @param type
     * @return
     */
    public List<Integer> getOutputStatisticIndicesByType(String type) {
        List<Integer> retList = new ArrayList<Integer>();
        for (int i = 0; i < this.outputStats.length; i++) {
            if (this.outputStats[i] != null) {
                Set<Entry> set = this.outputStats[i];
                for (Entry e : set) {
                    if (e.conf.stat.equals(type)) {
                        retList.add(i);
                    }
                }
            }
        }
        return retList;
    }

    /**
     * Returns a list of input indices for which a statistic of typeNamehas been requested.
     * @param typeName
     * @return
     */
    public List<Integer> getInputStatisticIndicesByType(String typeName) {
        List<Integer> retList = new ArrayList<Integer>();
        for (int i = 0; i < this.inputStats.length; i++) {
            if (this.inputStats[i] != null) {
                Set<Entry> set = this.inputStats[i];
                for (Entry e : set) {
                    if (e.conf.stat.equals(typeName)) {
                        retList.add(i);
                    }
                }
            }
        }
        return retList;
    }

    public Iterable<String> getSpecialStatistics() {
        return this.specialStats;
    }

    private static Set<Entry> readSet(DataInputView in) throws IOException {
        Set<Entry> set = new HashSet<Entry>();
        for (int numEntries = in.readInt(); numEntries > 0; --numEntries) {
            set.add(new Entry(in));
        }
        return set;
    }

    private static Set<Entry>[] readArray(DataInputView in) throws IOException {
        int numEntries = in.readInt();
        @SuppressWarnings("unchecked")
        Set<Entry>[] sets = new Set[numEntries];
        for (int i = 0; i < numEntries; i++)
            sets[i] = readSet(in);
        return sets;
    }

    private static void writeSet(Set<Entry> set, DataOutputView out) throws IOException {
        out.writeInt(set.size());
        for (Entry entry : set)
            entry.write(out);
    }

    private static void writeArray(Set<Entry>[] sets, DataOutputView out) throws IOException {
        out.writeInt(sets.length);
        for (int i = 0; i < sets.length; i++) {
            Set<Entry> set = sets[i];
            writeSet(set, out);
        }
    }

    public byte[] getBytes() {

        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        DataOutputStream dos = new DataOutputStream(bos);
        DataOutputView outputView = new OutputViewDataOutputStreamWrapper(dos);

        try {
            writeArray(this.inputStats, outputView);
            writeArray(this.outputStats, outputView);

            outputView.writeInt(this.specialStats.size());

            for (String entry : this.specialStats)
                outputView.writeUTF(entry);
            dos.close();
        } catch (IOException ioe) {
            LOG.error("Serialisation error: " + StringUtils.stringifyException(ioe));
            return null;
        }

        return bos.toByteArray();
    }

    public void write(Configuration parameters, String key) {
        parameters.setBytes(key, this.getBytes());
    }

    public static StatisticsRequest readBytes(byte[] statReqBytes) {
        StatisticsRequest sr = new StatisticsRequest();

        ByteArrayInputStream bis = new ByteArrayInputStream(statReqBytes);
        DataInputStream dis = new DataInputStream(bis);
        DataInputView inputView = new InputViewDataInputStreamWrapper(dis);
        try {

            sr.inputStats = readArray(inputView);
            sr.outputStats = readArray(inputView);

            for (int numEntries = inputView.readInt(); numEntries > 0; --numEntries)
                sr.specialStats.add(inputView.readUTF());

            dis.close();
        } catch (IOException ioe) {
            LOG.error("Deserialisation error: " + StringUtils.stringifyException(ioe));
            return null;
        }

        return sr;
    }

    public static StatisticsRequest read(Configuration parameters, String key) {
        byte[] statReqBytes = parameters.getBytes(key, null);
        if (statReqBytes == null) {
            throw new CorruptConfigurationException("No StatisticsRequest was set.");
        }
        return StatisticsRequest.readBytes(statReqBytes);
    }

    public static StatisticsRequest readFromTask(TaskConfig config) {
        int pos = config.findTaskMonitorRequest(StatisticTaskMonitor.DEFAULT_NAME);
        if (pos == -1) {
            return null;
        }
        return read(config.getTaskMonitorParameters(pos), StatisticTaskMonitor.STATISTICS_REQUEST);
    }

    public static StatisticsRequest readFromOperator(Operator op) {
        TaskMonitorRequest monitor = op.getMonitorRequests().getRequest(StatisticTaskMonitor.DEFAULT_NAME);
        if (monitor == null) {
            return new StatisticsRequest();
        }
        return read(monitor.getParameters(), StatisticTaskMonitor.STATISTICS_REQUEST);
    }

    public static class Entry implements IOReadableWritable {
        public int[] keyPos;
        public ConfPair conf;

        public Entry() {

        }

        public Entry(DataInputView in) throws IOException {
            this.read(in);
        }

        public Entry(int[] keyPos, ConfPair conf) {
            this.keyPos = keyPos;
            this.conf = conf;
        }

        public Entry(int keyPos, ConfPair conf) {
            this.keyPos = new int[1];
            this.keyPos[0] = keyPos;
            this.conf = conf;
        }

        @Override
        public boolean equals(final Object obj) {
            if (obj == null) {
                return false;
            }
            if (obj == this) {
                return true;
            }
            final Entry other = (Entry) obj;
            return Arrays.equals(this.keyPos, other.keyPos) && this.conf.equals(other.conf);
        }

        @Override
        public int hashCode() {
            return Arrays.hashCode(this.keyPos) ^ ~this.conf.hashCode();
        }

        @Override
        public void write(DataOutputView out) throws IOException {
            out.writeInt(this.keyPos.length);
            for (int i = 0; i < this.keyPos.length; i++) {
                out.writeInt(this.keyPos[i]);
            }
            this.conf.write(out);
        }

        @Override
        public void read(DataInputView in) throws IOException {
            this.keyPos = new int[in.readInt()];
            for (int i = 0; i < this.keyPos.length; i++)
                this.keyPos[i] = in.readInt();
            this.conf = new ConfPair(in);
        }

        @Override
        public String toString() {
            StringBuilder b = new StringBuilder();
            b.append("[');");
            for (int i = 0; i < this.keyPos.length; i++) {
                if (i > 0) {
                    b.append(", ");
                }
                b.append(this.keyPos[i]);
            }
            b.append("]: ");
            b.append(conf.toString());
            return b.toString();
        }
    }

    public static class ConfPair implements IOReadableWritable {
        public String stat;
        public Configuration conf;

        public ConfPair(DataInputView in) throws IOException {
            this.read(in);
        }

        public ConfPair(String stat, Configuration conf) {
            this.stat = stat;
            this.conf = conf;
        }

        @Override
        public boolean equals(Object obj) {
            if (!(obj instanceof ConfPair)) {
                return false;
            }
            ConfPair other = (ConfPair) obj;
            return this.stat != null && this.stat.equals(other.stat);
        }

        @Override
        public void write(DataOutputView out) throws IOException {
            if (this.stat != null) {
                out.writeBoolean(true);
                out.writeUTF(this.stat);
            } else {
                out.writeBoolean(false);
            }

            if (this.conf != null) {
                out.writeBoolean(true);
                this.conf.write(out);
            } else {
                out.writeBoolean(false);
            }
        }

        @Override
        public int hashCode() {
            int hashCode = 0;
            if (this.stat != null) {
                hashCode = this.stat.hashCode();
            }
            if (this.conf != null) {
                hashCode ^= ~this.conf.hashCode();
            }
            return hashCode;
        }

        @Override
        public void read(DataInputView in) throws IOException {
            this.stat = null;
            if (in.readBoolean()) {
                this.stat = in.readUTF();
            }
            this.conf = new Configuration();
            if (in.readBoolean()) {
                this.conf.read(in);
            } else {
                this.conf = null;
            }
        }

        @Override
        public String toString() {
            return this.stat;
        }
    }

    public void writeToTask(TaskConfig config) {
        int pos = config.findTaskMonitorRequest(StatisticTaskMonitor.DEFAULT_NAME);
        if (pos == -1) {
            config.addTaskMonitorRequest(this.asMonitorRequest());
        } else {
            config.getTaskMonitorParameters(pos).setBytes(StatisticTaskMonitor.STATISTICS_REQUEST, this.getBytes());
        }
    }

    protected TaskMonitorRequest asMonitorRequest() {
        TaskMonitorRequest monitorRequest = new TaskMonitorRequest(StatisticTaskMonitor.DEFAULT_NAME,
                StatisticTaskMonitor.class);
        monitorRequest.getParameters().setBytes(StatisticTaskMonitor.STATISTICS_REQUEST, this.getBytes());
        return monitorRequest;
    }

    public void insertInto(DataSet op) {
        op.getMonitorRequests().addRequest(this.asMonitorRequest());
    }

    protected static final byte[] getConfBytes(TaskConfig config, String confRequestKey) {
        int pos = config.findTaskMonitorRequest(StatisticTaskMonitor.DEFAULT_NAME);
        if (pos == -1) {
            return null;
        }
        return config.getTaskMonitorParameters(pos).getBytes(confRequestKey, null);
    }
}