Java tutorial
/* Copyright (C) 2013-2014 Computer Sciences Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ezbake.thrift; import java.lang.reflect.Constructor; import java.net.InetSocketAddress; import java.util.Properties; import ezbakehelpers.ezconfigurationhelpers.thrift.ThriftConfigurationHelper; import org.apache.commons.codec.binary.Base64; import org.apache.thrift.TBase; import org.apache.thrift.TDeserializer; import org.apache.thrift.TException; import org.apache.thrift.TProcessor; import org.apache.thrift.TSerializer; import org.apache.thrift.TServiceClient; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TCompactProtocol; import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.server.THsHaServer; import org.apache.thrift.server.TServer; import org.apache.thrift.server.TSimpleServer; import org.apache.thrift.server.TThreadPoolServer; import org.apache.thrift.transport.*; import ezbake.thrift.transport.EzSSLTransportFactory; import ezbake.thrift.transport.EzSecureClientTransport; import ezbake.thrift.transport.EzSecureServerTransport; import com.google.common.annotations.VisibleForTesting; import com.google.common.net.HostAndPort; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * A class for some common thrift functions. */ public class ThriftUtils { private static final Logger logger = LoggerFactory.getLogger(ThriftUtils.class); /** * Start a thrift service for TESTING. You should be using the thrift service runner in production * * @param processor The thrift processor for the service * @param portNumber The port to run the service on */ @VisibleForTesting public static TServer startSimpleServer(TProcessor processor, int portNumber) throws Exception { final TServerTransport transport = new TServerSocket(portNumber); return startSimpleServer(transport, processor); } @VisibleForTesting public static TServer startSslSimpleServer(TProcessor processor, int portNumber, Properties properties) throws Exception { final TServerTransport transport = getSslServerSocket(portNumber, properties); return startSimpleServer(transport, processor, properties); } private static TServer startSimpleServer(final TServerTransport transport, final TProcessor processor) throws Exception { return ThriftUtils.startSimpleServer(transport, processor, null); } private static TServer startSimpleServer(final TServerTransport transport, final TProcessor processor, Properties properties) throws Exception { TServer.AbstractServerArgs<?> serverArgs; if (properties == null) { serverArgs = new TServer.Args(transport).processor(processor); } else { serverArgs = ThriftUtils.getServerArgs(transport, properties).processor(processor); } final TServer server = new TSimpleServer(serverArgs); new Thread(new Runnable() { @Override public void run() { server.serve(); } }).start(); return server; } @VisibleForTesting public static TServer startThreadedPoolServer(TProcessor processor, int portNumber) throws Exception { final TServerTransport transport = new TServerSocket(portNumber); return startThreadedPoolServer(transport, processor); } @VisibleForTesting public static TServer startSslThreadedPoolServer(TProcessor processor, int portNumber, Properties properties) throws Exception { final TServerTransport transport = getSslServerSocket(portNumber, properties); return startThreadedPoolServer(transport, processor, properties); } private static TServer startThreadedPoolServer(final TServerTransport transport, final TProcessor processor) throws Exception { return ThriftUtils.startThreadedPoolServer(transport, processor, null); } private static TServer startThreadedPoolServer(final TServerTransport transport, final TProcessor processor, Properties properties) throws Exception { TThreadPoolServer.Args serverArgs; if (properties == null) { serverArgs = new TThreadPoolServer.Args(transport).processor(processor); } else { serverArgs = (TThreadPoolServer.Args) ThriftUtils.getServerArgs(transport, properties) .processor(processor); } final TServer server = new TThreadPoolServer(serverArgs); new Thread(new Runnable() { @Override public void run() { server.serve(); } }).start(); return server; } @VisibleForTesting public static TServer startHshaServer(TProcessor processor, int portNumber) throws Exception { final TNonblockingServerSocket socket = new TNonblockingServerSocket(portNumber); final THsHaServer.Args serverArgs = new THsHaServer.Args(socket); serverArgs.processor(processor); serverArgs.inputProtocolFactory(new TCompactProtocol.Factory()); serverArgs.outputProtocolFactory(new TCompactProtocol.Factory()); final TServer server = new THsHaServer(serverArgs); final Thread t = new Thread(new Runnable() { @Override public void run() { server.serve(); } }); t.start(); return server; } /** * Serialize a thrift object to binary. * * @param object The object to serialize * @return A byte array of the object * @throws TException */ public static byte[] serialize(TBase<?, ?> object) throws TException { return new TSerializer().serialize(object); } /** * Serialize a thrift object to a base64-encoded string. * * @param object The object to serialize * @return A string of the base64-encoded serialized binary of the object * @throws TException */ public static String serializeToBase64(TBase<?, ?> object) throws TException { return Base64.encodeBase64String(serialize(object)); } /** * Deserialize a thrift object * * @param type The type of object * @param bytes The bytes of the object * @param <T> The type of object * @return The object */ public static <T extends TBase<?, ?>> T deserialize(Class<T> type, byte[] bytes) throws TException { final TDeserializer deserializer = new TDeserializer(); try { final T object = type.newInstance(); deserializer.deserialize(object, bytes); return object; } catch (final Exception ex) { throw new TException(ex); } } /** * Deserialize a thrift object from a base64-encoded string. * * @param type The type of object * @param base64 The base64-encoded string of the serialization of the object * @param <T> The type of object * @return The object */ public static <T extends TBase<?, ?>> T deserializeFromBase64(Class<T> type, String base64) throws TException { return deserialize(type, Base64.decodeBase64(base64)); } public static TTransport getSslClientSocket(String host, int port, Properties properties) throws TTransportException { logger.debug("connecting via SSL to {}:{}", host, port); TTransport transport = EzSSLTransportFactory.getClientSocket(host, port, 0, new EzSSLTransportFactory.EzSSLTransportParameters(properties)); return transport; } public static TServerSocket getSslServerSocket(int port, Properties properties) throws TTransportException { return EzSSLTransportFactory.getServerSocket(port, 0, null, new EzSSLTransportFactory.EzSSLTransportParameters(properties)); } public static TServerSocket getSslServerSocket(InetSocketAddress addr, Properties properties) throws TTransportException { return EzSSLTransportFactory.getServerSocket(addr.getPort(), 0, addr.getAddress(), new EzSSLTransportFactory.EzSSLTransportParameters(properties)); } @SuppressWarnings("null") public static TServer.AbstractServerArgs<?> getServerArgs(TServerTransport transport, Properties properties) { TServer.AbstractServerArgs<?> args = null; ThriftConfigurationHelper thriftConfiguration = new ThriftConfigurationHelper(properties); switch (thriftConfiguration.getServerMode()) { case Simple: args = new TServer.Args(transport); break; case ThreadedPool: args = new TThreadPoolServer.Args(transport); break; case HsHa: throw new IllegalArgumentException("Unable to create an HsHa Server Args at this time"); } // Use the EzSecureTransport (exposes peer ssl certs) if using SSL if (thriftConfiguration.useSSL()) { args.inputTransportFactory(new EzSecureServerTransport.Factory(properties)); } return args; } @SuppressWarnings("unchecked") public static <Y extends TServiceClient> Y getClient(Class<Y> clazz, HostAndPort hostAndPort, Properties properties) throws NoSuchMethodException, TException, Exception { final Constructor<?> constructor = clazz.getConstructor(TProtocol.class); final Object ds = constructor.newInstance(getProtocol(hostAndPort, properties)); return (Y) ds; } @SuppressWarnings("unchecked") public static <Y extends TServiceClient> Y getClient(Class<Y> clazz, HostAndPort hostAndPort, String securityId, Properties properties) throws NoSuchMethodException, TException, Exception { final Constructor<?> constructor = clazz.getConstructor(TProtocol.class); final Object ds = constructor.newInstance(getProtocol(hostAndPort, securityId, properties)); return (Y) ds; } @SuppressWarnings("unchecked") public static <Y extends TServiceClient> Y getClient(Class<Y> clazz, HostAndPort hostAndPort, String securityId, Properties properties, TTransportFactory transportFactory) throws NoSuchMethodException, TException, Exception { final Constructor<?> constructor = clazz.getConstructor(TProtocol.class); final Object ds = constructor .newInstance(getProtocol(hostAndPort, securityId, properties, transportFactory)); return (Y) ds; } public static void quietlyClose(TServiceClient client) { try { client.getOutputProtocol().getTransport().close(); } catch (Exception ignore) { //do nothing } try { client.getInputProtocol().getTransport().close(); } catch (Exception ignore) { //do nothing } } protected static TProtocol getProtocol(HostAndPort hostAndPort, Properties properties) throws Exception { return getProtocol(hostAndPort, null, properties); } protected static TProtocol getProtocol(HostAndPort hostAndPort, String securityId, Properties properties) throws Exception { return getProtocol(hostAndPort, securityId, properties, null); } protected static TProtocol getProtocol(HostAndPort hostAndPort, String securityId, Properties properties, TTransportFactory transportFactory) throws Exception { logger.debug("getProtocol for host:port {} and security id {}", hostAndPort, securityId); TProtocol protocol; ThriftConfigurationHelper thriftConfiguration = new ThriftConfigurationHelper(properties); logger.debug("about to getTransport for host:port {} and security id {}", hostAndPort, securityId); TTransport transport = getTransport(properties, hostAndPort, securityId, transportFactory); // HsHa is using framed transport with Compact protocol, but others are using binary (for now at least) if (thriftConfiguration.getServerMode() == ThriftConfigurationHelper.ThriftServerMode.HsHa) { protocol = new TCompactProtocol(transport); } else { protocol = new TBinaryProtocol(transport); } if (!transport.isOpen()) { transport.open(); } return protocol; } protected static TTransport getTransport(Properties configuration, HostAndPort hostAndPort, String securityId, TTransportFactory transportFactory) throws TTransportException { TTransport transport; ThriftConfigurationHelper thriftConfiguration = new ThriftConfigurationHelper(configuration); logger.debug("getTransport for hostAndPort {}", hostAndPort); if (thriftConfiguration.getServerMode() == ThriftConfigurationHelper.ThriftServerMode.HsHa) { logger.debug("opening framed transport to {}", hostAndPort); transport = new TFramedTransport(new TSocket(hostAndPort.getHostText(), hostAndPort.getPort())); } else { if (thriftConfiguration.useSSL()) { logger.debug("opening SSL connection to {}", hostAndPort); transport = ThriftUtils.getSslClientSocket(hostAndPort.getHostText(), hostAndPort.getPort(), configuration); transport = new EzSecureClientTransport(transport, configuration, securityId); } else { logger.debug("opening connection in the clear (without SSL) to {}", hostAndPort); transport = new TSocket(hostAndPort.getHostText(), hostAndPort.getPort()); } } // Wrap the transport using the transportFactory (if provided) if (transportFactory != null) { transport = transportFactory.getTransport(transport); } return transport; } }