Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.nifi.controller.livy; import java.io.BufferedReader; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; import java.net.ConnectException; import java.net.HttpURLConnection; import java.net.SocketTimeoutException; import java.net.URL; import java.nio.charset.StandardCharsets; import java.security.KeyManagementException; import java.security.KeyStore; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.security.UnrecoverableKeyException; import java.security.cert.CertificateException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.StringUtils; import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.annotation.lifecycle.OnDisabled; import org.apache.nifi.annotation.lifecycle.OnEnabled; import org.apache.nifi.components.PropertyDescriptor; import org.apache.nifi.controller.AbstractControllerService; import org.apache.nifi.controller.ConfigurationContext; import org.apache.nifi.controller.ControllerServiceInitializationContext; import org.apache.nifi.logging.ComponentLog; import org.apache.nifi.processor.exception.ProcessException; import org.apache.nifi.processor.util.StandardValidators; import org.apache.nifi.ssl.SSLContextService; import org.codehaus.jackson.map.ObjectMapper; import org.codehaus.jettison.json.JSONException; import org.codehaus.jettison.json.JSONObject; import org.apache.nifi.controller.api.livy.LivySessionService; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManagerFactory; @Tags({ "Livy", "REST", "Spark", "http" }) @CapabilityDescription("Manages pool of Spark sessions over HTTP") public class LivySessionController extends AbstractControllerService implements LivySessionService { public static final PropertyDescriptor LIVY_HOST = new PropertyDescriptor.Builder().name("livy-cs-livy-host") .displayName("Livy Host").description("The hostname (or IP address) of the Livy server.") .addValidator(StandardValidators.NON_EMPTY_VALIDATOR).required(true).expressionLanguageSupported(true) .build(); public static final PropertyDescriptor LIVY_PORT = new PropertyDescriptor.Builder().name("livy-cs-livy-port") .displayName("Livy Port").description("The port number for the Livy server.").required(true) .addValidator(StandardValidators.NON_EMPTY_VALIDATOR).expressionLanguageSupported(true) .defaultValue("8998").build(); public static final PropertyDescriptor SESSION_POOL_SIZE = new PropertyDescriptor.Builder() .name("livy-cs-session-pool-size").displayName("Session Pool Size") .description("Number of sessions to keep open").required(true).defaultValue("2") .addValidator(StandardValidators.POSITIVE_INTEGER_VALIDATOR).expressionLanguageSupported(true).build(); public static final PropertyDescriptor SESSION_TYPE = new PropertyDescriptor.Builder() .name("livy-cs-session-kind").displayName("Session Type") .description("The type of Spark session to start (spark, pyspark, pyspark3, sparkr, e.g.)") .required(true).allowableValues("spark", "pyspark", "pyspark3", "sparkr").defaultValue("spark") .addValidator(StandardValidators.NON_EMPTY_VALIDATOR).build(); public static final PropertyDescriptor SESSION_MGR_STATUS_INTERVAL = new PropertyDescriptor.Builder() .name("livy-cs-session-manager-status-interval").displayName("Session Manager Status Interval") .description("The amount of time to wait between requesting session information updates.") .required(true).defaultValue("2 sec").addValidator(StandardValidators.TIME_PERIOD_VALIDATOR) .expressionLanguageSupported(true).build(); public static final PropertyDescriptor JARS = new PropertyDescriptor.Builder().name("livy-cs-session-jars") .displayName("Session JARs").description("JARs to be used in the Spark session.").required(false) .addValidator( StandardValidators.createListValidator(true, true, StandardValidators.FILE_EXISTS_VALIDATOR)) .expressionLanguageSupported(true).build(); public static final PropertyDescriptor FILES = new PropertyDescriptor.Builder().name("livy-cs-session-files") .displayName("Session Files").description("Files to be used in the Spark session.").required(false) .addValidator( StandardValidators.createListValidator(true, true, StandardValidators.FILE_EXISTS_VALIDATOR)) .defaultValue(null).build(); public static final PropertyDescriptor SSL_CONTEXT_SERVICE = new PropertyDescriptor.Builder() .name("SSL Context Service") .description( "The SSL Context Service used to provide client certificate information for TLS/SSL (https) connections.") .required(false).identifiesControllerService(SSLContextService.class).build(); public static final PropertyDescriptor CONNECT_TIMEOUT = new PropertyDescriptor.Builder() .name("Connection Timeout").description("Max wait time for connection to remote service.") .required(true).defaultValue("5 secs").addValidator(StandardValidators.TIME_PERIOD_VALIDATOR).build(); private volatile String livyUrl; private volatile int sessionPoolSize; private volatile String controllerKind; private volatile String jars; private volatile String files; private volatile Map<Integer, JSONObject> sessions = new ConcurrentHashMap<>(); private volatile SSLContextService sslContextService; private volatile SSLContext sslContext; private volatile int connectTimeout; private volatile Thread livySessionManagerThread = null; private volatile boolean enabled = true; private List<PropertyDescriptor> properties; @Override protected void init(ControllerServiceInitializationContext config) { final List<PropertyDescriptor> props = new ArrayList<>(); props.add(LIVY_HOST); props.add(LIVY_PORT); props.add(SESSION_POOL_SIZE); props.add(SESSION_TYPE); props.add(SESSION_MGR_STATUS_INTERVAL); props.add(SSL_CONTEXT_SERVICE); props.add(CONNECT_TIMEOUT); props.add(JARS); props.add(FILES); properties = Collections.unmodifiableList(props); } @Override protected List<PropertyDescriptor> getSupportedPropertyDescriptors() { return properties; } @OnEnabled public void onConfigured(final ConfigurationContext context) { ComponentLog log = getLogger(); final String livyHost = context.getProperty(LIVY_HOST).evaluateAttributeExpressions().getValue(); final String livyPort = context.getProperty(LIVY_PORT).evaluateAttributeExpressions().getValue(); final String sessionPoolSize = context.getProperty(SESSION_POOL_SIZE).evaluateAttributeExpressions() .getValue(); final String sessionKind = context.getProperty(SESSION_TYPE).getValue(); final long sessionManagerStatusInterval = context.getProperty(SESSION_MGR_STATUS_INTERVAL) .asTimePeriod(TimeUnit.MILLISECONDS); final String jars = context.getProperty(JARS).evaluateAttributeExpressions().getValue(); final String files = context.getProperty(FILES).evaluateAttributeExpressions().getValue(); sslContextService = context.getProperty(SSL_CONTEXT_SERVICE).asControllerService(SSLContextService.class); sslContext = sslContextService == null ? null : sslContextService.createSSLContext(SSLContextService.ClientAuth.NONE); connectTimeout = Math.toIntExact(context.getProperty(CONNECT_TIMEOUT).asTimePeriod(TimeUnit.MILLISECONDS)); this.livyUrl = "http" + (sslContextService != null ? "s" : "") + "://" + livyHost + ":" + livyPort; this.controllerKind = sessionKind; this.jars = jars; this.files = files; this.sessionPoolSize = Integer.valueOf(sessionPoolSize); this.enabled = true; livySessionManagerThread = new Thread(() -> { while (enabled) { try { manageSessions(); Thread.sleep(sessionManagerStatusInterval); } catch (InterruptedException e) { Thread.currentThread().interrupt(); enabled = false; } catch (IOException ioe) { throw new ProcessException(ioe); } } }); livySessionManagerThread.setName("Livy-Session-Manager-" + controllerKind); livySessionManagerThread.start(); } @OnDisabled public void shutdown() { ComponentLog log = getLogger(); try { enabled = false; livySessionManagerThread.interrupt(); livySessionManagerThread.join(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); log.error("Livy Session Manager Thread interrupted"); } } @Override public Map<String, String> getSession() { Map<String, String> sessionMap = new HashMap<>(); try { final Map<Integer, JSONObject> sessionsCopy = sessions; for (int sessionId : sessionsCopy.keySet()) { JSONObject currentSession = sessions.get(sessionId); String state = currentSession.getString("state"); String sessionKind = currentSession.getString("kind"); if (state.equalsIgnoreCase("idle") && sessionKind.equalsIgnoreCase(controllerKind)) { sessionMap.put("sessionId", String.valueOf(sessionId)); sessionMap.put("livyUrl", livyUrl); } } } catch (JSONException e) { getLogger().error("Unexpected data found when looking for JSON object with 'state' and 'kind' fields", e); } return sessionMap; } @Override public HttpURLConnection getConnection(String urlString) throws IOException { URL url = new URL(urlString); HttpURLConnection connection = (HttpURLConnection) url.openConnection(); connection.setConnectTimeout(connectTimeout); if (sslContextService != null) { try { setSslSocketFactory((HttpsURLConnection) connection, sslContextService, sslContext); } catch (KeyStoreException | CertificateException | NoSuchAlgorithmException | UnrecoverableKeyException | KeyManagementException e) { throw new IOException(e); } } return connection; } private void manageSessions() throws InterruptedException, IOException { int idleSessions = 0; JSONObject newSessionInfo; Map<Integer, JSONObject> sessionsInfo; ComponentLog log = getLogger(); try { sessionsInfo = listSessions(); if (sessions.isEmpty()) { log.debug("manageSessions() the active session list is empty, populating from acquired list..."); sessions.putAll(sessionsInfo); } for (Integer sessionId : new ArrayList<>(sessions.keySet())) { JSONObject currentSession = sessions.get(sessionId); log.debug("manageSessions() Updating current session: " + currentSession); if (sessionsInfo.containsKey(sessionId)) { String state = currentSession.getString("state"); String sessionKind = currentSession.getString("kind"); log.debug("manageSessions() controller kind: {}, session kind: {}, session state: {}", new Object[] { controllerKind, sessionKind, state }); if (state.equalsIgnoreCase("idle") && sessionKind.equalsIgnoreCase(controllerKind)) { // Keep track of how many sessions are in an idle state and thus available idleSessions++; sessions.put(sessionId, sessionsInfo.get(sessionId)); // Remove session from session list source of truth snapshot since it has been dealt with sessionsInfo.remove(sessionId); } else if ((state.equalsIgnoreCase("busy") || state.equalsIgnoreCase("starting")) && sessionKind.equalsIgnoreCase(controllerKind)) { // Update status of existing sessions sessions.put(sessionId, sessionsInfo.get(sessionId)); // Remove session from session list source of truth snapshot since it has been dealt with sessionsInfo.remove(sessionId); } else { // Prune sessions of kind != controllerKind and whose state is: // not_started, shutting_down, error, dead, success (successfully stopped) sessions.remove(sessionId); //Remove session from session list source of truth snapshot since it has been dealt with sessionsInfo.remove(sessionId); } } else { // Prune sessions that no longer exist log.debug( "manageSessions() session exists in session pool but not in source snapshot, removing from pool..."); sessions.remove(sessionId); // Remove session from session list source of truth snapshot since it has been dealt with sessionsInfo.remove(sessionId); } } int numSessions = sessions.size(); log.debug("manageSessions() There are " + numSessions + " sessions in the pool"); // Open new sessions equal to the number requested by sessionPoolSize if (numSessions == 0) { for (int i = 0; i < sessionPoolSize; i++) { newSessionInfo = openSession(); sessions.put(newSessionInfo.getInt("id"), newSessionInfo); log.debug("manageSessions() Registered new session: " + newSessionInfo); } } else { // Open one new session if there are no idle sessions if (idleSessions == 0) { log.debug("manageSessions() There are " + numSessions + " sessions in the pool but none of them are idle sessions, creating..."); newSessionInfo = openSession(); sessions.put(newSessionInfo.getInt("id"), newSessionInfo); log.debug("manageSessions() Registered new session: " + newSessionInfo); } // Open more sessions if number of sessions is less than target pool size if (numSessions < sessionPoolSize) { log.debug("manageSessions() There are " + numSessions + ", need more sessions to equal requested pool size of " + sessionPoolSize + ", creating..."); for (int i = 0; i < sessionPoolSize - numSessions; i++) { newSessionInfo = openSession(); sessions.put(newSessionInfo.getInt("id"), newSessionInfo); log.debug("manageSessions() Registered new session: " + newSessionInfo); } } } } catch (ConnectException | SocketTimeoutException ce) { log.error("Timeout connecting to Livy service to retrieve sessions", ce); } catch (JSONException e) { throw new IOException(e); } } private Map<Integer, JSONObject> listSessions() throws IOException { String sessionsUrl = livyUrl + "/sessions"; int numSessions; JSONObject sessionsInfo; Map<Integer, JSONObject> sessionsMap = new HashMap<>(); Map<String, String> headers = new HashMap<>(); headers.put("Content-Type", APPLICATION_JSON); headers.put("X-Requested-By", USER); try { sessionsInfo = readJSONFromUrl(sessionsUrl, headers); numSessions = sessionsInfo.getJSONArray("sessions").length(); for (int i = 0; i < numSessions; i++) { int currentSessionId = sessionsInfo.getJSONArray("sessions").getJSONObject(i).getInt("id"); JSONObject currentSession = sessionsInfo.getJSONArray("sessions").getJSONObject(i); sessionsMap.put(currentSessionId, currentSession); } } catch (JSONException e) { throw new IOException(e); } return sessionsMap; } private JSONObject getSessionInfo(int sessionId) throws IOException { String sessionUrl = livyUrl + "/sessions/" + sessionId; JSONObject sessionInfo; Map<String, String> headers = new HashMap<>(); headers.put("Content-Type", APPLICATION_JSON); headers.put("X-Requested-By", USER); try { sessionInfo = readJSONFromUrl(sessionUrl, headers); } catch (JSONException e) { throw new IOException(e); } return sessionInfo; } private JSONObject openSession() throws IOException, JSONException, InterruptedException { ComponentLog log = getLogger(); JSONObject newSessionInfo; final ObjectMapper mapper = new ObjectMapper(); String sessionsUrl = livyUrl + "/sessions"; StringBuilder payload = new StringBuilder("{\"kind\":\"" + controllerKind + "\""); if (jars != null) { List<String> jarsArray = Arrays.stream(jars.split(",")).filter(StringUtils::isNotBlank) .map(String::trim).collect(Collectors.toList()); String jarsJsonArray = mapper.writeValueAsString(jarsArray); payload.append(",\"jars\":"); payload.append(jarsJsonArray); } if (files != null) { List<String> filesArray = Arrays.stream(files.split(",")).filter(StringUtils::isNotBlank) .map(String::trim).collect(Collectors.toList()); String filesJsonArray = mapper.writeValueAsString(filesArray); payload.append(",\"files\":"); payload.append(filesJsonArray); } payload.append("}"); log.debug("openSession() Session Payload: " + payload.toString()); Map<String, String> headers = new HashMap<>(); headers.put("Content-Type", APPLICATION_JSON); headers.put("X-Requested-By", USER); newSessionInfo = readJSONObjectFromUrlPOST(sessionsUrl, headers, payload.toString()); Thread.sleep(1000); while (newSessionInfo.getString("state").equalsIgnoreCase("starting")) { log.debug("openSession() Waiting for session to start..."); newSessionInfo = getSessionInfo(newSessionInfo.getInt("id")); log.debug("openSession() newSessionInfo: " + newSessionInfo); Thread.sleep(1000); } return newSessionInfo; } private JSONObject readJSONObjectFromUrlPOST(String urlString, Map<String, String> headers, String payload) throws IOException, JSONException { URL url = new URL(urlString); HttpURLConnection connection = getConnection(urlString); connection.setRequestMethod(POST); connection.setDoOutput(true); for (Map.Entry<String, String> entry : headers.entrySet()) { connection.setRequestProperty(entry.getKey(), entry.getValue()); } OutputStream os = connection.getOutputStream(); os.write(payload.getBytes()); os.flush(); if (connection.getResponseCode() != HttpURLConnection.HTTP_OK && connection.getResponseCode() != HttpURLConnection.HTTP_CREATED) { throw new RuntimeException("Failed : HTTP error code : " + connection.getResponseCode() + " : " + connection.getResponseMessage()); } InputStream content = connection.getInputStream(); return readAllIntoJSONObject(content); } private JSONObject readJSONFromUrl(String urlString, Map<String, String> headers) throws IOException, JSONException { HttpURLConnection connection = getConnection(urlString); for (Map.Entry<String, String> entry : headers.entrySet()) { connection.setRequestProperty(entry.getKey(), entry.getValue()); } connection.setRequestMethod(GET); connection.setDoOutput(true); InputStream content = connection.getInputStream(); return readAllIntoJSONObject(content); } private JSONObject readAllIntoJSONObject(InputStream content) throws IOException, JSONException { BufferedReader rd = new BufferedReader(new InputStreamReader(content, StandardCharsets.UTF_8)); String jsonText = IOUtils.toString(rd); return new JSONObject(jsonText); } private void setSslSocketFactory(HttpsURLConnection httpsURLConnection, SSLContextService sslService, SSLContext sslContext) throws IOException, KeyStoreException, CertificateException, NoSuchAlgorithmException, UnrecoverableKeyException, KeyManagementException { final String keystoreLocation = sslService.getKeyStoreFile(); final String keystorePass = sslService.getKeyStorePassword(); final String keystoreType = sslService.getKeyStoreType(); // prepare the keystore final KeyStore keyStore = KeyStore.getInstance(keystoreType); try (FileInputStream keyStoreStream = new FileInputStream(keystoreLocation)) { keyStore.load(keyStoreStream, keystorePass.toCharArray()); } final KeyManagerFactory keyManagerFactory = KeyManagerFactory .getInstance(KeyManagerFactory.getDefaultAlgorithm()); keyManagerFactory.init(keyStore, keystorePass.toCharArray()); // load truststore final String truststoreLocation = sslService.getTrustStoreFile(); final String truststorePass = sslService.getTrustStorePassword(); final String truststoreType = sslService.getTrustStoreType(); KeyStore truststore = KeyStore.getInstance(truststoreType); final TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance("X509"); truststore.load(new FileInputStream(truststoreLocation), truststorePass.toCharArray()); trustManagerFactory.init(truststore); sslContext.init(keyManagerFactory.getKeyManagers(), trustManagerFactory.getTrustManagers(), null); final SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory(); httpsURLConnection.setSSLSocketFactory(sslSocketFactory); } }