org.apache.avro.ipc.RestRequestor.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.avro.ipc.RestRequestor.java

Source

/*
 * Copyright (c) 2013
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */

package org.apache.avro.ipc;

import org.apache.avro.AvroRemoteException;
import org.apache.avro.AvroRuntimeException;
import org.apache.avro.Protocol;
import org.apache.avro.Schema;
import org.apache.avro.io.BinaryDecoder;
import org.apache.avro.io.DatumWriter;
import org.apache.avro.io.Decoder;
import org.apache.avro.io.DecoderFactory;
import org.apache.avro.io.Encoder;
import org.apache.avro.io.EncoderFactory;
import org.apache.avro.specific.SpecificData;
import org.apache.avro.specific.SpecificDatumWriter;
import org.apache.avro.util.ByteBufferInputStream;
import org.apache.avro.util.ByteBufferOutputStream;
import org.apache.avro.util.Utf8;
import org.apache.commons.lang.CharSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.concurrent.ExecutionException;

/**
 * TODO: Document this
 */
public class RestRequestor implements InvocationHandler {

    private static final Logger LOG = LoggerFactory.getLogger(RestRequestor.class);

    private static final EncoderFactory ENCODER_FACTORY = new EncoderFactory();

    private final Protocol protocol;
    private volatile Protocol remote;
    private final RestTransceiver transceiver;
    private final SpecificData data;

    public RestRequestor(Protocol protocol, RestTransceiver transciever, SpecificData data) {
        this.protocol = protocol;
        this.transceiver = transciever;
        this.data = data;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        try {
            // Check if this is a callback-based RPC:
            //            Type[] parameterTypes = method.getParameterTypes();
            //            if ((parameterTypes.length > 0) &&
            //                    (parameterTypes[parameterTypes.length - 1] instanceof Class) &&
            //                    Callback.class.isAssignableFrom(((Class<?>) parameterTypes[parameterTypes.length - 1]))) {
            //                // Extract the Callback from the end of of the argument list
            //                Object[] finalArgs = Arrays.copyOf(args, args.length - 1);
            //                Callback<?> callback = (Callback<?>) args[args.length - 1];
            //                request(method.getName(), finalArgs, callback);
            //                return null;
            //            } else {
            return request(method.getName(), args);
            //            }
        } catch (Exception e) {
            // Check if this is a declared Exception:
            for (Class<?> exceptionClass : method.getExceptionTypes()) {
                if (exceptionClass.isAssignableFrom(e.getClass())) {
                    throw e;
                }
            }

            // Next, check for RuntimeExceptions:
            if (e instanceof RuntimeException) {
                throw e;
            }

            // Not an expected Exception, so wrap it in AvroRemoteException:
            throw new AvroRemoteException(e);
        }
    }

    <T> void request(RestRequest request, Callback<T> callback) throws Exception {
        if (!transceiver.isConnected()) {
            // Acquire handshake lock so that only one thread is performing the
            // handshake and other threads block until the handshake is completed
            //            handshakeLock.lock();
            try {
                if (transceiver.isConnected()) {
                    // Another thread already completed the handshake; no need to hold
                    // the write lock
                    //                    handshakeLock.unlock();
                } else {
                    CallFuture<T> callFuture = new CallFuture<>(callback);
                    transceiver.transceive(request, new TransceiverCallback<>(request, callFuture));
                    // Block until handshake complete
                    callFuture.await();
                    if (request.getMessage().isOneWay()) {
                        Throwable error = callFuture.getError();
                        if (error != null) {
                            if (error instanceof Exception) {
                                throw (Exception) error;
                            } else {
                                throw new AvroRemoteException(error);
                            }
                        }
                    }
                    return;
                }
            } finally {
                //                if (handshakeLock.isHeldByCurrentThread()) {
                //                    handshakeLock.unlock();
                //                }
            }
        }

        if (request.getMessage().isOneWay()) {
            transceiver.lockChannel();
            try {
                transceiver.writeBuffers(request.getBytes());
                if (callback != null) {
                    callback.handleResult(null);
                }
            } finally {
                transceiver.unlockChannel();
            }
        } else {
            CallFuture<T> callFuture = new CallFuture<T>(callback);
            transceiver.transceive(request, new TransceiverCallback<T>(request, callFuture));
        }

    }

    private Object request(String messageName, Object request) throws Exception {
        // Initialize request
        RestRequest rpcRequest = new RestRequest(messageName, request, new RPCContext());
        CallFuture<Object> future = /* only need a Future for two-way messages */
                rpcRequest.getMessage().isOneWay() ? null : new CallFuture<Object>();

        // Send request
        request(rpcRequest, future);

        if (future == null) // the message is one-way, so return immediately
            return null;
        try { // the message is two-way, wait for the result
            return future.get();
        } catch (ExecutionException e) {
            if (e.getCause() instanceof Exception) {
                throw (Exception) e.getCause();
            } else {
                throw new AvroRemoteException(e.getCause());
            }
        }
    }

