org.graylog2.plugin.inputs.transports.AbstractTcpTransport.java Source code

Java tutorial

Introduction

Here is the source code for org.graylog2.plugin.inputs.transports.AbstractTcpTransport.java

Source

/**
 * This file is part of Graylog.
 *
 * Graylog is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Graylog is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Graylog.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.graylog2.plugin.inputs.transports;

import com.codahale.metrics.Gauge;
import com.codahale.metrics.MetricRegistry;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.FixedRecvByteBufAllocator;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.socket.ServerSocketChannelConfig;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.graylog2.inputs.transports.NettyTransportConfiguration;
import org.graylog2.inputs.transports.netty.ChannelRegistrationHandler;
import org.graylog2.inputs.transports.netty.EventLoopGroupFactory;
import org.graylog2.inputs.transports.netty.ExceptionLoggingChannelHandler;
import org.graylog2.inputs.transports.netty.ServerSocketChannelFactory;
import org.graylog2.plugin.LocalMetricRegistry;
import org.graylog2.plugin.configuration.Configuration;
import org.graylog2.plugin.configuration.ConfigurationRequest;
import org.graylog2.plugin.configuration.fields.BooleanField;
import org.graylog2.plugin.configuration.fields.ConfigurationField;
import org.graylog2.plugin.configuration.fields.DropdownField;
import org.graylog2.plugin.configuration.fields.TextField;
import org.graylog2.plugin.inputs.MessageInput;
import org.graylog2.plugin.inputs.MisfireException;
import org.graylog2.plugin.inputs.annotations.ConfigClass;
import org.graylog2.plugin.inputs.transports.util.KeyUtil;
import org.graylog2.plugin.inputs.util.ConnectionCounter;
import org.graylog2.plugin.inputs.util.ThroughputCounter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import java.io.File;
import java.io.IOException;
import java.net.SocketAddress;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.EnumSet;
import java.util.LinkedHashMap;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

import static com.google.common.base.Preconditions.checkState;

public abstract class AbstractTcpTransport extends NettyTransport {
    private static final Logger LOG = LoggerFactory.getLogger(AbstractTcpTransport.class);

    private static final String CK_TLS_CERT_FILE = "tls_cert_file";
    private static final String CK_TLS_KEY_FILE = "tls_key_file";
    private static final String CK_TLS_ENABLE = "tls_enable";
    private static final String CK_TLS_KEY_PASSWORD = "tls_key_password";
    private static final String CK_TLS_CLIENT_AUTH = "tls_client_auth";
    private static final String CK_TLS_CLIENT_AUTH_TRUSTED_CERT_FILE = "tls_client_auth_cert_file";
    private static final String CK_TCP_KEEPALIVE = "tcp_keepalive";

    private static final String TLS_CLIENT_AUTH_DISABLED = "disabled";
    private static final String TLS_CLIENT_AUTH_OPTIONAL = "optional";
    private static final String TLS_CLIENT_AUTH_REQUIRED = "required";
    private static final Map<String, String> TLS_CLIENT_AUTH_OPTIONS = ImmutableMap.of(TLS_CLIENT_AUTH_DISABLED,
            TLS_CLIENT_AUTH_DISABLED, TLS_CLIENT_AUTH_OPTIONAL, TLS_CLIENT_AUTH_OPTIONAL, TLS_CLIENT_AUTH_REQUIRED,
            TLS_CLIENT_AUTH_REQUIRED);

    private final ConnectionCounter connectionCounter;
    private final AtomicInteger connections;
    private final AtomicLong totalConnections;

    protected final Configuration configuration;
    protected final EventLoopGroup parentEventLoopGroup;
    private final NettyTransportConfiguration nettyTransportConfiguration;
    private final AtomicReference<Channel> channelReference;

    private final boolean tlsEnable;
    private final String tlsKeyPassword;
    private File tlsCertFile;
    private File tlsKeyFile;
    private final File tlsClientAuthCertFile;
    private final String tlsClientAuth;
    private final boolean tcpKeepalive;

    private ChannelGroup childChannels;
    protected EventLoopGroup childEventLoopGroup;
    private ServerBootstrap bootstrap;

    public AbstractTcpTransport(Configuration configuration, ThroughputCounter throughputCounter,
            LocalMetricRegistry localRegistry, EventLoopGroup parentEventLoopGroup,
            EventLoopGroupFactory eventLoopGroupFactory, NettyTransportConfiguration nettyTransportConfiguration) {
        super(configuration, eventLoopGroupFactory, throughputCounter, localRegistry);
        this.configuration = configuration;
        this.parentEventLoopGroup = parentEventLoopGroup;
        this.nettyTransportConfiguration = nettyTransportConfiguration;
        this.channelReference = new AtomicReference<>();
        this.childChannels = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);

        this.tlsEnable = configuration.getBoolean(CK_TLS_ENABLE);
        this.tlsCertFile = getTlsFile(configuration, CK_TLS_CERT_FILE);
        this.tlsKeyFile = getTlsFile(configuration, CK_TLS_KEY_FILE);
        this.tlsKeyPassword = configuration.getString(CK_TLS_KEY_PASSWORD);
        this.tlsClientAuth = configuration.getString(CK_TLS_CLIENT_AUTH, TLS_CLIENT_AUTH_DISABLED);
        this.tlsClientAuthCertFile = getTlsFile(configuration, CK_TLS_CLIENT_AUTH_TRUSTED_CERT_FILE);

        this.tcpKeepalive = configuration.getBoolean(CK_TCP_KEEPALIVE);

        this.connections = new AtomicInteger();
        this.totalConnections = new AtomicLong();
        this.connectionCounter = new ConnectionCounter(connections, totalConnections);
        this.localRegistry.register("open_connections", new Gauge<Integer>() {
            @Override
            public Integer getValue() {
                return connections.get();
            }
        });
        this.localRegistry.register("total_connections", new Gauge<Long>() {
            @Override
            public Long getValue() {
                return totalConnections.get();
            }
        });
    }

    private File getTlsFile(Configuration configuration, String configKey) {
        return new File(configuration.getString(configKey, ""));
    }

    protected ServerBootstrap getBootstrap(MessageInput input) {
        final LinkedHashMap<String, Callable<? extends ChannelHandler>> parentHandlers = getChannelHandlers(input);
        final LinkedHashMap<String, Callable<? extends ChannelHandler>> childHandlers = getChildChannelHandlers(
                input);

        childEventLoopGroup = eventLoopGroupFactory.create(workerThreads, localRegistry, "workers");

        return new ServerBootstrap().group(parentEventLoopGroup, childEventLoopGroup)
                .channelFactory(new ServerSocketChannelFactory(nettyTransportConfiguration.getType()))
                .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
                .option(ChannelOption.RCVBUF_ALLOCATOR, new FixedRecvByteBufAllocator(8192))
                .option(ChannelOption.SO_RCVBUF, getRecvBufferSize())
                .childOption(ChannelOption.SO_RCVBUF, getRecvBufferSize())
                .childOption(ChannelOption.SO_KEEPALIVE, tcpKeepalive)
                .handler(getChannelInitializer(parentHandlers)).childHandler(getChannelInitializer(childHandlers));
    }

    @Override
    public void launch(final MessageInput input) throws MisfireException {
        try {
            bootstrap = getBootstrap(input);
            bootstrap.bind(socketAddress)
                    .addListener(new InputLaunchListener(channelReference, input, getRecvBufferSize()))
                    .syncUninterruptibly();
        } catch (Exception e) {
            throw new MisfireException(e);
        }
    }

    @Nullable
    @Override
    public SocketAddress getLocalAddress() {
        final Channel channel = channelReference.get();
        if (channel != null) {
            return channel.localAddress();
        }

        return null;
    }

    @Override
    public void stop() {
        final Channel channel = channelReference.get();
        if (channel != null) {
            channel.close();
            channel.closeFuture().syncUninterruptibly();
        }

        childChannels.close().syncUninterruptibly();

        if (childEventLoopGroup != null) {
            childEventLoopGroup.shutdownGracefully();
        }
        bootstrap = null;
    }

    @Override
    protected LinkedHashMap<String, Callable<? extends ChannelHandler>> getChildChannelHandlers(
            MessageInput input) {
        final LinkedHashMap<String, Callable<? extends ChannelHandler>> handlers = new LinkedHashMap<>();

        handlers.put("channel-registration", () -> new ChannelRegistrationHandler(childChannels));
        handlers.put("traffic-counter", () -> throughputCounter);
        handlers.put("connection-counter", () -> connectionCounter);
        if (tlsEnable) {
            LOG.info("Enabled TLS for input [{}/{}]. key-file=\"{}\" cert-file=\"{}\"", input.getName(),
                    input.getId(), tlsKeyFile, tlsCertFile);
            handlers.put("tls", getSslHandlerCallable(input));
        }
        handlers.putAll(super.getChildChannelHandlers(input));

        return handlers;
    }

    private Callable<ChannelHandler> getSslHandlerCallable(MessageInput input) {
        final File certFile;
        final File keyFile;
        if (tlsCertFile.exists() && tlsKeyFile.exists()) {
            certFile = tlsCertFile;
            keyFile = tlsKeyFile;
        } else {
            LOG.warn(
                    "TLS key file or certificate file does not exist, creating a self-signed certificate for input [{}/{}].",
                    input.getName(), input.getId());

            final String tmpDir = System.getProperty("java.io.tmpdir");
            checkState(tmpDir != null, "The temporary directory must not be null!");
            final Path tmpPath = Paths.get(tmpDir);
            if (!Files.isDirectory(tmpPath) || !Files.isWritable(tmpPath)) {
                throw new IllegalStateException(
                        "Couldn't write to temporary directory: " + tmpPath.toAbsolutePath());
            }

            try {
                final SelfSignedCertificate ssc = new SelfSignedCertificate(
                        configuration.getString(CK_BIND_ADDRESS) + ":" + configuration.getString(CK_PORT));
                certFile = ssc.certificate();
                keyFile = ssc.privateKey();
            } catch (CertificateException e) {
                final String msg = String.format(Locale.ENGLISH,
                        "Problem creating a self-signed certificate for input [%s/%s].", input.getName(),
                        input.getId());
                throw new IllegalStateException(msg, e);
            }
        }

        final ClientAuth clientAuth;
        switch (tlsClientAuth) {
        case TLS_CLIENT_AUTH_DISABLED:
            LOG.debug("Not using TLS client authentication");
            clientAuth = ClientAuth.NONE;
            break;
        case TLS_CLIENT_AUTH_OPTIONAL:
            LOG.debug("Using optional TLS client authentication");
            clientAuth = ClientAuth.OPTIONAL;
            break;
        case TLS_CLIENT_AUTH_REQUIRED:
            LOG.debug("Using mandatory TLS client authentication");
            clientAuth = ClientAuth.REQUIRE;
            break;
        default:
            throw new IllegalArgumentException("Unknown TLS client authentication mode: " + tlsClientAuth);
        }

        return buildSslHandlerCallable(nettyTransportConfiguration.getTlsProvider(), certFile, keyFile,
                tlsKeyPassword, clientAuth, tlsClientAuthCertFile);
    }

    private Callable<ChannelHandler> buildSslHandlerCallable(SslProvider tlsProvider, File certFile, File keyFile,
            String password, ClientAuth clientAuth, File clientAuthCertFile) {
        return new Callable<ChannelHandler>() {
            @Override
            public ChannelHandler call() throws Exception {
                try {
                    return new SslHandler(createSslEngine());
                } catch (SSLException e) {
                    LOG.error(
                            "Error creating SSL context. Make sure the certificate and key are in the correct format: cert=X.509 key=PKCS#8");
                    throw e;
                }
            }

            private SSLEngine createSslEngine() throws IOException, CertificateException {
                final X509Certificate[] clientAuthCerts;
                if (EnumSet.of(ClientAuth.OPTIONAL, ClientAuth.REQUIRE).contains(clientAuth)) {
                    if (clientAuthCertFile.exists()) {
                        clientAuthCerts = KeyUtil.loadCertificates(clientAuthCertFile.toPath()).stream()
                                .filter(certificate -> certificate instanceof X509Certificate)
                                .map(certificate -> (X509Certificate) certificate).toArray(X509Certificate[]::new);
                    } else {
                        LOG.warn(
                                "Client auth configured, but no authorized certificates / certificate authorities configured");
                        clientAuthCerts = null;
                    }
                } else {
                    clientAuthCerts = null;
                }

                final SslContext sslContext = SslContextBuilder
                        .forServer(certFile, keyFile, Strings.emptyToNull(password)).sslProvider(tlsProvider)
                        .clientAuth(clientAuth).trustManager(clientAuthCerts).build();

                // TODO: Use byte buffer allocator of channel
                return sslContext.newEngine(ByteBufAllocator.DEFAULT);
            }
        };
    }

    @ConfigClass
    public static class Config extends NettyTransport.Config {
        @Override
        public ConfigurationRequest getRequestedConfiguration() {
            final ConfigurationRequest x = super.getRequestedConfiguration();

            x.addField(new TextField(CK_TLS_CERT_FILE, "TLS cert file", "", "Path to the TLS certificate file",
                    ConfigurationField.Optional.OPTIONAL));
            x.addField(new TextField(CK_TLS_KEY_FILE, "TLS private key file", "",
                    "Path to the TLS private key file", ConfigurationField.Optional.OPTIONAL));
            x.addField(new BooleanField(CK_TLS_ENABLE, "Enable TLS", false, "Accept TLS connections"));
            x.addField(new TextField(CK_TLS_KEY_PASSWORD, "TLS key password", "",
                    "The password for the encrypted key file.", ConfigurationField.Optional.OPTIONAL,
                    TextField.Attribute.IS_PASSWORD));
            x.addField(new DropdownField(CK_TLS_CLIENT_AUTH, "TLS client authentication", TLS_CLIENT_AUTH_DISABLED,
                    TLS_CLIENT_AUTH_OPTIONS, "Whether clients need to authenticate themselves in a TLS connection",
                    ConfigurationField.Optional.OPTIONAL));
            x.addField(new TextField(CK_TLS_CLIENT_AUTH_TRUSTED_CERT_FILE, "TLS Client Auth Trusted Certs", "",
                    "TLS Client Auth Trusted Certs  (File or Directory)", ConfigurationField.Optional.OPTIONAL));
            x.addField(new BooleanField(CK_TCP_KEEPALIVE, "TCP keepalive", false, "Enable TCP keepalive packets"));

            return x;
        }
    }

    private static class InputLaunchListener implements ChannelFutureListener {
        private final AtomicReference<Channel> channelReference;
        private final MessageInput input;
        private final int expectedRecvBufferSize;

        public InputLaunchListener(AtomicReference<Channel> channelReference, MessageInput input,
                int expectedRecvBufferSize) {
            this.channelReference = channelReference;
            this.input = input;
            this.expectedRecvBufferSize = expectedRecvBufferSize;
        }

        @Override
        public void operationComplete(ChannelFuture future) throws Exception {
            if (future.isSuccess()) {
                final Channel channel = future.channel();
                channelReference.set(channel);
                LOG.debug("Started channel {}", channel);

                final ServerSocketChannelConfig channelConfig = (ServerSocketChannelConfig) channel.config();
                final int receiveBufferSize = channelConfig.getReceiveBufferSize();
                if (receiveBufferSize != expectedRecvBufferSize) {
                    LOG.warn("receiveBufferSize (SO_RCVBUF) for input {} (channel {}) should be {} but is {}.",
                            input, channel, expectedRecvBufferSize, receiveBufferSize);
                }
            } else {
                LOG.warn("Failed to start channel for input {}", input, future.cause());
            }
        }
    }
}