Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.cloudera.livy.rsc.rpc; import java.io.Closeable; import java.io.IOException; import java.net.InetSocketAddress; import java.security.SecureRandom; import java.util.Map; import java.util.Properties; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import javax.security.auth.callback.Callback; import javax.security.auth.callback.CallbackHandler; import javax.security.auth.callback.NameCallback; import javax.security.auth.callback.PasswordCallback; import javax.security.sasl.AuthorizeCallback; import javax.security.sasl.RealmCallback; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; import javax.security.sasl.SaslServer; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.Maps; import com.google.common.util.concurrent.ThreadFactoryBuilder; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.ScheduledFuture; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.cloudera.livy.rsc.RSCConf; import static com.cloudera.livy.rsc.RSCConf.Entry.*; /** * An RPC server. The server matches remote clients based on a secret that is generated on * the server - the secret needs to be given to the client through some other mechanism for * this to work. */ public class RpcServer implements Closeable { private static final Logger LOG = LoggerFactory.getLogger(RpcServer.class); private static final SecureRandom RND = new SecureRandom(); private final String address; private final Channel channel; private final EventLoopGroup group; private final int port; private final ConcurrentMap<String, ClientInfo> pendingClients; private final RSCConf config; public RpcServer(RSCConf lconf) throws IOException, InterruptedException { this.config = lconf; this.group = new NioEventLoopGroup(this.config.getInt(RPC_MAX_THREADS), new ThreadFactoryBuilder().setNameFormat("RPC-Handler-%d").setDaemon(true).build()); this.channel = new ServerBootstrap().group(group).channel(NioServerSocketChannel.class) .childHandler(new ChannelInitializer<SocketChannel>() { @Override public void initChannel(SocketChannel ch) throws Exception { SaslServerHandler saslHandler = new SaslServerHandler(config); final Rpc newRpc = Rpc.createServer(saslHandler, config, ch, group); saslHandler.rpc = newRpc; Runnable cancelTask = new Runnable() { @Override public void run() { LOG.warn("Timed out waiting for hello from client."); newRpc.close(); } }; saslHandler.cancelTask = group.schedule(cancelTask, config.getTimeAsMs(RPC_CLIENT_HANDSHAKE_TIMEOUT), TimeUnit.MILLISECONDS); } }).option(ChannelOption.SO_BACKLOG, 1).option(ChannelOption.SO_REUSEADDR, true) .childOption(ChannelOption.SO_KEEPALIVE, true).bind(0).sync().channel(); this.port = ((InetSocketAddress) channel.localAddress()).getPort(); this.pendingClients = Maps.newConcurrentMap(); String address = config.get(RPC_SERVER_ADDRESS); if (address == null) { address = config.findLocalAddress(); } this.address = address; } /** * Tells the RPC server to expect connections from clients. * * @param clientId An identifier for the client. Must be unique. * @param secret The secret the client will send to the server to identify itself. * @param callback The callback for when a new client successfully connects with the given * credentials. */ public void registerClient(String clientId, String secret, ClientCallback callback) { final ClientInfo client = new ClientInfo(clientId, secret, callback); if (pendingClients.putIfAbsent(clientId, client) != null) { throw new IllegalStateException(String.format("Client '%s' already registered.", clientId)); } } /** * Stop waiting for connections for a given client ID. * * @param clientId The client ID to forget. */ public void unregisterClient(String clientId) { pendingClients.remove(clientId); } /** * Creates a secret for identifying a client connection. */ public String createSecret() { byte[] secret = new byte[config.getInt(RPC_SECRET_RANDOM_BITS) / 8]; RND.nextBytes(secret); StringBuilder sb = new StringBuilder(); for (byte b : secret) { if (b < 10) { sb.append("0"); } sb.append(Integer.toHexString(b)); } return sb.toString(); } public String getAddress() { return address; } public int getPort() { return port; } public EventLoopGroup getEventLoopGroup() { return group; } @Override public void close() { try { channel.close(); pendingClients.clear(); } finally { group.shutdownGracefully(); } } /** * A callback that can be registered to be notified when new clients are created and * successfully authenticate against the server. */ public interface ClientCallback { /** * Called when a new client successfully connects. * * @param client The RPC instance for the new client. * @return The RpcDispatcher to be used for the client. */ RpcDispatcher onNewClient(Rpc client); } private class SaslServerHandler extends SaslHandler implements CallbackHandler { private final SaslServer server; private Rpc rpc; private ScheduledFuture<?> cancelTask; private String clientId; private ClientInfo client; SaslServerHandler(RSCConf config) throws IOException { super(config); this.server = Sasl.createSaslServer(config.get(SASL_MECHANISMS), Rpc.SASL_PROTOCOL, Rpc.SASL_REALM, config.getSaslOptions(), this); } @Override protected boolean isComplete() { return server.isComplete(); } @Override protected String getNegotiatedProperty(String name) { return (String) server.getNegotiatedProperty(name); } @Override protected Rpc.SaslMessage update(Rpc.SaslMessage challenge) throws IOException { if (clientId == null) { Preconditions.checkArgument(challenge.clientId != null, "Missing client ID in SASL handshake."); clientId = challenge.clientId; client = pendingClients.get(clientId); Preconditions.checkArgument(client != null, "Unexpected client ID '%s' in SASL handshake.", clientId); } return new Rpc.SaslMessage(server.evaluateResponse(challenge.payload)); } @Override public byte[] wrap(byte[] data, int offset, int len) throws IOException { return server.wrap(data, offset, len); } @Override public byte[] unwrap(byte[] data, int offset, int len) throws IOException { return server.unwrap(data, offset, len); } @Override public void dispose() throws IOException { if (!server.isComplete()) { onError(new SaslException("Server closed before SASL negotiation finished.")); } server.dispose(); } @Override protected void onComplete() throws Exception { cancelTask.cancel(true); RpcDispatcher dispatcher = null; try { dispatcher = client.callback.onNewClient(rpc); } catch (Exception e) { LOG.warn("Client callback threw an exception.", e); } if (dispatcher != null) { rpc.setDispatcher(dispatcher); } } @Override protected void onError(Throwable error) { cancelTask.cancel(true); } @Override public void handle(Callback[] callbacks) { Preconditions.checkState(client != null, "Handshake not initialized yet."); for (Callback cb : callbacks) { if (cb instanceof NameCallback) { ((NameCallback) cb).setName(clientId); } else if (cb instanceof PasswordCallback) { ((PasswordCallback) cb).setPassword(client.secret.toCharArray()); } else if (cb instanceof AuthorizeCallback) { ((AuthorizeCallback) cb).setAuthorized(true); } else if (cb instanceof RealmCallback) { RealmCallback rb = (RealmCallback) cb; rb.setText(rb.getDefaultText()); } } } } private static class ClientInfo { final String id; final String secret; final ClientCallback callback; private ClientInfo(String id, String secret, ClientCallback callback) { this.id = id; this.secret = secret; this.callback = callback; } } }