com.datatorrent.contrib.kinesis.KinesisConsumer.java Source code

Java tutorial

Introduction

Here is the source code for com.datatorrent.contrib.kinesis.KinesisConsumer.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package com.datatorrent.contrib.kinesis;

import java.io.Closeable;
import java.io.Serializable;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import javax.validation.constraints.NotNull;
import javax.validation.constraints.Pattern;
import javax.validation.constraints.Pattern.Flag;

import com.amazonaws.services.kinesis.model.Record;
import com.amazonaws.services.kinesis.model.Shard;
import com.amazonaws.services.kinesis.model.ShardIteratorType;
import com.google.common.collect.Maps;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.commons.io.IOUtils;

import com.datatorrent.common.util.Pair;

/**
 *
 * Base KinesisConsumer class used by kinesis input operator
 * @since 2.0.0
 */

public class KinesisConsumer implements Closeable {
    private static final Logger logger = LoggerFactory.getLogger(KinesisConsumer.class);

    protected Integer recordsLimit = 100;

    private int bufferSize = 1024;

    protected transient boolean isAlive = false;

    private transient ArrayBlockingQueue<Pair<String, Record>> holdingBuffer;

    /**
     * The streamName that this consumer consumes
     */
    @NotNull
    protected String streamName;

    /**
     * The initialOffset could be either earliest or latest
     * Earliest means the beginning the shard
     * Latest means the current record to consume from the shard
     * By default it always consume from the beginning of the shard
     */
    @Pattern(flags = { Flag.CASE_INSENSITIVE }, regexp = "earliest|latest")
    protected String initialOffset = "latest";

    protected transient ExecutorService consumerThreadExecutor = null;

    protected ConcurrentHashMap<String, String> shardPosition = new ConcurrentHashMap<String, String>();
    protected final transient HashSet<Shard> simpleConsumerThreads = new HashSet<Shard>();
    private Set<String> shardIds = new HashSet<String>();
    protected Set<Shard> closedShards = new HashSet<Shard>();
    protected long recordsCheckInterval = 500;

    protected transient KinesisShardStats stats = new KinesisShardStats();

    public KinesisConsumer() {
    }

    public KinesisConsumer(String streamName) {
        this.streamName = streamName;
    }

    public KinesisConsumer(String streamName, Set<String> newShardIds) {
        this(streamName);
        shardIds = newShardIds;
    }

    /**
     * This method is called in setup method of the operator
     */
    public void create() {
        holdingBuffer = new ArrayBlockingQueue<Pair<String, Record>>(bufferSize);
        boolean defaultSelect = (shardIds == null) || (shardIds.size() == 0);
        final List<Shard> pms = KinesisUtil.getInstance().getShardList(streamName);
        for (final Shard shId : pms) {
            if ((shardIds.contains(shId.getShardId()) || defaultSelect) && !closedShards.contains(shId)) {
                simpleConsumerThreads.add(shId);
            }
        }
    }

    /**
     * This method returns the iterator type of the given shard
     */
    public ShardIteratorType getIteratorType(String shardId) {
        if (shardPosition.containsKey(shardId)) {
            return ShardIteratorType.AFTER_SEQUENCE_NUMBER;
        }
        return initialOffset.equalsIgnoreCase("earliest") ? ShardIteratorType.TRIM_HORIZON
                : ShardIteratorType.LATEST;
    }

    /**
     * This method is called in the activate method of the operator
     */
    public void start() {
        isAlive = true;
        int realNumStream = simpleConsumerThreads.size();
        if (realNumStream == 0)
            return;

        consumerThreadExecutor = Executors.newFixedThreadPool(realNumStream);
        for (final Shard shd : simpleConsumerThreads) {
            consumerThreadExecutor.submit(new Runnable() {
                @Override
                public void run() {
                    logger.debug("Thread " + Thread.currentThread().getName() + " start consuming Records...");
                    while (isAlive) {
                        Shard shard = shd;
                        try {
                            List<Record> records = KinesisUtil.getInstance().getRecords(streamName, recordsLimit,
                                    shard.getShardId(), getIteratorType(shard.getShardId()),
                                    shardPosition.get(shard.getShardId()));

                            if (records == null || records.isEmpty()) {
                                if (shard.getSequenceNumberRange().getEndingSequenceNumber() != null) {
                                    closedShards.add(shard);
                                    break;
                                }
                                try {
                                    Thread.sleep(recordsCheckInterval);
                                } catch (Exception e) {
                                    throw new RuntimeException(e);
                                }
                            } else {
                                String seqNo = "";
                                for (Record rc : records) {
                                    seqNo = rc.getSequenceNumber();
                                    putRecord(shd.getShardId(), rc);
                                }
                                shardPosition.put(shard.getShardId(), seqNo);
                            }
                        } catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                    }
                    logger.debug("Thread " + Thread.currentThread().getName() + " stop consuming Records...");
                }
            });
        }
    }

