com.alibaba.dubbo.rpc.protocol.thrift.ThriftCodec.java Source code

Java tutorial

Introduction

Here is the source code for com.alibaba.dubbo.rpc.protocol.thrift.ThriftCodec.java

Source

/**
 * File Created at 2011-12-05
 * $Id$
 *
 * Copyright 2008 Alibaba.com Croporation Limited.
 * All rights reserved.
 *
 * This software is the confidential and proprietary information of
 * Alibaba Company. ("Confidential Information").  You shall not
 * disclose such Confidential Information and shall use it only in
 * accordance with the terms of the license agreement you entered into
 * with Alibaba.com.
 */
package com.alibaba.dubbo.rpc.protocol.thrift;

import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.lang.StringUtils;
import org.apache.thrift.TApplicationException;
import org.apache.thrift.TBase;
import org.apache.thrift.TException;
import org.apache.thrift.TFieldIdEnum;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TMessageType;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TFramedTransport;
import org.apache.thrift.transport.TIOStreamTransport;

import com.alibaba.dubbo.common.Constants;
import com.alibaba.dubbo.common.extension.ExtensionLoader;
import com.alibaba.dubbo.common.utils.ClassHelper;
import com.alibaba.dubbo.remoting.Channel;
import com.alibaba.dubbo.remoting.Codec2;
import com.alibaba.dubbo.remoting.buffer.ChannelBuffer;
import com.alibaba.dubbo.remoting.buffer.ChannelBufferInputStream;
import com.alibaba.dubbo.remoting.exchange.Request;
import com.alibaba.dubbo.remoting.exchange.Response;
import com.alibaba.dubbo.rpc.RpcException;
import com.alibaba.dubbo.rpc.RpcInvocation;
import com.alibaba.dubbo.rpc.RpcResult;
import com.alibaba.dubbo.rpc.protocol.thrift.io.RandomAccessByteArrayOutputStream;

/**
 * Thrift framed protocol codec.
 *
 * <pre>
 * |<-                                  message header                                  ->|<- message body ->|
 * +----------------+----------------------+------------------+---------------------------+------------------+
 * | magic (2 bytes)|message size (4 bytes)|head size(2 bytes)| version (1 byte) | header |   message body   |
 * +----------------+----------------------+------------------+---------------------------+------------------+
 * |<-                                               message size                                          ->|
 * </pre>
 *
 * <p>
 * <b>header fields in version 1</b>
 * <ol>
 *     <li>string - service name</li>
 *     <li>long   - dubbo request id</li>
 * </ol>
 * </p>
 *
 * @author <a href="mailto:gang.lvg@alibaba-inc.com">gang.lvg</a>
 */
public class ThriftCodec implements Codec2 {

    private static final AtomicInteger THRIFT_SEQ_ID = new AtomicInteger(0);

    private static final ConcurrentMap<String, Class<?>> cachedClass = new ConcurrentHashMap<String, Class<?>>();

    static final ConcurrentMap<Long, RequestData> cachedRequest = new ConcurrentHashMap<Long, RequestData>();

    public static final int MESSAGE_LENGTH_INDEX = 2;

    public static final int MESSAGE_HEADER_LENGTH_INDEX = 6;

    public static final int MESSAGE_SHORTEST_LENGTH = 10;

    public static final String NAME = "thrift";

    public static final String PARAMETER_CLASS_NAME_GENERATOR = "class.name.generator";

    public static final byte VERSION = (byte) 1;

    public static final short MAGIC = (short) 0xdabc;

    public void encode(Channel channel, ChannelBuffer buffer, Object message) throws IOException {

        if (message instanceof Request) {
            encodeRequest(channel, buffer, (Request) message);
        } else if (message instanceof Response) {
            encodeResponse(channel, buffer, (Response) message);
        } else {
            throw new UnsupportedOperationException(new StringBuilder(32)
                    .append("Thrift codec only support encode ").append(Request.class.getName()).append(" and ")
                    .append(Response.class.getName()).toString());
        }

    }

