Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.lens.ml; import java.io.IOException; import java.io.ObjectOutputStream; import java.util.*; import javax.ws.rs.client.Client; import javax.ws.rs.client.ClientBuilder; import javax.ws.rs.client.Entity; import javax.ws.rs.client.WebTarget; import javax.ws.rs.core.MediaType; import org.apache.lens.api.LensConf; import org.apache.lens.api.LensException; import org.apache.lens.api.LensSessionHandle; import org.apache.lens.api.query.LensQuery; import org.apache.lens.api.query.QueryHandle; import org.apache.lens.api.query.QueryStatus; import org.apache.lens.ml.spark.SparkMLDriver; import org.apache.lens.ml.spark.algos.BaseSparkAlgo; import org.apache.lens.server.api.LensConfConstants; import org.apache.commons.io.IOUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.session.SessionState; import org.apache.spark.api.java.JavaSparkContext; import org.glassfish.jersey.media.multipart.FormDataBodyPart; import org.glassfish.jersey.media.multipart.FormDataContentDisposition; import org.glassfish.jersey.media.multipart.FormDataMultiPart; import org.glassfish.jersey.media.multipart.MultiPartFeature; /** * The Class LensMLImpl. */ public class LensMLImpl implements LensML { /** The Constant LOG. */ public static final Log LOG = LogFactory.getLog(LensMLImpl.class); /** The drivers. */ protected List<MLDriver> drivers; /** The conf. */ private HiveConf conf; /** The spark context. */ private JavaSparkContext sparkContext; /** * Instantiates a new lens ml impl. * * @param conf the conf */ public LensMLImpl(HiveConf conf) { this.conf = conf; } public HiveConf getConf() { return conf; } /** * Use an existing Spark context. Useful in case of * * @param jsc JavaSparkContext instance */ public void setSparkContext(JavaSparkContext jsc) { this.sparkContext = jsc; } public List<String> getAlgorithms() { List<String> algos = new ArrayList<String>(); for (MLDriver driver : drivers) { algos.addAll(driver.getAlgoNames()); } return algos; } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#getAlgoForName(java.lang.String) */ public MLAlgo getAlgoForName(String algorithm) throws LensException { for (MLDriver driver : drivers) { if (driver.isAlgoSupported(algorithm)) { return driver.getAlgoInstance(algorithm); } } throw new LensException("Algo not supported " + algorithm); } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#train(java.lang.String, java.lang.String, java.lang.String[]) */ public String train(String table, String algorithm, String[] args) throws LensException { MLAlgo algo = getAlgoForName(algorithm); String modelId = UUID.randomUUID().toString(); LOG.info("Begin training model " + modelId + ", algo=" + algorithm + ", table=" + table + ", params=" + Arrays.toString(args)); String database = null; if (SessionState.get() != null) { database = SessionState.get().getCurrentDatabase(); } else { database = "default"; } MLModel model = algo.train(toLensConf(conf), database, table, modelId, args); LOG.info("Done training model: " + modelId); model.setCreatedAt(new Date()); model.setAlgoName(algorithm); Path modelLocation = null; try { modelLocation = persistModel(model); LOG.info("Model saved: " + modelId + ", algo: " + algorithm + ", path: " + modelLocation); return model.getId(); } catch (IOException e) { throw new LensException("Error saving model " + modelId + " for algo " + algorithm, e); } } /** * Gets the algo dir. * * @param algoName the algo name * @return the algo dir * @throws IOException Signals that an I/O exception has occurred. */ private Path getAlgoDir(String algoName) throws IOException { String modelSaveBaseDir = conf.get(ModelLoader.MODEL_PATH_BASE_DIR, ModelLoader.MODEL_PATH_BASE_DIR_DEFAULT); return new Path(new Path(modelSaveBaseDir), algoName); } /** * Persist model. * * @param model the model * @return the path * @throws IOException Signals that an I/O exception has occurred. */ private Path persistModel(MLModel model) throws IOException { // Get model save path Path algoDir = getAlgoDir(model.getAlgoName()); FileSystem fs = algoDir.getFileSystem(conf); if (!fs.exists(algoDir)) { fs.mkdirs(algoDir); } Path modelSavePath = new Path(algoDir, model.getId()); ObjectOutputStream outputStream = null; try { outputStream = new ObjectOutputStream(fs.create(modelSavePath, false)); outputStream.writeObject(model); outputStream.flush(); } catch (IOException io) { LOG.error("Error saving model " + model.getId() + " reason: " + io.getMessage()); throw io; } finally { IOUtils.closeQuietly(outputStream); } return modelSavePath; } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#getModels(java.lang.String) */ public List<String> getModels(String algorithm) throws LensException { try { Path algoDir = getAlgoDir(algorithm); FileSystem fs = algoDir.getFileSystem(conf); if (!fs.exists(algoDir)) { return null; } List<String> models = new ArrayList<String>(); for (FileStatus stat : fs.listStatus(algoDir)) { models.add(stat.getPath().getName()); } if (models.isEmpty()) { return null; } return models; } catch (IOException ioex) { throw new LensException(ioex); } } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#getModel(java.lang.String, java.lang.String) */ public MLModel getModel(String algorithm, String modelId) throws LensException { try { return ModelLoader.loadModel(conf, algorithm, modelId); } catch (IOException e) { throw new LensException(e); } } /** * Inits the. * * @param hiveConf the hive conf */ public synchronized void init(HiveConf hiveConf) { this.conf = hiveConf; // Get all the drivers String[] driverClasses = hiveConf.getStrings("lens.ml.drivers"); if (driverClasses == null || driverClasses.length == 0) { throw new RuntimeException("No ML Drivers specified in conf"); } LOG.info("Loading drivers " + Arrays.toString(driverClasses)); drivers = new ArrayList<MLDriver>(driverClasses.length); for (String driverClass : driverClasses) { Class<?> cls; try { cls = Class.forName(driverClass); } catch (ClassNotFoundException e) { LOG.error("Driver class not found " + driverClass); continue; } if (!MLDriver.class.isAssignableFrom(cls)) { LOG.warn("Not a driver class " + driverClass); continue; } try { Class<? extends MLDriver> mlDriverClass = (Class<? extends MLDriver>) cls; MLDriver driver = mlDriverClass.newInstance(); driver.init(toLensConf(conf)); drivers.add(driver); LOG.info("Added driver " + driverClass); } catch (Exception e) { LOG.error("Failed to create driver " + driverClass + " reason: " + e.getMessage(), e); } } if (drivers.isEmpty()) { throw new RuntimeException("No ML drivers loaded"); } LOG.info("Inited ML service"); } /** * Start. */ public synchronized void start() { for (MLDriver driver : drivers) { try { if (driver instanceof SparkMLDriver && sparkContext != null) { ((SparkMLDriver) driver).useSparkContext(sparkContext); } driver.start(); } catch (LensException e) { LOG.error("Failed to start driver " + driver, e); } } LOG.info("Started ML service"); } /** * Stop. */ public synchronized void stop() { for (MLDriver driver : drivers) { try { driver.stop(); } catch (LensException e) { LOG.error("Failed to stop driver " + driver, e); } } drivers.clear(); LOG.info("Stopped ML service"); } public synchronized HiveConf getHiveConf() { return conf; } /** * Clear models. */ public void clearModels() { ModelLoader.clearCache(); } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#getModelPath(java.lang.String, java.lang.String) */ public String getModelPath(String algorithm, String modelID) { return ModelLoader.getModelLocation(conf, algorithm, modelID).toString(); } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#testModel(org.apache.lens.api.LensSessionHandle, java.lang.String, java.lang.String, * java.lang.String) */ @Override public MLTestReport testModel(LensSessionHandle session, String table, String algorithm, String modelID, String outputTable) throws LensException { return null; } /** * Test a model in embedded mode. * * @param sessionHandle the session handle * @param table the table * @param algorithm the algorithm * @param modelID the model id * @param queryApiUrl the query api url * @return the ML test report * @throws LensException the lens exception */ public MLTestReport testModelRemote(LensSessionHandle sessionHandle, String table, String algorithm, String modelID, String queryApiUrl, String outputTable) throws LensException { return testModel(sessionHandle, table, algorithm, modelID, new RemoteQueryRunner(sessionHandle, queryApiUrl), outputTable); } /** * Evaluate a model. Evaluation is done on data selected table from an input table. The model is run as a UDF and its * output is inserted into a table with a partition. Each evaluation is given a unique ID. The partition label is * associated with this unique ID. * <p/> * <p> * This call also required a query runner. Query runner is responsible for executing the evaluation query against Lens * server. * </p> * * @param sessionHandle the session handle * @param table the table * @param algorithm the algorithm * @param modelID the model id * @param queryRunner the query runner * @param outputTable table where test output will be written * @return the ML test report * @throws LensException the lens exception */ public MLTestReport testModel(LensSessionHandle sessionHandle, String table, String algorithm, String modelID, TestQueryRunner queryRunner, String outputTable) throws LensException { // check if algorithm exists if (!getAlgorithms().contains(algorithm)) { throw new LensException("No such algorithm " + algorithm); } MLModel<?> model; try { model = ModelLoader.loadModel(conf, algorithm, modelID); } catch (IOException e) { throw new LensException(e); } if (model == null) { throw new LensException("Model not found: " + modelID + " algorithm=" + algorithm); } String database = null; if (SessionState.get() != null) { database = SessionState.get().getCurrentDatabase(); } String testID = UUID.randomUUID().toString().replace("-", "_"); final String testTable = outputTable; final String testResultColumn = "prediction_result"; // TODO support error metric UDAFs TableTestingSpec spec = TableTestingSpec.newBuilder().hiveConf(conf) .database(database == null ? "default" : database).inputTable(table) .featureColumns(model.getFeatureColumns()).outputColumn(testResultColumn) .lableColumn(model.getLabelColumn()).algorithm(algorithm).modelID(modelID).outputTable(testTable) .testID(testID).build(); String testQuery = spec.getTestQuery(); if (testQuery == null) { throw new LensException( "Invalid test spec. " + "table=" + table + " algorithm=" + algorithm + " modelID=" + modelID); } if (!spec.isOutputTableExists()) { LOG.info("Output table '" + testTable + "' does not exist for test algorithm = " + algorithm + " modelid=" + modelID + ", Creating table using query: " + spec.getCreateOutputTableQuery()); // create the output table String createOutputTableQuery = spec.getCreateOutputTableQuery(); queryRunner.runQuery(createOutputTableQuery); LOG.info("Table created " + testTable); } LOG.info("Running evaluation query " + testQuery); QueryHandle testQueryHandle = queryRunner.runQuery(testQuery); MLTestReport testReport = new MLTestReport(); testReport.setReportID(testID); testReport.setAlgorithm(algorithm); testReport.setFeatureColumns(model.getFeatureColumns()); testReport.setLabelColumn(model.getLabelColumn()); testReport.setModelID(model.getId()); testReport.setOutputColumn(testResultColumn); testReport.setOutputTable(testTable); testReport.setTestTable(table); testReport.setQueryID(testQueryHandle.toString()); // Save test report persistTestReport(testReport); LOG.info("Saved test report " + testReport.getReportID()); return testReport; } /** * Persist test report. * * @param testReport the test report * @throws LensException the lens exception */ private void persistTestReport(MLTestReport testReport) throws LensException { LOG.info("saving test report " + testReport.getReportID()); try { ModelLoader.saveTestReport(conf, testReport); LOG.info("Saved report " + testReport.getReportID()); } catch (IOException e) { LOG.error("Error saving report " + testReport.getReportID() + " reason: " + e.getMessage()); } } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#getTestReports(java.lang.String) */ public List<String> getTestReports(String algorithm) throws LensException { Path reportBaseDir = new Path( conf.get(ModelLoader.TEST_REPORT_BASE_DIR, ModelLoader.TEST_REPORT_BASE_DIR_DEFAULT)); FileSystem fs = null; try { fs = reportBaseDir.getFileSystem(conf); if (!fs.exists(reportBaseDir)) { return null; } Path algoDir = new Path(reportBaseDir, algorithm); if (!fs.exists(algoDir)) { return null; } List<String> reports = new ArrayList<String>(); for (FileStatus stat : fs.listStatus(algoDir)) { reports.add(stat.getPath().getName()); } return reports; } catch (IOException e) { LOG.error("Error reading report list for " + algorithm, e); return null; } } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#getTestReport(java.lang.String, java.lang.String) */ public MLTestReport getTestReport(String algorithm, String reportID) throws LensException { try { return ModelLoader.loadReport(conf, algorithm, reportID); } catch (IOException e) { throw new LensException(e); } } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#predict(java.lang.String, java.lang.String, java.lang.Object[]) */ public Object predict(String algorithm, String modelID, Object[] features) throws LensException { // Load the model instance MLModel<?> model = getModel(algorithm, modelID); return model.predict(features); } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#deleteModel(java.lang.String, java.lang.String) */ public void deleteModel(String algorithm, String modelID) throws LensException { try { ModelLoader.deleteModel(conf, algorithm, modelID); LOG.info("DELETED model " + modelID + " algorithm=" + algorithm); } catch (IOException e) { LOG.error("Error deleting model file. algorithm=" + algorithm + " model=" + modelID + " reason: " + e.getMessage(), e); throw new LensException("Unable to delete model " + modelID + " for algorithm " + algorithm, e); } } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#deleteTestReport(java.lang.String, java.lang.String) */ public void deleteTestReport(String algorithm, String reportID) throws LensException { try { ModelLoader.deleteTestReport(conf, algorithm, reportID); LOG.info("DELETED report=" + reportID + " algorithm=" + algorithm); } catch (IOException e) { LOG.error( "Error deleting report " + reportID + " algorithm=" + algorithm + " reason: " + e.getMessage(), e); throw new LensException("Unable to delete report " + reportID + " for algorithm " + algorithm, e); } } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#getAlgoParamDescription(java.lang.String) */ public Map<String, String> getAlgoParamDescription(String algorithm) { MLAlgo algo = null; try { algo = getAlgoForName(algorithm); } catch (LensException e) { LOG.error("Error getting algo description : " + algorithm, e); return null; } if (algo instanceof BaseSparkAlgo) { return ((BaseSparkAlgo) algo).getArgUsage(); } return null; } /** * Submit model test query to a remote Lens server. */ class RemoteQueryRunner extends TestQueryRunner { /** The query api url. */ final String queryApiUrl; /** * Instantiates a new remote query runner. * * @param sessionHandle the session handle * @param queryApiUrl the query api url */ public RemoteQueryRunner(LensSessionHandle sessionHandle, String queryApiUrl) { super(sessionHandle); this.queryApiUrl = queryApiUrl; } /* * (non-Javadoc) * * @see org.apache.lens.ml.TestQueryRunner#runQuery(java.lang.String) */ @Override public QueryHandle runQuery(String query) throws LensException { // Create jersey client for query endpoint Client client = ClientBuilder.newBuilder().register(MultiPartFeature.class).build(); WebTarget target = client.target(queryApiUrl); final FormDataMultiPart mp = new FormDataMultiPart(); mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("sessionid").build(), sessionHandle, MediaType.APPLICATION_XML_TYPE)); mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("query").build(), query)); mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("operation").build(), "execute")); LensConf lensConf = new LensConf(); lensConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_SET, false + ""); lensConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_INDRIVER, false + ""); mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("conf").fileName("conf").build(), lensConf, MediaType.APPLICATION_XML_TYPE)); final QueryHandle handle = target.request().post(Entity.entity(mp, MediaType.MULTIPART_FORM_DATA_TYPE), QueryHandle.class); LensQuery ctx = target.path(handle.toString()).queryParam("sessionid", sessionHandle).request() .get(LensQuery.class); QueryStatus stat = ctx.getStatus(); while (!stat.isFinished()) { ctx = target.path(handle.toString()).queryParam("sessionid", sessionHandle).request() .get(LensQuery.class); stat = ctx.getStatus(); try { Thread.sleep(500); } catch (InterruptedException e) { throw new LensException(e); } } if (stat.getStatus() != QueryStatus.Status.SUCCESSFUL) { throw new LensException( "Query failed " + ctx.getQueryHandle().getHandleId() + " reason:" + stat.getErrorMessage()); } return ctx.getQueryHandle(); } } /** * To lens conf. * * @param conf the conf * @return the lens conf */ private LensConf toLensConf(HiveConf conf) { LensConf lensConf = new LensConf(); lensConf.getProperties().putAll(conf.getValByRegex(".*")); return lensConf; } }