Java tutorial
/* * Copyright 2014 Basho Technologies Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.basho.riak.client.core.netty; import com.basho.riak.client.core.RiakMessage; import com.basho.riak.client.core.util.Constants; import com.basho.riak.protobuf.RiakMessageCodes; import com.basho.riak.protobuf.RiakPB; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.ssl.SslHandler; import io.netty.util.concurrent.DefaultPromise; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.GenericFutureListener; import java.io.IOException; import java.util.List; import java.util.concurrent.CountDownLatch; import javax.net.ssl.SSLEngine; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * * @author Brian Roach <roach at basho dot com> */ public class RiakSecurityDecoder extends ByteToMessageDecoder { private final CountDownLatch promiseLatch = new CountDownLatch(1); private final SSLEngine sslEngine; private final String username; private final String password; private final Logger logger = LoggerFactory.getLogger(RiakSecurityDecoder.class); private volatile DefaultPromise<Void> promise; private enum State { TLS_START, TLS_WAIT, SSL_WAIT, AUTH_WAIT } private volatile State state = State.TLS_START; public RiakSecurityDecoder(SSLEngine engine, String username, String password) { this.sslEngine = engine; this.username = username; this.password = password; } @Override protected void decode(ChannelHandlerContext chc, ByteBuf in, List<Object> out) throws Exception { // Make sure we have 4 bytes if (in.readableBytes() >= 4) { in.markReaderIndex(); int length = in.readInt(); // See if we have the full frame. if (in.readableBytes() < length) { in.resetReaderIndex(); } else { byte code = in.readByte(); byte[] protobuf = new byte[length - 1]; in.readBytes(protobuf); switch (state) { case TLS_WAIT: switch (code) { case RiakMessageCodes.MSG_StartTls: logger.debug("Received MSG_RpbStartTls reply"); // change state this.state = State.SSL_WAIT; // insert SSLHandler SslHandler sslHandler = new SslHandler(sslEngine); // get promise Future<Channel> hsFuture = sslHandler.handshakeFuture(); // register callback hsFuture.addListener(new SslListener()); // Add handler chc.channel().pipeline().addFirst(Constants.SSL_HANDLER, sslHandler); break; case RiakMessageCodes.MSG_ErrorResp: logger.debug("Received MSG_ErrorResp reply to startTls"); promise.tryFailure((riakErrorToException(protobuf))); break; default: promise.tryFailure( new RiakResponseException(0, "Invalid return code during StartTLS; " + code)); } break; case AUTH_WAIT: chc.channel().pipeline().remove(this); switch (code) { case RiakMessageCodes.MSG_AuthResp: logger.debug("Received MSG_RpbAuthResp reply"); promise.trySuccess(null); break; case RiakMessageCodes.MSG_ErrorResp: logger.debug("Received MSG_ErrorResp reply to auth"); promise.tryFailure(riakErrorToException(protobuf)); break; default: promise.tryFailure( new RiakResponseException(0, "Invalid return code during Auth; " + code)); } break; default: // WTF? logger.error("Received message while not in TLS_WAIT or AUTH_WAIT"); promise.tryFailure( new IllegalStateException("Received message while not in TLS_WAIT or AUTH_WAIT")); } } } } private RiakResponseException riakErrorToException(byte[] protobuf) { try { RiakPB.RpbErrorResp error = RiakPB.RpbErrorResp.parseFrom(protobuf); return new RiakResponseException(error.getErrcode(), error.getErrmsg().toStringUtf8()); } catch (InvalidProtocolBufferException ex) { return null; } } private void init(ChannelHandlerContext ctx) { state = State.TLS_WAIT; promise = new DefaultPromise<Void>(ctx.executor()); promiseLatch.countDown(); ctx.channel().writeAndFlush(new RiakMessage(RiakMessageCodes.MSG_StartTls, new byte[0])); } @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { logger.debug("MyStartTlsDecoder Handler Added"); if (ctx.channel().isActive()) { init(ctx); } } @Override public void channelActive(final ChannelHandlerContext ctx) throws Exception { logger.debug("MyStartTlsDecoder Channel Active"); init(ctx); } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { promise.tryFailure(new IOException("Channel closed during auth")); ctx.fireChannelInactive(); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { if (cause.getCause() instanceof javax.net.ssl.SSLHandshakeException) { // consume } else { ctx.fireExceptionCaught(cause); } } public DefaultPromise<Void> getPromise() throws InterruptedException { promiseLatch.await(); return promise; } private class SslListener implements GenericFutureListener<Future<Channel>> { @Override public void operationComplete(Future<Channel> future) throws Exception { if (future.isSuccess()) { Channel c = future.getNow(); state = State.AUTH_WAIT; RiakPB.RpbAuthReq authReq = RiakPB.RpbAuthReq.newBuilder() .setUser(ByteString.copyFromUtf8(username)).setPassword(ByteString.copyFromUtf8(password)) .build(); c.writeAndFlush(new RiakMessage(RiakMessageCodes.MSG_AuthReq, authReq.toByteArray())); } else { promise.tryFailure(future.cause()); } } } }