    public Object decode(Channel channel, ChannelBuffer buffer) throws IOException {

        int available = buffer.readableBytes();

        if (available < MESSAGE_SHORTEST_LENGTH) {

            return DecodeResult.NEED_MORE_INPUT;

        } else {

            TIOStreamTransport transport = new TIOStreamTransport(new ChannelBufferInputStream(buffer));

            TBinaryProtocol protocol = new TBinaryProtocol(transport);

            short magic;
            int messageLength;

            try {
                //                protocol.readI32(); // skip the first message length
                byte[] bytes = new byte[4];
                transport.read(bytes, 0, 4);
                magic = protocol.readI16();
                messageLength = protocol.readI32();

            } catch (TException e) {
                throw new IOException(e.getMessage(), e);
            }

            if (MAGIC != magic) {
                throw new IOException(new StringBuilder(32).append("Unknown magic code ").append(magic).toString());
            }

            if (available < messageLength) {
                return DecodeResult.NEED_MORE_INPUT;
            }

            return decode(protocol);

        }

    }

    private Object decode(TProtocol protocol) throws IOException {

        // version
        String serviceName;
        long id;

        TMessage message;

        try {
            protocol.readI16();
            protocol.readByte();
            serviceName = protocol.readString();
            id = protocol.readI64();
            message = protocol.readMessageBegin();
        } catch (TException e) {
            throw new IOException(e.getMessage(), e);
        }

        if (message.type == TMessageType.CALL) {

            RpcInvocation result = new RpcInvocation();
            result.setAttachment(Constants.INTERFACE_KEY, serviceName);
            result.setMethodName(message.name);

            String argsClassName = ExtensionLoader.getExtensionLoader(ClassNameGenerator.class)
                    .getExtension(ThriftClassNameGenerator.NAME).generateArgsClassName(serviceName, message.name);

            if (StringUtils.isEmpty(argsClassName)) {
                throw new RpcException(RpcException.SERIALIZATION_EXCEPTION,
                        "The specified interface name incorrect.");
            }

            Class clazz = cachedClass.get(argsClassName);

            if (clazz == null) {
                try {

                    clazz = ClassHelper.forNameWithThreadContextClassLoader(argsClassName);

                    cachedClass.putIfAbsent(argsClassName, clazz);

                } catch (ClassNotFoundException e) {
                    throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
                }
            }

            TBase args;

            try {
                args = (TBase) clazz.newInstance();
            } catch (InstantiationException e) {
                throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
            } catch (IllegalAccessException e) {
                throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
            }

            try {
                args.read(protocol);
                protocol.readMessageEnd();
            } catch (TException e) {
                throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
            }

            List<Object> parameters = new ArrayList<Object>();
            List<Class<?>> parameterTypes = new ArrayList<Class<?>>();
            int index = 1;

            while (true) {

                TFieldIdEnum fieldIdEnum = args.fieldForId(index++);

                if (fieldIdEnum == null) {
                    break;
                }

                String fieldName = fieldIdEnum.getFieldName();

                String getMethodName = ThriftUtils.generateGetMethodName(fieldName);

                Method getMethod;

                try {
                    getMethod = clazz.getMethod(getMethodName);
                } catch (NoSuchMethodException e) {
                    throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
                }

                parameterTypes.add(getMethod.getReturnType());
                try {
                    parameters.add(getMethod.invoke(args));
                } catch (IllegalAccessException e) {
                    throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
                } catch (InvocationTargetException e) {
                    throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
                }

            }

            result.setArguments(parameters.toArray());
            result.setParameterTypes(parameterTypes.toArray(new Class[parameterTypes.size()]));

            Request request = new Request(id);
            request.setData(result);

            cachedRequest.putIfAbsent(id, RequestData.create(message.seqid, serviceName, message.name));

            return request;

        } else if (message.type == TMessageType.EXCEPTION) {

            TApplicationException exception;

            try {
                exception = TApplicationException.read(protocol);
                protocol.readMessageEnd();
            } catch (TException e) {
                throw new IOException(e.getMessage(), e);
            }

            RpcResult result = new RpcResult();

            result.setException(new RpcException(exception.getMessage()));

            Response response = new Response();

            response.setResult(result);

            response.setId(id);

            return response;

        } else if (message.type == TMessageType.REPLY) {

            String resultClassName = ExtensionLoader.getExtensionLoader(ClassNameGenerator.class)
                    .getExtension(ThriftClassNameGenerator.NAME).generateResultClassName(serviceName, message.name);

            if (StringUtils.isEmpty(resultClassName)) {
                throw new IllegalArgumentException(new StringBuilder(32)
                        .append("Could not infer service result class name from service name ").append(serviceName)
                        .append(", the service name you specified may not generated by thrift idl compiler")
                        .toString());
            }

            Class<?> clazz = cachedClass.get(resultClassName);

            if (clazz == null) {

                try {

                    clazz = ClassHelper.forNameWithThreadContextClassLoader(resultClassName);

                    cachedClass.putIfAbsent(resultClassName, clazz);

                } catch (ClassNotFoundException e) {
                    throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
                }

            }

            TBase<?, ? extends TFieldIdEnum> result;
            try {
                result = (TBase<?, ?>) clazz.newInstance();
            } catch (InstantiationException e) {
                throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
            } catch (IllegalAccessException e) {
                throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
            }

            try {
                result.read(protocol);
                protocol.readMessageEnd();
            } catch (TException e) {
                throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
            }

            Object realResult = null;

            int index = 0;

            while (true) {

                TFieldIdEnum fieldIdEnum = result.fieldForId(index++);

                if (fieldIdEnum == null) {
                    break;
                }

                Field field;

                try {
                    field = clazz.getDeclaredField(fieldIdEnum.getFieldName());
                    field.setAccessible(true);
                } catch (NoSuchFieldException e) {
                    throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
                }

                try {
                    realResult = field.get(result);
                } catch (IllegalAccessException e) {
                    throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
                }

                if (realResult != null) {
                    break;
                }

            }

            Response response = new Response();

            response.setId(id);

            RpcResult rpcResult = new RpcResult();

            if (realResult instanceof Throwable) {
                rpcResult.setException((Throwable) realResult);
            } else {
                rpcResult.setValue(realResult);
            }

            response.setResult(rpcResult);

            return response;

        } else {
            // Impossible
            throw new IOException();
        }

    }

