com.facebook.presto.kinesis.util.MockKinesisClient.java Source code

Java tutorial

Introduction

Here is the source code for com.facebook.presto.kinesis.util.MockKinesisClient.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.kinesis.util;

import com.amazonaws.AmazonClientException;
import com.amazonaws.AmazonServiceException;
import com.amazonaws.AmazonWebServiceRequest;
import com.amazonaws.ResponseMetadata;
import com.amazonaws.regions.Region;

import com.amazonaws.services.kinesis.AmazonKinesisClient;
import com.amazonaws.services.kinesis.model.CreateStreamRequest;
import com.amazonaws.services.kinesis.model.CreateStreamResult;
import com.amazonaws.services.kinesis.model.DescribeStreamRequest;
import com.amazonaws.services.kinesis.model.DescribeStreamResult;
import com.amazonaws.services.kinesis.model.GetRecordsRequest;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.amazonaws.services.kinesis.model.GetShardIteratorRequest;
import com.amazonaws.services.kinesis.model.GetShardIteratorResult;
import com.amazonaws.services.kinesis.model.ListStreamsRequest;
import com.amazonaws.services.kinesis.model.ListStreamsResult;
import com.amazonaws.services.kinesis.model.ListTagsForStreamRequest;
import com.amazonaws.services.kinesis.model.ListTagsForStreamResult;
import com.amazonaws.services.kinesis.model.PutRecordRequest;
import com.amazonaws.services.kinesis.model.PutRecordResult;
import com.amazonaws.services.kinesis.model.PutRecordsRequest;
import com.amazonaws.services.kinesis.model.PutRecordsRequestEntry;
import com.amazonaws.services.kinesis.model.PutRecordsResult;
import com.amazonaws.services.kinesis.model.PutRecordsResultEntry;
import com.amazonaws.services.kinesis.model.Record;
import com.amazonaws.services.kinesis.model.SequenceNumberRange;
import com.amazonaws.services.kinesis.model.Shard;
import com.amazonaws.services.kinesis.model.StreamDescription;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Date;

/**
 * Mock kinesis client for testing that is primarily used for reading from the
 * stream as we do here in Presto.
 *
 * This is to help prove that the API is being used correctly and debug any
 * issues that arise without incurring AWS load and charges.  It is far from a complete
 * implementation of Kinesis.
 *
 * Created by derekbennett on 6/20/16.
 */
public class MockKinesisClient extends AmazonKinesisClient {
    private String endpoint = "";
    private Region region = null;
    ArrayList<InternalStream> streams = new ArrayList<InternalStream>();

    //// Support classes

    public static class InternalShard extends Shard {
        private ArrayList<Record> recs = new ArrayList<Record>();
        private String streamName = "";
        private int index = 0;

        public InternalShard(String owningStream, int anIndex) {
            super();
            this.streamName = owningStream;
            this.index = anIndex;
            this.setShardId(this.streamName + "_" + this.index);
        }

        public ArrayList<Record> getRecords() {
            return recs;
        }

        public ArrayList<Record> getRecordsFrom(ShardIterator iter) {
            ArrayList<Record> returnRecords = new ArrayList<Record>();

            for (Record record : this.recs) {
                if (Integer.valueOf(record.getSequenceNumber()) >= iter.recordIndex) {
                    returnRecords.add(record);
                }
            }

            return returnRecords;
        }

        public String getStreamName() {
            return streamName;
        }

        public int getIndex() {
            return index;
        }

        public void addRecord(Record rec) {
            recs.add(rec);
        }

        public void clearRecords() {
            recs.clear();
        }
    }

    public static class InternalStream {
        private String streamName = "";
        private String streamARN = "";
        private String streamStatus = "CREATING";
        private int retentionPeriodHours = 24;
        private ArrayList<InternalShard> shards = new ArrayList<InternalShard>();
        private int sequenceNo = 100;
        private int nextShard = 0;

        public InternalStream(String aName, int nbShards, boolean isActive) {
            this.streamName = aName;
            this.streamARN = "local:fake.stream:" + aName;
            if (isActive) {
                this.streamStatus = "ACTIVE";
            }

            for (int i = 0; i < nbShards; i++) {
                InternalShard newShard = new InternalShard(this.streamName, i);
                newShard.setSequenceNumberRange((new SequenceNumberRange()).withStartingSequenceNumber("100")
                        .withEndingSequenceNumber("999"));
                this.shards.add(newShard);
            }
        }

