org.apache.lens.ml.ModelLoader.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.lens.ml.ModelLoader.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.lens.ml;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import org.apache.commons.io.IOUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;

/**
 * Load ML models from a FS location.
 */
public final class ModelLoader {
    private ModelLoader() {
    }

    /** The Constant MODEL_PATH_BASE_DIR. */
    public static final String MODEL_PATH_BASE_DIR = "Lens.ml.model.basedir";

    /** The Constant MODEL_PATH_BASE_DIR_DEFAULT. */
    public static final String MODEL_PATH_BASE_DIR_DEFAULT = "file:///tmp";

    /** The Constant LOG. */
    public static final Log LOG = LogFactory.getLog(ModelLoader.class);

    /** The Constant TEST_REPORT_BASE_DIR. */
    public static final String TEST_REPORT_BASE_DIR = "Lens.ml.test.basedir";

    /** The Constant TEST_REPORT_BASE_DIR_DEFAULT. */
    public static final String TEST_REPORT_BASE_DIR_DEFAULT = "file:///tmp/ml_reports";

    // Model cache settings
    /** The Constant MODEL_CACHE_SIZE. */
    public static final long MODEL_CACHE_SIZE = 10;

    /** The Constant MODEL_CACHE_TIMEOUT. */
    public static final long MODEL_CACHE_TIMEOUT = 3600000L; // one hour

    /** The model cache. */
    private static Cache<Path, MLModel> modelCache = CacheBuilder.newBuilder().maximumSize(MODEL_CACHE_SIZE)
            .expireAfterAccess(MODEL_CACHE_TIMEOUT, TimeUnit.MILLISECONDS).build();

    /**
     * Gets the model location.
     *
     * @param conf      the conf
     * @param algorithm the algorithm
     * @param modelID   the model id
     * @return the model location
     */
    public static Path getModelLocation(Configuration conf, String algorithm, String modelID) {
        String modelDataBaseDir = conf.get(MODEL_PATH_BASE_DIR, MODEL_PATH_BASE_DIR_DEFAULT);
        // Model location format - <modelDataBaseDir>/<algorithm>/modelID
        return new Path(new Path(new Path(modelDataBaseDir), algorithm), modelID);
    }

    /**
     * Load model.
     *
     * @param conf      the conf
     * @param algorithm the algorithm
     * @param modelID   the model id
     * @return the ML model
     * @throws IOException Signals that an I/O exception has occurred.
     */
    public static MLModel loadModel(Configuration conf, String algorithm, String modelID) throws IOException {
        final Path modelPath = getModelLocation(conf, algorithm, modelID);
        LOG.info("Loading model for algorithm: " + algorithm + " modelID: " + modelID + " At path: "
                + modelPath.toUri().toString());
        try {
            return modelCache.get(modelPath, new Callable<MLModel>() {
                @Override
                public MLModel call() throws Exception {
                    FileSystem fs = modelPath.getFileSystem(new HiveConf());
                    if (!fs.exists(modelPath)) {
                        throw new IOException("Model path not found " + modelPath.toString());
                    }

                    ObjectInputStream ois = null;
                    try {
                        ois = new ObjectInputStream(fs.open(modelPath));
                        MLModel model = (MLModel) ois.readObject();
                        LOG.info("Loaded model " + model.getId() + " from location " + modelPath);
                        return model;
                    } catch (ClassNotFoundException e) {
                        throw new IOException(e);
                    } finally {
                        IOUtils.closeQuietly(ois);
                    }
                }
            });
        } catch (ExecutionException exc) {
            throw new IOException(exc);
        }
    }

    /**
     * Clear cache.
     */
    public static void clearCache() {
        modelCache.cleanUp();
    }

    /**
     * Gets the test report path.
     *
     * @param conf      the conf
     * @param algorithm the algorithm
     * @param report    the report
     * @return the test report path
     */
    public static Path getTestReportPath(Configuration conf, String algorithm, String report) {
        String testReportDir = conf.get(TEST_REPORT_BASE_DIR, TEST_REPORT_BASE_DIR_DEFAULT);
        return new Path(new Path(testReportDir, algorithm), report);
    }

    /**
     * Save test report.
     *
     * @param conf   the conf
     * @param report the report
     * @throws IOException Signals that an I/O exception has occurred.
     */
    public static void saveTestReport(Configuration conf, MLTestReport report) throws IOException {
        Path reportDir = new Path(conf.get(TEST_REPORT_BASE_DIR, TEST_REPORT_BASE_DIR_DEFAULT));
        FileSystem fs = reportDir.getFileSystem(conf);

        if (!fs.exists(reportDir)) {
            LOG.info("Creating test report dir " + reportDir.toUri().toString());
            fs.mkdirs(reportDir);
        }

        Path algoDir = new Path(reportDir, report.getAlgorithm());

        if (!fs.exists(algoDir)) {
            LOG.info("Creating algorithm report dir " + algoDir.toUri().toString());
            fs.mkdirs(algoDir);
        }

        ObjectOutputStream reportOutputStream = null;
        Path reportSaveLocation;
        try {
            reportSaveLocation = new Path(algoDir, report.getReportID());
            reportOutputStream = new ObjectOutputStream(fs.create(reportSaveLocation));
            reportOutputStream.writeObject(report);
            reportOutputStream.flush();
        } catch (IOException ioexc) {
            LOG.error("Error saving test report " + report.getReportID(), ioexc);
            throw ioexc;
        } finally {
            IOUtils.closeQuietly(reportOutputStream);
        }
        LOG.info("Saved report " + report.getReportID() + " at location " + reportSaveLocation.toUri());
    }

    /**
     * Load report.
     *
     * @param conf      the conf
     * @param algorithm the algorithm
     * @param reportID  the report id
     * @return the ML test report
     * @throws IOException Signals that an I/O exception has occurred.
     */
    public static MLTestReport loadReport(Configuration conf, String algorithm, String reportID)
            throws IOException {
        Path reportLocation = getTestReportPath(conf, algorithm, reportID);
        FileSystem fs = reportLocation.getFileSystem(conf);
        ObjectInputStream reportStream = null;
        MLTestReport report = null;

        try {
            reportStream = new ObjectInputStream(fs.open(reportLocation));
            report = (MLTestReport) reportStream.readObject();
        } catch (IOException ioex) {
            LOG.error("Error reading report " + reportLocation, ioex);
        } catch (ClassNotFoundException e) {
            throw new IOException(e);
        } finally {
            IOUtils.closeQuietly(reportStream);
        }
        return report;
    }

    /**
     * Delete model.
     *
     * @param conf      the conf
     * @param algorithm the algorithm
     * @param modelID   the model id
     * @throws IOException Signals that an I/O exception has occurred.
     */
    public static void deleteModel(HiveConf conf, String algorithm, String modelID) throws IOException {
        Path modelLocation = getModelLocation(conf, algorithm, modelID);
        FileSystem fs = modelLocation.getFileSystem(conf);
        fs.delete(modelLocation, false);
    }

    /**
     * Delete test report.
     *
     * @param conf      the conf
     * @param algorithm the algorithm
     * @param reportID  the report id
     * @throws IOException Signals that an I/O exception has occurred.
     */
    public static void deleteTestReport(HiveConf conf, String algorithm, String reportID) throws IOException {
        Path reportPath = getTestReportPath(conf, algorithm, reportID);
        reportPath.getFileSystem(conf).delete(reportPath, false);
    }
}