    private void encodeRequest(Channel channel, ChannelBuffer buffer, Request request) throws IOException {

        RpcInvocation inv = (RpcInvocation) request.getData();

        int seqId = nextSeqId();

        String serviceName = inv.getAttachment(Constants.INTERFACE_KEY);

        if (StringUtils.isEmpty(serviceName)) {
            throw new IllegalArgumentException(
                    new StringBuilder(32).append("Could not find service name in attachment with key ")
                            .append(Constants.INTERFACE_KEY).toString());
        }

        TMessage message = new TMessage(inv.getMethodName(), TMessageType.CALL, seqId);

        String methodArgs = ExtensionLoader
                .getExtensionLoader(ClassNameGenerator.class).getExtension(channel.getUrl()
                        .getParameter(ThriftConstants.CLASS_NAME_GENERATOR_KEY, ThriftClassNameGenerator.NAME))
                .generateArgsClassName(serviceName, inv.getMethodName());

        if (StringUtils.isEmpty(methodArgs)) {
            throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, new StringBuilder(32)
                    .append("Could not encode request, the specified interface may be incorrect.").toString());
        }

        Class<?> clazz = cachedClass.get(methodArgs);

        if (clazz == null) {

            try {

                clazz = ClassHelper.forNameWithThreadContextClassLoader(methodArgs);

                cachedClass.putIfAbsent(methodArgs, clazz);

            } catch (ClassNotFoundException e) {
                throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
            }

        }

        TBase args;