    @Override
    public void close() {
        if (consumerThreadExecutor != null) {
            consumerThreadExecutor.shutdown();
        }
        simpleConsumerThreads.clear();
    }

    /**
     * The method is called in the deactivate method of the operator
     */
    public void stop() {
        isAlive = false;
        holdingBuffer.clear();
        IOUtils.closeQuietly(this);
    }

    public void resetShardPositions(Map<String, String> shardPositions) {

        if (shardPositions == null) {
            return;
        }
        shardPosition.clear();

        for (String pid : shardIds) {
            String offsetForPar = shardPositions.get(pid);
            if (offsetForPar != null && !offsetForPar.equals("")) {
                shardPosition.put(pid, offsetForPar);
            }
        }
    }

    protected Map<String, String> getShardPosition() {
        return shardPosition;
    }

    public Set<Shard> getClosedShards() {
        return closedShards;
    }

    public Integer getNumOfShards() {
        return shardIds.size();
    }

    public KinesisShardStats getConsumerStats(Map<String, String> shardStats) {
        stats.updateShardStats(shardStats);
        return stats;
    }

    /**
     * This method is called in teardown method of the operator
     */
    public void teardown() {
        holdingBuffer.clear();
    }

    public boolean isAlive() {
        return isAlive;
    }

    public void setAlive(boolean isAlive) {
        this.isAlive = isAlive;
    }

    public void setStreamName(String streamName) {
        this.streamName = streamName;
    }

    public String getStreamName() {
        return streamName;
    }

    public Pair<String, Record> pollRecord() {
        return holdingBuffer.poll();
    }

    public int getQueueSize() {
        return holdingBuffer.size();
    }

    public void setInitialOffset(String initialOffset) {
        this.initialOffset = initialOffset;
    }

    public String getInitialOffset() {
        return initialOffset;
    }

    final protected void putRecord(String shardId, Record msg) throws InterruptedException {
        holdingBuffer.put(new Pair<String, Record>(shardId, msg));
    };

    public Integer getRecordsLimit() {
        return recordsLimit;
    }

    public void setRecordsLimit(Integer recordsLimit) {
        this.recordsLimit = recordsLimit;
    }

    public Set<String> getShardIds() {
        return shardIds;
    }

    public void setShardIds(Set<String> shardIds) {
        this.shardIds = shardIds;
    }

    public void setQueueSize(int bufferSize) {
        this.bufferSize = bufferSize;
    }

    /**
     * Counter class which gives the statistic value from the consumer
     */
    @SuppressWarnings("serial")
    public static class KinesisShardStats implements Serializable {
        public ConcurrentHashMap<String, String> partitionStats = new ConcurrentHashMap<String, String>();

        public KinesisShardStats() {
        }

        //important API for update
        public void updateShardStats(Map<String, String> shardStats) {
            for (Entry<String, String> ss : shardStats.entrySet()) {
                partitionStats.put(ss.getKey(), ss.getValue());
            }
        }
    }

    public static class KinesisShardStatsUtil {
        public static Map<String, String> getShardStatsForPartitions(List<KinesisShardStats> kinesisshardStats) {
            Map<String, String> result = Maps.newHashMap();
            for (KinesisShardStats kms : kinesisshardStats) {
                for (Entry<String, String> item : kms.partitionStats.entrySet()) {
                    result.put(item.getKey(), item.getValue());
                }
            }
            return result;
        }
    }
}