        public String getStreamName() {
            return streamName;
        }

        public String getStreamARN() {
            return streamARN;
        }

        public String getStreamStatus() {
            return streamStatus;
        }

        public int getRetentionPeriodHours() {
            return retentionPeriodHours;
        }

        public ArrayList<InternalShard> getShards() {
            return shards;
        }

        public ArrayList<InternalShard> getShardsFrom(String afterShardId) {
            String[] comps = afterShardId.split("_");
            if (comps.length == 2) {
                ArrayList<InternalShard> returnArray = new ArrayList<InternalShard>();
                int afterIndex = Integer.parseInt(comps[1]);
                if (shards.size() > afterIndex + 1) {
                    for (InternalShard shard : shards) {
                        if (shard.getIndex() > afterIndex) {
                            returnArray.add(shard);
                        }
                    }
                }

                return returnArray;
            } else {
                return new ArrayList<InternalShard>();
            }
        }

        public void activate() {
            this.streamStatus = "ACTIVE";
        }

        public PutRecordResult putRecord(ByteBuffer data, String partitionKey) {
            // Create record and insert into the shards.  Initially just do it
            // on a round robin basis.
            long ts = System.currentTimeMillis() - 50000;
            Record rec = new Record();
            rec = rec.withData(data).withPartitionKey(partitionKey).withSequenceNumber(String.valueOf(sequenceNo));
            rec.setApproximateArrivalTimestamp(new Date(ts));

            if (nextShard == shards.size()) {
                nextShard = 0;
            }
            InternalShard shard = shards.get(nextShard);
            shard.addRecord(rec);

            PutRecordResult result = new PutRecordResult();
            result.setSequenceNumber(String.valueOf(sequenceNo));
            result.setShardId(shard.getShardId());

            nextShard++;
            sequenceNo++;

            return result;
        }

        public void clearRecords() {
            for (InternalShard shard : this.shards) {
                shard.clearRecords();
            }
        }
    }

    public static class ShardIterator {
        public String streamId = "";
        public int shardIndex = 0;
        public int recordIndex = 0;

        public ShardIterator(String aStreamId, int aShard, int aRecord) {
            this.streamId = aStreamId;
            this.shardIndex = aShard;
            this.recordIndex = aRecord;
        }

        public String makeString() {
            return this.streamId + "_" + this.shardIndex + "_" + this.recordIndex;
        }

        public static ShardIterator fromStreamAndShard(String streamName, String shardId) {
            ShardIterator newInst = null;
            String[] comps = shardId.split("_");
            if (streamName.equals(comps[0]) && comps[1].matches("[0-9]+")) {
                newInst = new ShardIterator(comps[0], Integer.parseInt(comps[1]), 0);
            }

            return newInst;
        }

        public static ShardIterator fromString(String input) {
            ShardIterator newInst = null;
            String[] comps = input.split("_");
            if (comps.length == 3) {
                if (comps[1].matches("[0-9]+") && comps[2].matches("[0-9]+")) {
                    newInst = new ShardIterator(comps[0], Integer.parseInt(comps[1]), Integer.parseInt(comps[2]));
                }
            }

            return newInst;
        }
    }

    public MockKinesisClient() {
        super();
    }

    protected InternalStream getStream(String name) {
        InternalStream foundStream = null;
        for (InternalStream stream : this.streams) {
            if (stream.getStreamName().equals(name)) {
                foundStream = stream;
                break;
            }
        }
        return foundStream;
    }

    protected ArrayList<Shard> getShards(InternalStream theStream) {
        ArrayList<Shard> externalList = new ArrayList<Shard>();
        for (InternalShard intshard : theStream.getShards()) {
            externalList.add(intshard);
        }

        return externalList;
    }

    protected ArrayList<Shard> getShards(InternalStream theStream, String fromShardId) {
        ArrayList<Shard> externalList = new ArrayList<Shard>();
        for (InternalShard intshard : theStream.getShardsFrom(fromShardId)) {
            externalList.add(intshard);
        }

        return externalList;
    }

    /** Clears everything, including all stream and shard definitions. */
    public void clearAll() {
        this.streams.clear();
    }

    /** Clears records from shards but leaves stream and shard structure in place. */
    public void clearRecords() {
        for (InternalStream stream : this.streams) {
            stream.clearRecords();
        }
    }