        try {
            args = (TBase) clazz.newInstance();
        } catch (InstantiationException e) {
            throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
        } catch (IllegalAccessException e) {
            throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
        }

        for (int i = 0; i < inv.getArguments().length; i++) {

            Object obj = inv.getArguments()[i];

            if (obj == null) {
                continue;
            }

            TFieldIdEnum field = args.fieldForId(i + 1);

            String setMethodName = ThriftUtils.generateSetMethodName(field.getFieldName());

            Method method;

            try {
                method = clazz.getMethod(setMethodName, inv.getParameterTypes()[i]);
            } catch (NoSuchMethodException e) {
                throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
            }

            try {
                method.invoke(args, obj);
            } catch (IllegalAccessException e) {
                throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
            } catch (InvocationTargetException e) {
                throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
            }

        }

        RandomAccessByteArrayOutputStream bos = new RandomAccessByteArrayOutputStream(1024);

        TIOStreamTransport transport = new TIOStreamTransport(bos);

        TBinaryProtocol protocol = new TBinaryProtocol(transport);

        int headerLength, messageLength;

        byte[] bytes = new byte[4];
        try {
            // magic
            protocol.writeI16(MAGIC);
            // message length placeholder
            protocol.writeI32(Integer.MAX_VALUE);
            // message header length placeholder
            protocol.writeI16(Short.MAX_VALUE);
            // version
            protocol.writeByte(VERSION);
            // service name
            protocol.writeString(serviceName);
            // dubbo request id
            protocol.writeI64(request.getId());
            protocol.getTransport().flush();
            // header size
            headerLength = bos.size();

            // message body
            protocol.writeMessageBegin(message);
            args.write(protocol);
            protocol.writeMessageEnd();
            protocol.getTransport().flush();
            int oldIndex = messageLength = bos.size();

            // fill in message length and header length
            try {
                TFramedTransport.encodeFrameSize(messageLength, bytes);
                bos.setWriteIndex(MESSAGE_LENGTH_INDEX);
                protocol.writeI32(messageLength);
                bos.setWriteIndex(MESSAGE_HEADER_LENGTH_INDEX);
                protocol.writeI16((short) (0xffff & headerLength));
            } finally {
                bos.setWriteIndex(oldIndex);
            }

        } catch (TException e) {
            throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
        }

