Java tutorial
/* * Copyright 2005-2014 Red Hat, Inc. * Red Hat licenses this file to you under the Apache License, version * 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. See the License for the specific language governing * permissions and limitations under the License. */ package org.hornetq.amqp.dealer.util; import java.nio.ByteBuffer; import java.util.concurrent.Executor; import io.netty.buffer.ByteBuf; import org.apache.qpid.proton.Proton; import org.apache.qpid.proton.amqp.transport.AmqpError; import org.apache.qpid.proton.amqp.transport.ErrorCondition; import org.apache.qpid.proton.engine.Collector; import org.apache.qpid.proton.engine.Connection; import org.apache.qpid.proton.engine.Delivery; import org.apache.qpid.proton.engine.EndpointState; import org.apache.qpid.proton.engine.Event; import org.apache.qpid.proton.engine.Link; import org.apache.qpid.proton.engine.Sasl; import org.apache.qpid.proton.engine.Session; import org.apache.qpid.proton.engine.Transport; import org.apache.qpid.proton.engine.TransportResultFactory; import org.hornetq.amqp.dealer.SASL; /** * @author Clebert Suconic */ public abstract class ProtonTrio { static ThreadLocal<Boolean> inDispatch = new ThreadLocal<>(); private Sasl sasl; private Runnable saslCallback; protected final Transport transport = Proton.transport(); protected final Connection connection = Proton.connection(); protected final Collector collector = Proton.collector(); protected final Object lock = new Object(); private String username; private String password; private Executor executor; public ProtonTrio(Executor executor) { // TODO parameterize maxFrameSize //transport.setMaxFrameSize(1024 * 1024); transport.bind(connection); connection.collect(collector); this.executor = executor; } public void setSaslCallback(Runnable runnable) { this.saslCallback = runnable; } public Transport getTransport() { return transport; } public Connection getConnection() { return connection; } final Runnable dispatchRunnable = new Runnable() { public void run() { dispatch(); } }; public String getUsername() { return username; } public String getPassword() { return password; } public Object getLock() { return lock; } public void createServerSasl(String... mechanisms) { sasl = transport.sasl(); sasl.server(); sasl.setMechanisms(mechanisms); } public void createClientSasl(SASL clientSASL) { if (clientSASL != null) { sasl = transport.sasl(); sasl.setMechanisms(clientSASL.getName()); byte[] initialSasl = clientSASL.getBytes(); sasl.send(initialSasl, 0, initialSasl.length); } } public void close() { synchronized (lock) { connection.close(); transport.close(); dispatch(); } } /** * this method will change the readerIndex on bytes to the latest read position */ public void pump(ByteBuf bytes) { if (bytes.readableBytes() < 8) { return; } try { synchronized (lock) { while (bytes.readableBytes() > 0) { int capacity = transport.capacity(); if (capacity > 0) { ByteBuffer tail = transport.tail(); int min = Math.min(capacity, bytes.readableBytes()); tail.limit(min); bytes.readBytes(tail); transport.process(); checkSASL(); dispatch(); } else { if (capacity == 0) { System.out.println("abandoning: " + bytes.readableBytes()); } else { System.out.println("transport closed, discarding: " + bytes.readableBytes()); } bytes.skipBytes(bytes.readableBytes()); } } } } finally { // After everything is processed we still need to check for more dispatches! dispatch(); } } private void checkSASL() { if (sasl != null && sasl.getRemoteMechanisms().length > 0) { byte[] dataSASL = new byte[sasl.pending()]; sasl.recv(dataSASL, 0, dataSASL.length); if (sasl.getRemoteMechanisms()[0].equals("PLAIN")) { setUserPass(dataSASL); } // TODO: do the proper SASL authorization here // call an abstract method (authentication (bytes[]) sasl.done(Sasl.SaslOutcome.PN_SASL_OK); sasl = null; if (saslCallback != null) { saslCallback.run(); } } } /** It will only start a dispatch if it's not on the dispatch thread already */ public void dispatchIfNeeded() { if (inDispatch.get() != null) { return; } dispatch(); } public void dispatch() { if (inDispatch.get() != null) { // new Exception("Already in dispatch mode, using executor").printStackTrace(); executor.execute(dispatchRunnable); return; } inDispatch.set(Boolean.TRUE); try { internalDispatch(); } finally { inDispatch.set(null); } } protected void internalDispatch() { synchronized (lock) { try { Event ev; while ((ev = collector.peek()) != null) { dispatch(ev); collector.pop(); } } catch (Exception e) { connection.setCondition(new ErrorCondition(AmqpError.INTERNAL_ERROR, e.getMessage())); } // forcing transport on every dispatch call onTransport(transport); } } protected void onRemoteState(Connection connection) throws Exception { if (connection.getRemoteState() == EndpointState.ACTIVE) { connection.open(); connectionOpened(connection); } else if (connection.getRemoteState() == EndpointState.CLOSED) { connection.close(); connectionClosed(connection); } } protected abstract void connectionOpened(Connection connection) throws Exception; protected abstract void connectionClosed(Connection connection) throws Exception; protected void onRemoteState(Session session) { if (session.getRemoteState() == EndpointState.ACTIVE) { session.open(); sessionOpened(session); } else if (session.getRemoteState() == EndpointState.CLOSED) { session.close(); sessionClosed(session); } } protected abstract void sessionOpened(Session session); protected abstract void sessionClosed(Session session); protected void onRemoteState(Link link) { if (link.getRemoteState() == EndpointState.ACTIVE) { link.open(); linkOpened(link); } else if (link.getRemoteState() == EndpointState.CLOSED) { link.close(); linkClosed(link); } else if (link.getLocalState() == EndpointState.ACTIVE) { linkActive(link); } } protected abstract void linkOpened(Link link); protected abstract void linkClosed(Link link); /** * Do we really need this? This used to be done with: * <p/> * <p/> * link = (LinkImpl) protonConnection.linkHead(ProtonProtocolManager.ACTIVE, ProtonProtocolManager.ANY_ENDPOINT_STATE); * while (link != null) * { * try * { * protonProtocolManager.handleActiveLink(link); * } * catch (HornetQAMQPException e) * { * link.setCondition(new ErrorCondition(e.getAmqpError(), e.getMessage())); * } * link = (LinkImpl) link.next(ProtonProtocolManager.ACTIVE, ProtonProtocolManager.ANY_ENDPOINT_STATE); * } */ protected abstract void linkActive(Link link); protected void onLocalState(Connection connection) { } protected void onLocalState(Session session) { } protected void onLocalState(Link link) { } protected void onFlow(Link link) { } protected abstract void onDelivery(Delivery delivery); protected abstract void onTransport(Transport transport); private void dispatch(Event event) throws Exception { switch (event.getType()) { case CONNECTION_REMOTE_OPEN: case CONNECTION_REMOTE_CLOSE: onRemoteState(event.getConnection()); break; case CONNECTION_OPEN: case CONNECTION_CLOSE: onLocalState(event.getConnection()); break; case SESSION_REMOTE_OPEN: case SESSION_REMOTE_CLOSE: onRemoteState(event.getSession()); break; case SESSION_OPEN: case SESSION_CLOSE: onLocalState(event.getSession()); break; case LINK_REMOTE_OPEN: case LINK_REMOTE_CLOSE: onRemoteState(event.getLink()); break; case LINK_OPEN: case LINK_CLOSE: onLocalState(event.getLink()); break; case LINK_FLOW: onFlow(event.getLink()); break; case TRANSPORT: onTransport(event.getTransport()); break; case DELIVERY: onDelivery(event.getDelivery()); break; } } /** * this is to be used with SASL */ private void setUserPass(byte[] data) { String bytes = new String(data); String[] credentials = bytes.split(Character.toString((char) 0)); int offSet = 0; if (credentials.length > 0) { if (credentials[0].length() == 0) { offSet = 1; } if (credentials.length >= offSet) { username = credentials[offSet]; } if (credentials.length >= (offSet + 1)) { password = credentials[offSet + 1]; } } } }