Java tutorial
/* * Copyright 2012 The Netty Project * * The Netty Project licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package io.netty.channel.socket.nio; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelException; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelOutboundBuffer; import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoop; import io.netty.channel.FileRegion; import io.netty.channel.RecvByteBufAllocator; import io.netty.channel.nio.AbstractNioByteChannel; import io.netty.channel.socket.DefaultSocketChannelConfig; import io.netty.channel.socket.ServerSocketChannel; import io.netty.channel.socket.SocketChannelConfig; import io.netty.util.concurrent.GlobalEventExecutor; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.SocketUtils; import io.netty.util.internal.SuppressJava6Requirement; import io.netty.util.internal.UnstableApi; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import java.io.IOException; import java.net.InetSocketAddress; import java.net.Socket; import java.net.SocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; import java.nio.channels.spi.SelectorProvider; import java.util.Map; import java.util.concurrent.Executor; import static io.netty.channel.internal.ChannelUtils.MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD; /** * {@link io.netty.channel.socket.SocketChannel} which uses NIO selector based implementation. */ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty.channel.socket.SocketChannel { private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class); private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider.provider(); private static SocketChannel newSocket(SelectorProvider provider) { try { /** * Use the {@link SelectorProvider} to open {@link SocketChannel} and so remove condition in * {@link SelectorProvider#provider()} which is called by each SocketChannel.open() otherwise. * * See <a href="https://github.com/netty/netty/issues/2308">#2308</a>. */ return provider.openSocketChannel(); } catch (IOException e) { throw new ChannelException("Failed to open a socket.", e); } } private final SocketChannelConfig config; /** * Create a new instance */ public NioSocketChannel() { this(DEFAULT_SELECTOR_PROVIDER); } /** * Create a new instance using the given {@link SelectorProvider}. */ public NioSocketChannel(SelectorProvider provider) { this(newSocket(provider)); } /** * Create a new instance using the given {@link SocketChannel}. */ public NioSocketChannel(SocketChannel socket) { this(null, socket); } /** * Create a new instance * * @param parent the {@link Channel} which created this instance or {@code null} if it was created by the user * @param socket the {@link SocketChannel} which will be used */ public NioSocketChannel(Channel parent, SocketChannel socket) { super(parent, socket); config = new NioSocketChannelConfig(this, socket.socket()); } @Override public ServerSocketChannel parent() { return (ServerSocketChannel) super.parent(); } @Override public SocketChannelConfig config() { return config; } @Override protected SocketChannel javaChannel() { return (SocketChannel) super.javaChannel(); } @Override public boolean isActive() { SocketChannel ch = javaChannel(); return ch.isOpen() && ch.isConnected(); } @Override public boolean isOutputShutdown() { return javaChannel().socket().isOutputShutdown() || !isActive(); } @Override public boolean isInputShutdown() { return javaChannel().socket().isInputShutdown() || !isActive(); } @Override public boolean isShutdown() { Socket socket = javaChannel().socket(); return socket.isInputShutdown() && socket.isOutputShutdown() || !isActive(); } @Override public InetSocketAddress localAddress() { return (InetSocketAddress) super.localAddress(); } @Override public InetSocketAddress remoteAddress() { return (InetSocketAddress) super.remoteAddress(); } @SuppressJava6Requirement(reason = "Usage guarded by java version check") @UnstableApi @Override protected final void doShutdownOutput() throws Exception { if (PlatformDependent.javaVersion() >= 7) { javaChannel().shutdownOutput(); } else { javaChannel().socket().shutdownOutput(); } } @Override public ChannelFuture shutdownOutput() { return shutdownOutput(newPromise()); } @Override public ChannelFuture shutdownOutput(final ChannelPromise promise) { final EventLoop loop = eventLoop(); if (loop.inEventLoop()) { ((AbstractUnsafe) unsafe()).shutdownOutput(promise); } else { loop.execute(new Runnable() { @Override public void run() { ((AbstractUnsafe) unsafe()).shutdownOutput(promise); } }); } return promise; } @Override public ChannelFuture shutdownInput() { return shutdownInput(newPromise()); } @Override protected boolean isInputShutdown0() { return isInputShutdown(); } @Override public ChannelFuture shutdownInput(final ChannelPromise promise) { EventLoop loop = eventLoop(); if (loop.inEventLoop()) { shutdownInput0(promise); } else { loop.execute(new Runnable() { @Override public void run() { shutdownInput0(promise); } }); } return promise; } @Override public ChannelFuture shutdown() { return shutdown(newPromise()); } @Override public ChannelFuture shutdown(final ChannelPromise promise) { ChannelFuture shutdownOutputFuture = shutdownOutput(); if (shutdownOutputFuture.isDone()) { shutdownOutputDone(shutdownOutputFuture, promise); } else { shutdownOutputFuture.addListener(new ChannelFutureListener() { @Override public void operationComplete(final ChannelFuture shutdownOutputFuture) throws Exception { shutdownOutputDone(shutdownOutputFuture, promise); } }); } return promise; } private void shutdownOutputDone(final ChannelFuture shutdownOutputFuture, final ChannelPromise promise) { ChannelFuture shutdownInputFuture = shutdownInput(); if (shutdownInputFuture.isDone()) { shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise); } else { shutdownInputFuture.addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture shutdownInputFuture) throws Exception { shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise); } }); } } private static void shutdownDone(ChannelFuture shutdownOutputFuture, ChannelFuture shutdownInputFuture, ChannelPromise promise) { Throwable shutdownOutputCause = shutdownOutputFuture.cause(); Throwable shutdownInputCause = shutdownInputFuture.cause(); if (shutdownOutputCause != null) { if (shutdownInputCause != null) { logger.debug("Exception suppressed because a previous exception occurred.", shutdownInputCause); } promise.setFailure(shutdownOutputCause); } else if (shutdownInputCause != null) { promise.setFailure(shutdownInputCause); } else { promise.setSuccess(); } } private void shutdownInput0(final ChannelPromise promise) { try { shutdownInput0(); promise.setSuccess(); } catch (Throwable t) { promise.setFailure(t); } } @SuppressJava6Requirement(reason = "Usage guarded by java version check") private void shutdownInput0() throws Exception { if (PlatformDependent.javaVersion() >= 7) { javaChannel().shutdownInput(); } else { javaChannel().socket().shutdownInput(); } } @Override protected SocketAddress localAddress0() { return javaChannel().socket().getLocalSocketAddress(); } @Override protected SocketAddress remoteAddress0() { return javaChannel().socket().getRemoteSocketAddress(); } @Override protected void doBind(SocketAddress localAddress) throws Exception { doBind0(localAddress); } private void doBind0(SocketAddress localAddress) throws Exception { if (PlatformDependent.javaVersion() >= 7) { SocketUtils.bind(javaChannel(), localAddress); } else { SocketUtils.bind(javaChannel().socket(), localAddress); } } @Override protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception { if (localAddress != null) { doBind0(localAddress); } boolean success = false; try { boolean connected = SocketUtils.connect(javaChannel(), remoteAddress); if (!connected) { selectionKey().interestOps(SelectionKey.OP_CONNECT); } success = true; return connected; } finally { if (!success) { doClose(); } } } @Override protected void doFinishConnect() throws Exception { if (!javaChannel().finishConnect()) { throw new Error(); } } @Override protected void doDisconnect() throws Exception { doClose(); } @Override protected void doClose() throws Exception { super.doClose(); javaChannel().close(); } @Override protected int doReadBytes(ByteBuf byteBuf) throws Exception { final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle(); allocHandle.attemptedBytesRead(byteBuf.writableBytes()); return byteBuf.writeBytes(javaChannel(), allocHandle.attemptedBytesRead()); } @Override protected int doWriteBytes(ByteBuf buf) throws Exception { final int expectedWrittenBytes = buf.readableBytes(); return buf.readBytes(javaChannel(), expectedWrittenBytes); } @Override protected long doWriteFileRegion(FileRegion region) throws Exception { final long position = region.transferred(); return region.transferTo(javaChannel(), position); } private void adjustMaxBytesPerGatheringWrite(int attempted, int written, int oldMaxBytesPerGatheringWrite) { // By default we track the SO_SNDBUF when ever it is explicitly set. However some OSes may dynamically change // SO_SNDBUF (and other characteristics that determine how much data can be written at once) so we should try // make a best effort to adjust as OS behavior changes. if (attempted == written) { if (attempted << 1 > oldMaxBytesPerGatheringWrite) { ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted << 1); } } else if (attempted > MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD && written < attempted >>> 1) { ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted >>> 1); } } @Override protected void doWrite(ChannelOutboundBuffer in) throws Exception { SocketChannel ch = javaChannel(); int writeSpinCount = config().getWriteSpinCount(); do { if (in.isEmpty()) { // All written so clear OP_WRITE clearOpWrite(); // Directly return here so incompleteWrite(...) is not called. return; } // Ensure the pending writes are made of ByteBufs only. int maxBytesPerGatheringWrite = ((NioSocketChannelConfig) config).getMaxBytesPerGatheringWrite(); ByteBuffer[] nioBuffers = in.nioBuffers(1024, maxBytesPerGatheringWrite); int nioBufferCnt = in.nioBufferCount(); // Always us nioBuffers() to workaround data-corruption. // See https://github.com/netty/netty/issues/2761 switch (nioBufferCnt) { case 0: // We have something else beside ByteBuffers to write so fallback to normal writes. writeSpinCount -= doWrite0(in); break; case 1: { // Only one ByteBuf so use non-gathering write // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need // to check if the total size of all the buffers is non-zero. ByteBuffer buffer = nioBuffers[0]; int attemptedBytes = buffer.remaining(); final int localWrittenBytes = ch.write(buffer); if (localWrittenBytes <= 0) { incompleteWrite(true); return; } adjustMaxBytesPerGatheringWrite(attemptedBytes, localWrittenBytes, maxBytesPerGatheringWrite); in.removeBytes(localWrittenBytes); --writeSpinCount; break; } default: { // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need // to check if the total size of all the buffers is non-zero. // We limit the max amount to int above so cast is safe long attemptedBytes = in.nioBufferSize(); final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt); if (localWrittenBytes <= 0) { incompleteWrite(true); return; } // Casting to int is safe because we limit the total amount of data in the nioBuffers to int above. adjustMaxBytesPerGatheringWrite((int) attemptedBytes, (int) localWrittenBytes, maxBytesPerGatheringWrite); in.removeBytes(localWrittenBytes); --writeSpinCount; break; } } } while (writeSpinCount > 0); incompleteWrite(writeSpinCount < 0); } @Override protected AbstractNioUnsafe newUnsafe() { return new NioSocketChannelUnsafe(); } private final class NioSocketChannelUnsafe extends NioByteUnsafe { @Override protected Executor prepareToClose() { try { if (javaChannel().isOpen() && config().getSoLinger() > 0) { // We need to cancel this key of the channel so we may not end up in a eventloop spin // because we try to read or write until the actual close happens which may be later due // SO_LINGER handling. // See https://github.com/netty/netty/issues/4449 doDeregister(); return GlobalEventExecutor.INSTANCE; } } catch (Throwable ignore) { // Ignore the error as the underlying channel may be closed in the meantime and so // getSoLinger() may produce an exception. In this case we just return null. // See https://github.com/netty/netty/issues/4449 } return null; } } private final class NioSocketChannelConfig extends DefaultSocketChannelConfig { private volatile int maxBytesPerGatheringWrite = Integer.MAX_VALUE; private NioSocketChannelConfig(NioSocketChannel channel, Socket javaSocket) { super(channel, javaSocket); calculateMaxBytesPerGatheringWrite(); } @Override protected void autoReadCleared() { clearReadPending(); } @Override public NioSocketChannelConfig setSendBufferSize(int sendBufferSize) { super.setSendBufferSize(sendBufferSize); calculateMaxBytesPerGatheringWrite(); return this; } @Override public <T> boolean setOption(ChannelOption<T> option, T value) { if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) { return NioChannelOption.setOption(jdkChannel(), (NioChannelOption<T>) option, value); } return super.setOption(option, value); } @Override public <T> T getOption(ChannelOption<T> option) { if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) { return NioChannelOption.getOption(jdkChannel(), (NioChannelOption<T>) option); } return super.getOption(option); } @Override public Map<ChannelOption<?>, Object> getOptions() { if (PlatformDependent.javaVersion() >= 7) { return getOptions(super.getOptions(), NioChannelOption.getOptions(jdkChannel())); } return super.getOptions(); } void setMaxBytesPerGatheringWrite(int maxBytesPerGatheringWrite) { this.maxBytesPerGatheringWrite = maxBytesPerGatheringWrite; } int getMaxBytesPerGatheringWrite() { return maxBytesPerGatheringWrite; } private void calculateMaxBytesPerGatheringWrite() { // Multiply by 2 to give some extra space in case the OS can process write data faster than we can provide. int newSendBufferSize = getSendBufferSize() << 1; if (newSendBufferSize > 0) { setMaxBytesPerGatheringWrite(newSendBufferSize); } } private SocketChannel jdkChannel() { return ((NioSocketChannel) channel).javaChannel(); } } }