    @Override
    public void setEndpoint(String s) throws IllegalArgumentException {
        this.endpoint = s;
    }

    @Override
    public void setRegion(Region region) throws IllegalArgumentException {
        this.region = region;
    }

    @Override
    public PutRecordResult putRecord(PutRecordRequest putRecordRequest)
            throws AmazonServiceException, AmazonClientException {
        // Setup method to add a new record:
        InternalStream theStream = this.getStream(putRecordRequest.getStreamName());
        if (theStream != null) {
            PutRecordResult result = theStream.putRecord(putRecordRequest.getData(),
                    putRecordRequest.getPartitionKey());
            return result;
        } else {
            throw new AmazonClientException("This stream does not exist!");
        }
    }

    @Override
    public CreateStreamResult createStream(CreateStreamRequest createStreamRequest)
            throws AmazonServiceException, AmazonClientException {
        // Setup method to create a new stream:
        InternalStream stream = new InternalStream(createStreamRequest.getStreamName(),
                createStreamRequest.getShardCount(), true);
        this.streams.add(stream);
        return new CreateStreamResult();
    }

    @Override
    public CreateStreamResult createStream(String s, Integer integer)
            throws AmazonServiceException, AmazonClientException {
        return this.createStream((new CreateStreamRequest()).withStreamName(s).withShardCount(integer));
    }

    @Override
    public PutRecordsResult putRecords(PutRecordsRequest putRecordsRequest)
            throws AmazonServiceException, AmazonClientException {
        // Setup method to add a batch of new records:
        InternalStream theStream = this.getStream(putRecordsRequest.getStreamName());
        if (theStream != null) {
            PutRecordsResult result = new PutRecordsResult();
            ArrayList<PutRecordsResultEntry> resultList = new ArrayList<PutRecordsResultEntry>();
            for (PutRecordsRequestEntry entry : putRecordsRequest.getRecords()) {
                PutRecordResult putResult = theStream.putRecord(entry.getData(), entry.getPartitionKey());
                resultList.add((new PutRecordsResultEntry()).withShardId(putResult.getShardId())
                        .withSequenceNumber(putResult.getSequenceNumber()));
            }

            result.setRecords(resultList);
            return result;
        } else {
            throw new AmazonClientException("This stream does not exist!");
        }
    }

    @Override
    public DescribeStreamResult describeStream(DescribeStreamRequest describeStreamRequest)
            throws AmazonServiceException, AmazonClientException {
        InternalStream theStream = this.getStream(describeStreamRequest.getStreamName());
        if (theStream != null) {
            StreamDescription desc = new StreamDescription();
            desc = desc.withStreamName(theStream.getStreamName()).withStreamStatus(theStream.getStreamStatus())
                    .withStreamARN(theStream.getStreamARN());

            if (describeStreamRequest.getExclusiveStartShardId() == null
                    || describeStreamRequest.getExclusiveStartShardId().isEmpty()) {
                desc.setShards(this.getShards(theStream));
                desc.setHasMoreShards(false);
            } else {
                // Filter from given shard Id, or may not have any more
                String startId = describeStreamRequest.getExclusiveStartShardId();
                desc.setShards(this.getShards(theStream, startId));
                desc.setHasMoreShards(false);
            }

            DescribeStreamResult result = new DescribeStreamResult();
            result = result.withStreamDescription(desc);
            return result;
        } else {
            throw new AmazonClientException("This stream does not exist!");
        }
    }

    @Override
    public GetShardIteratorResult getShardIterator(GetShardIteratorRequest getShardIteratorRequest)
            throws AmazonServiceException, AmazonClientException {
        ShardIterator iter = ShardIterator.fromStreamAndShard(getShardIteratorRequest.getStreamName(),
                getShardIteratorRequest.getShardId());
        if (iter != null) {
            InternalStream theStream = this.getStream(iter.streamId);
            if (theStream != null) {
                String seqAsString = getShardIteratorRequest.getStartingSequenceNumber();
                if (seqAsString != null && !seqAsString.isEmpty()
                        && getShardIteratorRequest.getShardIteratorType().equals("AFTER_SEQUENCE_NUMBER")) {
                    int sequence = Integer.parseInt(seqAsString);
                    iter.recordIndex = sequence + 1;
                } else {
                    iter.recordIndex = 100;
                }

                GetShardIteratorResult result = new GetShardIteratorResult();
                return result.withShardIterator(iter.makeString());
            } else {
                throw new AmazonClientException("Unknown stream or bad shard iterator!");
            }
        } else {
            throw new AmazonClientException("Bad stream or shard iterator!");
        }
    }