    public class RestRequest {
        private final String messageName;
        private final Object request;
        private final RPCContext context;
        private Protocol.Message message;
        private HttpBuffers buffers;

        public RestRequest(String messageName, Object request, RPCContext context) {
            this.messageName = messageName;
            this.request = request;
            this.context = context;
        }

        /**
         * Copy constructor.
         *
         * @param other Request from which to copy fields.
         */
        public RestRequest(RestRequest other) {
            this.messageName = other.messageName;
            this.request = other.request;
            this.context = other.context;
            this.buffers = other.buffers;
        }

        /**
         * Gets the message name.
         *
         * @return the message name.
         */
        public String getMessageName() {
            return messageName;
        }

        /**
         * Gets the RPC context.
         *
         * @return the RPC context.
         */
        public RPCContext getContext() {
            return context;
        }

        /**
         * Gets the Message associated with this request.
         *
         * @return this request's message.
         */
        public Protocol.Message getMessage() {
            if (message == null) {
                message = protocol.getMessages().get(messageName);
                if (message == null) {
                    throw new AvroRuntimeException("Not a local message: " + messageName);
                }
            }
            return message;
        }

        public HttpBuffers getBytes(RestTransceiver.ContentType contentType) throws Exception {
            if (buffers == null || buffers.contentType != contentType) {
                HttpBuffers httpBuffers = parseMessage();
                List<ByteBuffer> bytes;
                if (contentType == RestTransceiver.ContentType.JSON) {
                    bytes = getJsonBytes();
                } else {
                    bytes = getBytes();
                }
                buffers = HttpBuffers.newBuilder(httpBuffers).setBuffers(bytes).build();
            }
            return buffers;
        }

        private List<ByteBuffer> getBytes() throws Exception {
            ByteBufferOutputStream bbo = new ByteBufferOutputStream();
            Encoder out = ENCODER_FACTORY.binaryEncoder(bbo, null);
            return getBytes(out, bbo);
        }

        private List<ByteBuffer> getJsonBytes() throws Exception {
            ByteBufferOutputStream bbo = new ByteBufferOutputStream();
            Protocol.Message m = getMessage();
            Encoder out = ENCODER_FACTORY.jsonEncoder(m.getRequest(), bbo);
            return getBytes(out, bbo);
        }

        private List<ByteBuffer> getBytes(Encoder out, ByteBufferOutputStream bbo) throws Exception {

            // use local protocol to write request
            Protocol.Message m = getMessage();
            context.setMessage(m);

            writeRequest(m.getRequest(), request, out); // write request payload

            out.flush();
            List<ByteBuffer> payload = bbo.getBufferList();

            //            writeHandshake(out);                     // prepend handshake if needed

            context.setRequestPayload(payload);

            //            out.writeString(m.getName());             // write message name

            out.flush();
            bbo.append(payload);

            return bbo.getBufferList();
        }

        private HttpBuffers parseMessage() throws IOException {
            Protocol.Message m = getMessage();
            String name = m.getName();
            String path;
            RestTransceiver.Verb verb;
            int fieldCount = m.getRequest().getFields().size();
            if (name.startsWith("find")) {
                verb = RestTransceiver.Verb.GET;
                path = "/" + name.substring(4) + getParametersAsPath(fieldCount);
            } else if (name.startsWith("findAll")) {
                verb = RestTransceiver.Verb.GET;
                path = "/" + name.substring(7) + getParametersAsPath(fieldCount);
            } else if (name.startsWith("delete")) {
                verb = RestTransceiver.Verb.DELETE;
                path = "/" + name.substring(7) + getParametersAsPath(fieldCount);
            } else if (name.startsWith("update")) {
                verb = RestTransceiver.Verb.PUT;
                path = "/" + name.substring(6) + getParametersAsPath(fieldCount - 1);
            } else if (name.startsWith("create")) {
                verb = RestTransceiver.Verb.POST;
                path = "/" + name.substring(6) + getParametersAsPath(fieldCount - 1);
            } else {
                verb = RestTransceiver.Verb.POST;
                path = "/";
            }
            return HttpBuffers.newBuilder().setUri(path).setVerb(verb).build();
        }

        private String getParametersAsPath(int fieldCount) throws IOException {
            StringBuilder sb = new StringBuilder();
            Object[] args = (Object[]) request;
            Protocol.Message m = getMessage();
            Schema s = m.getRequest();
            for (int i = 0; i < fieldCount; i++) {
                Schema.Field param = s.getFields().get(i);
                sb.append("/");
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                Encoder out = new UrlStringEncoder(s, baos);
                getDatumWriter(param.schema()).write(args[i], out);
                sb.append(new Utf8(baos.toByteArray()).toString());
            }
            return sb.toString();
        }

    }

    public void writeRequest(Schema schema, Object request, Encoder out) throws IOException {
        Object[] args = (Object[]) request;
        int idx = schema.getFields().size() - 1;
        Schema.Field param = schema.getFields().get(idx);
        getDatumWriter(param.schema()).write(args[idx], out);
    }

