org.apache.beam.sdk.io.aws.s3.S3WritableByteChannel.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.beam.sdk.io.aws.s3.S3WritableByteChannel.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.sdk.io.aws.s3;

import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;

import com.amazonaws.AmazonClientException;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.CompleteMultipartUploadRequest;
import com.amazonaws.services.s3.model.InitiateMultipartUploadRequest;
import com.amazonaws.services.s3.model.InitiateMultipartUploadResult;
import com.amazonaws.services.s3.model.ObjectMetadata;
import com.amazonaws.services.s3.model.PartETag;
import com.amazonaws.services.s3.model.UploadPartRequest;
import com.amazonaws.services.s3.model.UploadPartResult;
import com.amazonaws.util.Base64;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.WritableByteChannel;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.List;
import org.apache.beam.sdk.io.aws.options.S3Options;
import org.apache.beam.sdk.io.aws.options.S3Options.S3UploadBufferSizeBytesFactory;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;

/** A writable S3 object, as a {@link WritableByteChannel}. */
class S3WritableByteChannel implements WritableByteChannel {
    private final AmazonS3 amazonS3;
    private final S3Options options;
    private final S3ResourceId path;

    private final String uploadId;
    private final ByteBuffer uploadBuffer;
    private final List<PartETag> eTags;

    // AWS S3 parts are 1-indexed, not zero-indexed.
    private int partNumber = 1;
    private boolean open = true;
    private final MessageDigest md5 = md5();

    S3WritableByteChannel(AmazonS3 amazonS3, S3ResourceId path, String contentType, S3Options options)
            throws IOException {
        this.amazonS3 = checkNotNull(amazonS3, "amazonS3");
        this.options = checkNotNull(options);
        this.path = checkNotNull(path, "path");
        checkArgument(
                atMostOne(options.getSSECustomerKey() != null, options.getSSEAlgorithm() != null,
                        options.getSSEAwsKeyManagementParams() != null),
                "Either SSECustomerKey (SSE-C) or SSEAlgorithm (SSE-S3)"
                        + " or SSEAwsKeyManagementParams (SSE-KMS) must not be set at the same time.");
        // Amazon S3 API docs: Each part must be at least 5 MB in size, except the last part.
        checkArgument(options
                .getS3UploadBufferSizeBytes() >= S3UploadBufferSizeBytesFactory.MINIMUM_UPLOAD_BUFFER_SIZE_BYTES,
                "S3UploadBufferSizeBytes must be at least %s bytes",
                S3UploadBufferSizeBytesFactory.MINIMUM_UPLOAD_BUFFER_SIZE_BYTES);
        this.uploadBuffer = ByteBuffer.allocate(options.getS3UploadBufferSizeBytes());
        eTags = new ArrayList<>();

        ObjectMetadata objectMetadata = new ObjectMetadata();
        objectMetadata.setContentType(contentType);
        if (options.getSSEAlgorithm() != null) {
            objectMetadata.setSSEAlgorithm(options.getSSEAlgorithm());
        }
        InitiateMultipartUploadRequest request = new InitiateMultipartUploadRequest(path.getBucket(), path.getKey())
                .withStorageClass(options.getS3StorageClass()).withObjectMetadata(objectMetadata);
        request.setSSECustomerKey(options.getSSECustomerKey());
        request.setSSEAwsKeyManagementParams(options.getSSEAwsKeyManagementParams());
        InitiateMultipartUploadResult result;
        try {
            result = amazonS3.initiateMultipartUpload(request);
        } catch (AmazonClientException e) {
            throw new IOException(e);
        }
        uploadId = result.getUploadId();
    }

    private static MessageDigest md5() {
        try {
            return MessageDigest.getInstance("MD5");
        } catch (NoSuchAlgorithmException e) {
            throw new IllegalStateException(e);
        }
    }

    @Override
    public int write(ByteBuffer sourceBuffer) throws IOException {
        if (!isOpen()) {
            throw new ClosedChannelException();
        }

        int totalBytesWritten = 0;
        while (sourceBuffer.hasRemaining()) {
            int bytesWritten = Math.min(sourceBuffer.remaining(), uploadBuffer.remaining());
            totalBytesWritten += bytesWritten;

            byte[] copyBuffer = new byte[bytesWritten];
            sourceBuffer.get(copyBuffer);
            uploadBuffer.put(copyBuffer);
            md5.update(copyBuffer);

            if (!uploadBuffer.hasRemaining() || sourceBuffer.hasRemaining()) {
                flush();
            }
        }

        return totalBytesWritten;
    }

    private void flush() throws IOException {
        uploadBuffer.flip();
        ByteArrayInputStream inputStream = new ByteArrayInputStream(uploadBuffer.array());

        UploadPartRequest request = new UploadPartRequest().withBucketName(path.getBucket()).withKey(path.getKey())
                .withUploadId(uploadId).withPartNumber(partNumber++).withPartSize(uploadBuffer.remaining())
                .withMD5Digest(Base64.encodeAsString(md5.digest())).withInputStream(inputStream);
        request.setSSECustomerKey(options.getSSECustomerKey());

        UploadPartResult result;
        try {
            result = amazonS3.uploadPart(request);
        } catch (AmazonClientException e) {
            throw new IOException(e);
        }
        uploadBuffer.clear();
        md5.reset();
        eTags.add(result.getPartETag());
    }

    @Override
    public boolean isOpen() {
        return open;
    }

    @Override
    public void close() throws IOException {
        open = false;
        if (uploadBuffer.remaining() > 0) {
            flush();
        }
        CompleteMultipartUploadRequest request = new CompleteMultipartUploadRequest()
                .withBucketName(path.getBucket()).withKey(path.getKey()).withUploadId(uploadId)
                .withPartETags(eTags);
        try {
            amazonS3.completeMultipartUpload(request);
        } catch (AmazonClientException e) {
            throw new IOException(e);
        }
    }

    @VisibleForTesting
    static boolean atMostOne(boolean... values) {
        boolean one = false;
        for (boolean value : values) {
            if (!one && value) {
                one = true;
            } else if (value) {
                return false;
            }
        }
        return true;
    }
}