    @Override
    public GetRecordsResult getRecords(GetRecordsRequest getRecordsRequest)
            throws AmazonServiceException, AmazonClientException {
        ShardIterator iter = ShardIterator.fromString(getRecordsRequest.getShardIterator());
        if (iter == null) {
            throw new AmazonClientException("Bad shard iterator.");
        }

        // TODO: incorporate maximum batch size (getRecordsRequest.getLimit)
        GetRecordsResult result = null;
        InternalStream stream = this.getStream(iter.streamId);
        if (stream != null) {
            InternalShard shard = stream.getShards().get(iter.shardIndex);

            if (iter.recordIndex == 100) {
                result = new GetRecordsResult();
                ArrayList<Record> recs = shard.getRecords();
                result.setRecords(recs); // NOTE: getting all for now
                result.setNextShardIterator(getNextShardIterator(iter, recs).makeString());
                result.setMillisBehindLatest(100L);
            } else {
                result = new GetRecordsResult();
                ArrayList<Record> recs = shard.getRecordsFrom(iter);
                result.setRecords(recs); // may be empty
                result.setNextShardIterator(getNextShardIterator(iter, recs).makeString());
                result.setMillisBehindLatest(100L);
            }
        } else {
            throw new AmazonClientException("Unknown stream or bad shard iterator.");
        }

        return result;
    }

    protected ShardIterator getNextShardIterator(ShardIterator previousIter, ArrayList<Record> records) {
        ShardIterator newIter = null;
        if (records.size() == 0) {
            newIter = previousIter;
        } else {
            Record rec = records.get(records.size() - 1);
            int lastSeq = Integer.valueOf(rec.getSequenceNumber());
            newIter = new ShardIterator(previousIter.streamId, previousIter.shardIndex, lastSeq + 1);
        }

        return newIter;
    }

    //// Unsupported methods

    @Override
    public ListTagsForStreamResult listTagsForStream(ListTagsForStreamRequest listTagsForStreamRequest)
            throws AmazonServiceException, AmazonClientException {
        return null;
    }

    @Override
    public ListStreamsResult listStreams(ListStreamsRequest listStreamsRequest)
            throws AmazonServiceException, AmazonClientException {
        return null;
    }

    @Override
    public ListStreamsResult listStreams() throws AmazonServiceException, AmazonClientException {
        return null;
    }

    @Override
    public PutRecordResult putRecord(String s, ByteBuffer byteBuffer, String s1)
            throws AmazonServiceException, AmazonClientException {
        throw new UnsupportedOperationException("MockKinesisClient doesn't support this.");
    }

    @Override
    public PutRecordResult putRecord(String s, ByteBuffer byteBuffer, String s1, String s2)
            throws AmazonServiceException, AmazonClientException {
        throw new UnsupportedOperationException("MockKinesisClient doesn't support this.");
    }

    @Override
    public DescribeStreamResult describeStream(String s) throws AmazonServiceException, AmazonClientException {
        return null;
    }

    @Override
    public DescribeStreamResult describeStream(String s, String s1)
            throws AmazonServiceException, AmazonClientException {
        return null;
    }

    @Override
    public DescribeStreamResult describeStream(String s, Integer integer, String s1)
            throws AmazonServiceException, AmazonClientException {
        return null;
    }

    @Override
    public GetShardIteratorResult getShardIterator(String s, String s1, String s2)
            throws AmazonServiceException, AmazonClientException {
        return null;
    }

    @Override
    public GetShardIteratorResult getShardIterator(String s, String s1, String s2, String s3)
            throws AmazonServiceException, AmazonClientException {
        return null;
    }

    @Override
    public ListStreamsResult listStreams(String s) throws AmazonServiceException, AmazonClientException {
        return null;
    }

    @Override
    public ListStreamsResult listStreams(Integer integer, String s)
            throws AmazonServiceException, AmazonClientException {
        return null;
    }

    @Override
    public void shutdown() {
        return; // Nothing to shutdown here
    }

    @Override
    public ResponseMetadata getCachedResponseMetadata(AmazonWebServiceRequest amazonWebServiceRequest) {
        return null;
    }
}