    protected DatumWriter<Object> getDatumWriter(Schema schema) {
        return new SpecificDatumWriter<Object>(schema);
    }

    protected class TransceiverCallback<T> implements Callback<HttpBuffers> {
        private final RestRequest request;
        private final Callback<T> callback;

        /**
         * Creates a TransceiverCallback.
         *
         * @param request  the request to set.
         * @param callback the callback to set.
         */
        public TransceiverCallback(RestRequest request, Callback<T> callback) {
            this.request = request;
            this.callback = callback;
        }

        @Override
        @SuppressWarnings("unchecked")
        public void handleResult(HttpBuffers responseBuffers) {
            ByteBufferInputStream bbi = new ByteBufferInputStream(responseBuffers.buffers);
            BinaryDecoder in = DecoderFactory.get().binaryDecoder(bbi, null);
            //            try {
            //                if (!readHandshake(in)) {
            //                    // Resend the handshake and return
            //                    RestRequest handshake = new RestRequest(request);
            //                    transceiver.transceive
            //                            (handshake.getBytes(),
            //                                    new TransceiverCallback<T>(handshake, callback));
            //                    return;
            //                }
            //            } catch (Exception e) {
            //                LOG.error("Error handling transceiver callback: " + e, e);
            //            }

            // Read response; invoke callback
            RestResponse response = new RestResponse(request, in);
            Object responseObject;
            try {
                try {
                    responseObject = response.getResponse();
                } catch (Exception e) {
                    if (callback != null) {
                        callback.handleError(e);
                    }
                    return;
                }
                if (callback != null) {
                    callback.handleResult((T) responseObject);
                }
            } catch (Throwable t) {
                LOG.error("Error in callback handler: " + t, t);
            }
        }

        @Override
        public void handleError(Throwable error) {
            callback.handleError(error);
        }

        public class RestResponse {
            private final RestRequest request;
            private final Decoder in;

            /**
             * Creates a Response.
             *
             * @param request the Request associated with this response.
             */
            public RestResponse(RestRequest request) {
                this(request, null);
            }

            /**
             * Creates a Creates a Response.
             *
             * @param request the Request associated with this response.
             * @param in      the BinaryDecoder to use to deserialize the response.
             */
            public RestResponse(RestRequest request, Decoder in) {
                this.request = request;
                this.in = in;
            }

            /**
             * Gets the RPC response, reading/deserializing it first if necessary.
             *
             * @return the RPC response.
             * @throws Exception if an error occurs reading/deserializing the response.
             */
            public Object getResponse() throws Exception {
                Protocol.Message lm = request.getMessage();
                Protocol.Message rm = protocol.getMessages().get(request.getMessageName());
                if (rm == null)
                    throw new AvroRuntimeException("Not a remote message: " + request.getMessageName());

                if ((lm.isOneWay() != rm.isOneWay()) && transceiver.isConnected())
                    throw new AvroRuntimeException("Not both one-way messages: " + request.getMessageName());

                if (lm.isOneWay() && transceiver.isConnected())
                    return null; // one-way w/ handshake

                //                RPCContext context = request.getContext();
                //                context.setResponseCallMeta(META_READER.read(null, in));
                //
                //                if (!in.readBoolean()) {                      // no error
                //                    Object response = readResponse(rm.getResponse(), lm.getResponse(), in);
                //                    context.setResponse(response);
                //                    for (RPCPlugin plugin : rpcMetaPlugins) {
                //                        plugin.clientReceiveResponse(context);
                //                    }
                //                    return response;
                //
                //                } else {
                //                    Exception error = readError(rm.getErrors(), lm.getErrors(), in);
                //                    context.setError(error);
                //                    for (RPCPlugin plugin : rpcMetaPlugins) {
                //                        plugin.clientReceiveResponse(context);
                //                    }
                //                    throw error;
                //                }
                return null;
            }
        }
    }

    //    private void writeHandshake(Encoder out) throws IOException {
    //        if (transceiver.isConnected()) return;
    //        MD5 localHash = new MD5();
    //        localHash.bytes(local.getMD5());
    //        String remoteName = transceiver.getRemoteName();
    //        MD5 remoteHash = REMOTE_HASHES.get(remoteName);
    //        if (remoteHash == null) {                     // guess remote is local
    //            remoteHash = localHash;
    //            remote = local;
    //        } else {
    //            remote = REMOTE_PROTOCOLS.get(remoteHash);
    //        }
    //        HandshakeRequest handshake = new HandshakeRequest();
    //        handshake.clientHash = localHash;
    //        handshake.serverHash = remoteHash;
    //        if (sendLocalText)
    //            handshake.clientProtocol = local.toString();
    //
    //        RPCContext context = new RPCContext();
    //        context.setHandshakeRequest(handshake);
    //        for (RPCPlugin plugin : rpcMetaPlugins) {
    //            plugin.clientStartConnect(context);
    //        }
    //        handshake.meta = context.requestHandshakeMeta();
    //
    //        HANDSHAKE_WRITER.write(handshake, out);
    //    }
}