        buffer.writeBytes(bytes);
        buffer.writeBytes(bos.toByteArray());

    }

    private void encodeResponse(Channel channel, ChannelBuffer buffer, Response response) throws IOException {

        RpcResult result = (RpcResult) response.getResult();

        RequestData rd = cachedRequest.get(response.getId());

        String resultClassName = ExtensionLoader
                .getExtensionLoader(ClassNameGenerator.class).getExtension(channel.getUrl()
                        .getParameter(ThriftConstants.CLASS_NAME_GENERATOR_KEY, ThriftClassNameGenerator.NAME))
                .generateResultClassName(rd.serviceName, rd.methodName);

        if (StringUtils.isEmpty(resultClassName)) {
            throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, new StringBuilder(32)
                    .append("Could not encode response, the specified interface may be incorrect.").toString());
        }

        Class clazz = cachedClass.get(resultClassName);

        if (clazz == null) {

            try {
                clazz = ClassHelper.forNameWithThreadContextClassLoader(resultClassName);
                cachedClass.putIfAbsent(resultClassName, clazz);
            } catch (ClassNotFoundException e) {
                throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
            }

        }

        TBase resultObj;

        try {
            resultObj = (TBase) clazz.newInstance();
        } catch (InstantiationException e) {
            throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
        } catch (IllegalAccessException e) {
            throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
        }

        TApplicationException applicationException = null;
        TMessage message;

        if (result.hasException()) {
            Throwable throwable = result.getException();
            int index = 1;
            boolean found = false;
            while (true) {
                TFieldIdEnum fieldIdEnum = resultObj.fieldForId(index++);
                if (fieldIdEnum == null) {
                    break;
                }
                String fieldName = fieldIdEnum.getFieldName();
                String getMethodName = ThriftUtils.generateGetMethodName(fieldName);
                String setMethodName = ThriftUtils.generateSetMethodName(fieldName);
                Method getMethod;
                Method setMethod;
                try {
                    getMethod = clazz.getMethod(getMethodName);
                    if (getMethod.getReturnType().equals(throwable.getClass())) {
                        found = true;
                        setMethod = clazz.getMethod(setMethodName, throwable.getClass());
                        setMethod.invoke(resultObj, throwable);
                    }
                } catch (NoSuchMethodException e) {
                    throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
                } catch (InvocationTargetException e) {
                    throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
                } catch (IllegalAccessException e) {
                    throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
                }
            }

            if (!found) {
                applicationException = new TApplicationException(throwable.getMessage());
            }

        } else {
            Object realResult = result.getResult();
            // result field id is 0
            String fieldName = resultObj.fieldForId(0).getFieldName();
            String setMethodName = ThriftUtils.generateSetMethodName(fieldName);
            String getMethodName = ThriftUtils.generateGetMethodName(fieldName);
            Method getMethod;
            Method setMethod;
            try {
                getMethod = clazz.getMethod(getMethodName);
                setMethod = clazz.getMethod(setMethodName, getMethod.getReturnType());
                setMethod.invoke(resultObj, realResult);
            } catch (NoSuchMethodException e) {
                throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
            } catch (InvocationTargetException e) {
                throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
            } catch (IllegalAccessException e) {
                throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
            }

        }

        if (applicationException != null) {
            message = new TMessage(rd.methodName, TMessageType.EXCEPTION, rd.id);
        } else {
            message = new TMessage(rd.methodName, TMessageType.REPLY, rd.id);
        }

        RandomAccessByteArrayOutputStream bos = new RandomAccessByteArrayOutputStream(1024);

        TIOStreamTransport transport = new TIOStreamTransport(bos);

        TBinaryProtocol protocol = new TBinaryProtocol(transport);

        int messageLength;
        int headerLength;

        byte[] bytes = new byte[4];
        try {
            // magic
            protocol.writeI16(MAGIC);
            // message length
            protocol.writeI32(Integer.MAX_VALUE);
            // message header length
            protocol.writeI16(Short.MAX_VALUE);
            // version
            protocol.writeByte(VERSION);
            // service name
            protocol.writeString(rd.serviceName);
            // id
            protocol.writeI64(response.getId());
            protocol.getTransport().flush();
            headerLength = bos.size();

            // message
            protocol.writeMessageBegin(message);
            switch (message.type) {
            case TMessageType.EXCEPTION:
                applicationException.write(protocol);
                break;
            case TMessageType.REPLY:
                resultObj.write(protocol);
                break;
            }
            protocol.writeMessageEnd();
            protocol.getTransport().flush();
            int oldIndex = messageLength = bos.size();

            try {
                TFramedTransport.encodeFrameSize(messageLength, bytes);
                bos.setWriteIndex(MESSAGE_LENGTH_INDEX);
                protocol.writeI32(messageLength);
                bos.setWriteIndex(MESSAGE_HEADER_LENGTH_INDEX);
                protocol.writeI16((short) (0xffff & headerLength));
            } finally {
                bos.setWriteIndex(oldIndex);
            }

        } catch (TException e) {
            throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
        }

        buffer.writeBytes(bytes);
        buffer.writeBytes(bos.toByteArray());

    }

    private static int nextSeqId() {
        return THRIFT_SEQ_ID.incrementAndGet();
    }

    // just for test
    static int getSeqId() {
        return THRIFT_SEQ_ID.get();
    }

    static class RequestData {
        int id;
        String serviceName;
        String methodName;

        static RequestData create(int id, String sn, String mn) {
            RequestData result = new RequestData();
            result.id = id;
            result.serviceName = sn;
            result.methodName = mn;
            return result;
        }

    }

}