org.wso2.carbon.ml.core.impl.MLModelHandler.java Source code

Java tutorial

Introduction

Here is the source code for org.wso2.carbon.ml.core.impl.MLModelHandler.java

Source

/*
 * Copyright (c) 2015, WSO2 Inc. (http://www.wso2.org) All Rights Reserved.
 *
 * WSO2 Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.wso2.carbon.ml.core.impl;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.regex.Pattern;

import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.transform.OutputKeys;
import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;

import org.apache.commons.csv.CSVFormat;
import org.apache.commons.lang.math.NumberUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.InvalidRequestException;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.pmml.PMMLExportable;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.wso2.carbon.context.CarbonContext;
import org.wso2.carbon.context.PrivilegedCarbonContext;
import org.wso2.carbon.metrics.manager.Level;
import org.wso2.carbon.metrics.manager.MetricManager;
import org.wso2.carbon.metrics.manager.Timer.Context;
import org.wso2.carbon.ml.commons.constants.MLConstants;
import org.wso2.carbon.ml.commons.domain.*;
import org.wso2.carbon.ml.commons.domain.config.Storage;
import org.wso2.carbon.ml.core.exceptions.*;
import org.wso2.carbon.ml.core.factories.DatasetType;
import org.wso2.carbon.ml.core.factories.ModelBuilderFactory;
import org.wso2.carbon.ml.core.interfaces.MLInputAdapter;
import org.wso2.carbon.ml.core.interfaces.MLModelBuilder;
import org.wso2.carbon.ml.core.interfaces.MLOutputAdapter;
import org.wso2.carbon.ml.core.interfaces.PMMLModelContainer;
import org.wso2.carbon.ml.core.internal.MLModelConfigurationContext;
import org.wso2.carbon.ml.core.spark.algorithms.KMeans;
import org.wso2.carbon.ml.core.spark.algorithms.SparkModelUtils;
import org.wso2.carbon.ml.core.spark.models.MLDeeplearningModel;
import org.wso2.carbon.ml.core.spark.models.MLMatrixFactorizationModel;
import org.wso2.carbon.ml.core.spark.recommendation.CollaborativeFiltering;
import org.wso2.carbon.ml.core.spark.transformations.HeaderFilter;
import org.wso2.carbon.ml.core.spark.transformations.LineToTokens;
import org.wso2.carbon.ml.core.spark.transformations.MissingValuesFilter;
import org.wso2.carbon.ml.core.spark.transformations.TokensToVectors;
import org.wso2.carbon.ml.core.utils.BlockingExecutor;
import org.wso2.carbon.ml.core.utils.MLCoreServiceValueHolder;
import org.wso2.carbon.ml.core.utils.MLUtils;
import org.wso2.carbon.ml.core.utils.MLUtils.ColumnSeparatorFactory;
import org.wso2.carbon.ml.core.utils.MLUtils.DataTypeFactory;
import org.wso2.carbon.ml.database.DatabaseService;
import org.wso2.carbon.ml.database.exceptions.DatabaseHandlerException;
import org.wso2.carbon.registry.core.RegistryConstants;
import org.wso2.carbon.utils.ConfigurationContextService;
import org.xml.sax.InputSource;

import scala.Tuple2;
import hex.deeplearning.DeepLearningModel;

/**
 * {@link MLModelHandler} is responsible for handling/delegating all the model related requests.
 */
public class MLModelHandler {
    private static final Log log = LogFactory.getLog(MLModelHandler.class);
    private DatabaseService databaseService;
    private Properties mlProperties;
    private BlockingExecutor threadExecutor;

    public enum Format {
        SERIALIZED, PMML
    }

    public MLModelHandler() {
        MLCoreServiceValueHolder valueHolder = MLCoreServiceValueHolder.getInstance();
        databaseService = valueHolder.getDatabaseService();
        mlProperties = valueHolder.getMlProperties();
        threadExecutor = valueHolder.getThreadExecutor();
    }

    /**
     * Create a new model.
     *
     * @param model model to be created.
     * @throws MLModelHandlerException
     */
    public MLModelData createModel(MLModelData model) throws MLModelHandlerException {
        try {
            // set the model storage configurations
            Storage modelStorage = MLCoreServiceValueHolder.getInstance().getModelStorage();
            model.setStorageType(modelStorage.getStorageType());
            model.setStorageDirectory(modelStorage.getStorageDirectory());

            int tenantId = model.getTenantId();
            String userName = model.getUserName();
            MLAnalysis analysis = databaseService.getAnalysis(tenantId, userName, model.getAnalysisId());
            if (analysis == null) {
                throw new MLModelHandlerException("Invalid analysis [id] " + model.getAnalysisId());
            }

            MLDatasetVersion versionSet = databaseService.getVersionset(tenantId, userName,
                    model.getVersionSetId());
            if (versionSet == null) {
                throw new MLModelHandlerException("Invalid version set [id] " + model.getVersionSetId());

            }
            // set model name
            String modelName = analysis.getName();
            modelName = modelName + "." + MLConstants.MODEL_NAME + "." + MLUtils.getDate();
            model.setName(modelName);
            model.setStatus(MLConstants.MODEL_STATUS_NOT_STARTED);

            databaseService.insertModel(model);
            log.info(String.format("[Created] %s", model));
            return model;
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    /**
     * delete mddel using modelId
     *
     * @param tenantId  Unique ID of the tenant.
     * @param userName  Username of the user.
     * @param modelId   modelId of the model to be deleted.
     * @throws MLModelHandlerException
     */
    public void deleteModel(int tenantId, String userName, long modelId) throws MLModelHandlerException {
        try {
            databaseService.deleteModel(tenantId, userName, modelId);
            log.info(String.format("[Deleted] Model [id] %s", modelId));
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    /**
     * get Model using modelName
     *
     * @param tenantId  Unique ID of the tenant.
     * @param userName  Username of the user.
     * @param modelName modelName of the model to be retrieved.
     * @throws MLModelHandlerException
     */
    public MLModelData getModel(int tenantId, String userName, String modelName) throws MLModelHandlerException {
        try {
            return databaseService.getModel(tenantId, userName, modelName);
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    /**
     * get Model using modelId
     *
     * @param tenantId  Unique ID of the tenant.
     * @param userName  Username of the user.
     * @param modelId   modelId of the model to be retrieved.
     * @throws MLModelHandlerException
     */
    public MLModelData getModel(int tenantId, String userName, long modelId) throws MLModelHandlerException {
        try {
            return databaseService.getModel(tenantId, userName, modelId);
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    /**
     * get all models
     *
     * @param tenantId  Unique ID of the tenant.
     * @param userName  Username of the user.
     * @throws MLModelHandlerException
     */
    public List<MLModelData> getAllModels(int tenantId, String userName) throws MLModelHandlerException {
        try {
            return databaseService.getAllModels(tenantId, userName);
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    /**
     * check validity of modelId
     *
     * @param tenantId  Unique ID of the tenant.
     * @param userName  Username of the user.
     * @param modelId   modelId to be validated
     * @throws MLModelHandlerException
     */
    public boolean isValidModelId(int tenantId, String userName, long modelId) throws MLModelHandlerException {
        try {
            return databaseService.isValidModelId(tenantId, userName, modelId);
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    /**
     * check validity of model status
     *
     * @param modelId   modelId of the model which the status needs to be validated
     * @throws MLModelHandlerException
     */
    public boolean isValidModelStatus(long modelId, int tenantId, String userName) throws MLModelHandlerException {
        try {
            return databaseService.isValidModelStatus(modelId, tenantId, userName);
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(
                    "Model status for model [id] " + modelId + " is invalid :" + e.getMessage(), e);
        }
    }

    /**
     * @param modelId unique id of the model
     * @param storage MLStorage to be updated
     * @throws MLModelHandlerException
     */
    public void addStorage(long modelId, MLStorage storage) throws MLModelHandlerException {
        try {
            databaseService.updateModelStorage(modelId, storage.getType(), storage.getLocation());
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    /**
     * Get the summary of a model
     *
     * @param modelId ID of the model
     * @return Model Summary
     * @throws MLModelHandlerException
     */
    public ModelSummary getModelSummary(long modelId) throws MLModelHandlerException {
        try {
            return databaseService.getModelSummary(modelId);
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(e.getMessage(), e);
        }
    }

    /**
     * Build a ML model asynchronously and persist the built model in a given storage.
     *
     * @param modelId id of the model to be built.
     * @param tenantId tenant id
     * @param userName tenant user name
     * @throws MLModelHandlerException
     * @throws MLModelBuilderException
     */
    public Workflow buildModel(int tenantId, String userName, long modelId)
            throws MLModelHandlerException, MLModelBuilderException {

        if (!isValidModelId(tenantId, userName, modelId)) {
            String msg = String.format(
                    "Failed to build the model. Invalid model id: %s for tenant: %s and user: %s", modelId,
                    tenantId, userName);
            throw new MLModelHandlerException(msg);
        }

        try {
            long datasetVersionId = databaseService.getDatasetVersionIdOfModel(modelId);
            long datasetId = databaseService.getDatasetId(datasetVersionId);
            MLDataset dataset = databaseService.getDataset(tenantId, userName, datasetId);
            String dataSourceType = dataset.getDataSourceType();
            String dataType = databaseService.getDataTypeOfModel(modelId);
            String columnSeparator = ColumnSeparatorFactory.getColumnSeparator(dataType);
            String dataUrl = databaseService.getDatasetVersionUri(datasetVersionId);
            handleNull(dataUrl, "Target path is null for dataset version [id]: " + datasetVersionId);
            MLModelData model = databaseService.getModel(tenantId, userName, modelId);
            Workflow facts = databaseService.getWorkflow(model.getAnalysisId());
            facts.setDatasetVersion(databaseService.getVersionset(tenantId, userName, datasetVersionId).getName());
            facts.setDatasetURL(dataUrl);

            JavaRDD<String> lines;

            JavaSparkContext sparkContext = null;
            // java spark context
            sparkContext = MLCoreServiceValueHolder.getInstance().getSparkContext();

            try {
                lines = extractLines(tenantId, datasetId, sparkContext, dataUrl, dataSourceType, dataType);
            } catch (MLMalformedDatasetException e) {
                throw new MLModelBuilderException("Failed to build the model [id] " + modelId, e);
            }

            MLModelConfigurationContext context = buildMLModelConfigurationContext(modelId, datasetVersionId,
                    columnSeparator, model, facts, lines, sparkContext);

            // build the model asynchronously
            ModelBuilder task = new ModelBuilder(modelId, context);
            threadExecutor.execute(task);
            threadExecutor.afterExecute(task, null);

            databaseService.updateModelStatus(modelId, MLConstants.MODEL_STATUS_IN_PROGRESS);
            log.info(String.format("Build model [id] %s job is successfully submitted to Spark.", modelId));

            return facts;
        } catch (DatabaseHandlerException e) {
            throw new MLModelBuilderException(
                    "An error occurred while saving model [id] " + modelId + " to database: " + e.getMessage(), e);
        }
    }

    private MLModelConfigurationContext buildMLModelConfigurationContext(long modelId, long datasetVersionId,
            String columnSeparator, MLModelData model, Workflow facts, JavaRDD<String> lines,
            JavaSparkContext sparkContext) throws DatabaseHandlerException {
        MLModelConfigurationContext context = new MLModelConfigurationContext();
        context.setModelId(modelId);
        context.setColumnSeparator(columnSeparator);
        context.setFacts(facts);
        context.setModel(model);
        Map<String, String> summaryStatsOfFeatures = databaseService.getSummaryStats(datasetVersionId);
        context.setSummaryStatsOfFeatures(summaryStatsOfFeatures);
        int responseIndex = MLUtils.getFeatureIndex(facts.getResponseVariable(), facts.getFeatures());
        context.setIncludedFeaturesMap(MLUtils.getIncludedFeatures(facts, responseIndex));
        context.setNewToOldIndicesList(getNewToOldIndicesList(context.getIncludedFeaturesMap()));
        context.setResponseIndex(responseIndex);
        context.setSparkContext(sparkContext);
        context.setLines(lines);
        // get header line
        String headerRow = databaseService.getFeatureNamesInOrderUsingDatasetVersion(datasetVersionId,
                columnSeparator);
        context.setHeaderRow(headerRow);
        return context;
    }

    public List<?> predict(int tenantId, String userName, long modelId, String dataFormat, InputStream dataStream)
            throws MLModelHandlerException {
        List<String[]> data = new ArrayList<String[]>();
        CSVFormat csvFormat = DataTypeFactory.getCSVFormat(dataFormat);
        BufferedReader br = new BufferedReader(new InputStreamReader(dataStream, StandardCharsets.UTF_8));
        try {
            String line;
            while ((line = br.readLine()) != null) {
                String[] dataRow = line.split(csvFormat.getDelimiter() + "");
                data.add(dataRow);
            }
            return predict(tenantId, userName, modelId, data);
        } catch (IOException e) {
            String msg = "Failed to read the data points for prediction for model [id] " + modelId;
            log.error(msg, e);
            throw new MLModelHandlerException(msg, e);
        } finally {
            try {
                dataStream.close();
                br.close();
            } catch (IOException e) {
                String msg = "Error in closing input stream while publishing model";
                log.error(msg, e);
            }
        }

    }

    public String streamingPredict(int tenantId, String userName, long modelId, String dataFormat,
            String columnHeader, InputStream dataStream) throws MLModelHandlerException {
        List<String[]> data = new ArrayList<String[]>();
        CSVFormat csvFormat = DataTypeFactory.getCSVFormat(dataFormat);
        MLModel mlModel = retrieveModel(modelId);
        BufferedReader br = new BufferedReader(new InputStreamReader(dataStream, StandardCharsets.UTF_8));
        StringBuilder predictionsWithData = new StringBuilder();
        try {
            String line;
            if ((line = br.readLine()) != null && line.split(csvFormat.getDelimiter() + "").length == mlModel
                    .getNewToOldIndicesList().size()) {
                if (columnHeader.equalsIgnoreCase(MLConstants.NO)) {
                    String[] dataRow = line.split(csvFormat.getDelimiter() + "");
                    data.add(dataRow);
                } else {
                    predictionsWithData.append(line).append(MLConstants.NEW_LINE);
                }
                while ((line = br.readLine()) != null) {
                    String[] dataRow = line.split(csvFormat.getDelimiter() + "");
                    data.add(dataRow);
                }
                // cloning unencoded data to append with predictions
                List<String[]> unencodedData = new ArrayList<String[]>(data.size());
                for (String[] item : data) {
                    unencodedData.add(item.clone());
                }
                List<?> predictions = predict(tenantId, userName, modelId, data);
                for (int i = 0; i < predictions.size(); i++) {
                    predictionsWithData
                            .append(MLUtils.arrayToCsvString(unencodedData.get(i), csvFormat.getDelimiter()))
                            .append(String.valueOf(predictions.get(i))).append(MLConstants.NEW_LINE);
                }
            } else {
                int responseVariableIndex = mlModel.getResponseIndex();
                List<Integer> includedFeatureIndices = mlModel.getNewToOldIndicesList();
                List<String[]> unencodedData = new ArrayList<String[]>();
                if (columnHeader.equalsIgnoreCase(MLConstants.NO)) {
                    int count = 0;
                    String[] dataRow = line.split(csvFormat.getDelimiter() + "");
                    unencodedData.add(dataRow.clone());
                    String[] includedFeatureValues = new String[includedFeatureIndices.size()];
                    for (int index : includedFeatureIndices) {
                        includedFeatureValues[count++] = dataRow[index];
                    }
                    data.add(includedFeatureValues);
                } else {
                    predictionsWithData.append(line).append(MLConstants.NEW_LINE);
                }
                while ((line = br.readLine()) != null) {
                    int count = 0;
                    String[] dataRow = line.split(csvFormat.getDelimiter() + "");
                    unencodedData.add(dataRow.clone());
                    String[] includedFeatureValues = new String[includedFeatureIndices.size()];
                    for (int index : includedFeatureIndices) {
                        includedFeatureValues[count++] = dataRow[index];
                    }
                    data.add(includedFeatureValues);
                }

                List<?> predictions = predict(tenantId, userName, modelId, data);
                for (int i = 0; i < predictions.size(); i++) {
                    // replace with predicted value
                    unencodedData.get(i)[responseVariableIndex] = String.valueOf(predictions.get(i));
                    predictionsWithData
                            .append(MLUtils.arrayToCsvString(unencodedData.get(i), csvFormat.getDelimiter()));
                    predictionsWithData.deleteCharAt(predictionsWithData.length() - 1);
                    predictionsWithData.append(MLConstants.NEW_LINE);
                }
            }
            return predictionsWithData.toString();
        } catch (IOException e) {
            String msg = "Failed to read the data points for prediction for model [id] " + modelId;
            log.error(msg, e);
            throw new MLModelHandlerException(msg, e);
        } finally {
            try {
                if (dataStream != null && br != null) {
                    dataStream.close();
                    br.close();
                }
            } catch (IOException e) {
                String msg = MLUtils.getErrorMsg(String.format(
                        "Error occurred while closing the streams for model [id] %s of tenant [id] %s and [user] %s.",
                        modelId, tenantId, userName), e);
                log.warn(msg, e);
            }
        }

    }

    public List<?> predict(int tenantId, String userName, long modelId, List<String[]> data)
            throws MLModelHandlerException {

        if (!isValidModelId(tenantId, userName, modelId)) {
            String msg = String.format(
                    "Failed to build the model. Invalid model id: %s for tenant: %s and user: %s", modelId,
                    tenantId, userName);
            throw new MLModelHandlerException(msg);
        }

        if (!isValidModelStatus(modelId, tenantId, userName)) {
            String msg = String.format(
                    "This model cannot be used for prediction. Status of the model for model id: %s for tenant: %s and user: %s is not 'Complete'",
                    modelId, tenantId, userName);
            throw new MLModelHandlerException(msg);
        }

        MLModel builtModel = retrieveModel(modelId);

        // Validate number of features in predict dataset
        if (builtModel.getNewToOldIndicesList().size() != data.get(0).length) {
            String msg = String.format(
                    "Prediction failed from model [id] %s since [number of features of model]"
                            + " %s does not match [number of features in the input data] %s",
                    modelId, builtModel.getFeatures().size(), data.get(0).length);
            throw new MLModelHandlerException(msg);
        }

        // Validate numerical feature type in predict dataset
        for (Feature feature : builtModel.getFeatures()) {
            if (feature.getType().equals(FeatureType.NUMERICAL)) {
                int actualIndex = builtModel.getNewToOldIndicesList().indexOf(feature.getIndex());
                for (String[] dataPoint : data) {
                    if (!NumberUtils.isNumber(dataPoint[actualIndex])) {
                        String msg = String.format("Invalid value: %s for the feature: %s at feature index: %s",
                                dataPoint[actualIndex], feature.getName(), actualIndex);
                        throw new MLModelHandlerException(msg);
                    }
                }
            }
        }

        // predict
        Predictor predictor = new Predictor(modelId, builtModel, data);
        List<?> predictions = predictor.predict();

        return predictions;
    }

    public List<?> predict(int tenantId, String userName, long modelId, String dataFormat, InputStream dataStream,
            double percentile, boolean skipDecoding) throws MLModelHandlerException {
        List<String[]> data = new ArrayList<String[]>();
        CSVFormat csvFormat = DataTypeFactory.getCSVFormat(dataFormat);
        BufferedReader br = new BufferedReader(new InputStreamReader(dataStream, StandardCharsets.UTF_8));
        try {
            String line;
            while ((line = br.readLine()) != null) {
                String[] dataRow = line.split(csvFormat.getDelimiter() + "");
                data.add(dataRow);
            }
            return predict(tenantId, userName, modelId, data, percentile, skipDecoding);
        } catch (IOException e) {
            String msg = "Failed to read the data points for prediction for model [id] " + modelId;
            log.error(msg, e);
            throw new MLModelHandlerException(msg, e);
        } finally {
            try {
                dataStream.close();
                br.close();
            } catch (IOException e) {
                String msg = "Error in closing input stream while publishing model";
                log.error(msg, e);
            }
        }

    }

    public String streamingPredict(int tenantId, String userName, long modelId, String dataFormat,
            String columnHeader, InputStream dataStream, double percentile, boolean skipDecoding)
            throws MLModelHandlerException {
        List<String[]> data = new ArrayList<String[]>();
        CSVFormat csvFormat = DataTypeFactory.getCSVFormat(dataFormat);
        MLModel mlModel = retrieveModel(modelId);
        BufferedReader br = new BufferedReader(new InputStreamReader(dataStream, StandardCharsets.UTF_8));
        StringBuilder predictionsWithData = new StringBuilder();
        try {
            String line;
            if ((line = br.readLine()) != null && line.split(csvFormat.getDelimiter() + "").length == mlModel
                    .getNewToOldIndicesList().size()) {
                if (columnHeader.equalsIgnoreCase(MLConstants.NO)) {
                    String[] dataRow = line.split(csvFormat.getDelimiter() + "");
                    data.add(dataRow);
                } else {
                    predictionsWithData.append(line).append(MLConstants.NEW_LINE);
                }
                while ((line = br.readLine()) != null) {
                    String[] dataRow = line.split(csvFormat.getDelimiter() + "");
                    data.add(dataRow);
                }
                // cloning unencoded data to append with predictions
                List<String[]> unencodedData = new ArrayList<String[]>(data.size());
                for (String[] item : data) {
                    unencodedData.add(item.clone());
                }
                List<?> predictions = predict(tenantId, userName, modelId, data, percentile, skipDecoding);
                for (int i = 0; i < predictions.size(); i++) {
                    predictionsWithData
                            .append(MLUtils.arrayToCsvString(unencodedData.get(i), csvFormat.getDelimiter()))
                            .append(String.valueOf(predictions.get(i))).append(MLConstants.NEW_LINE);
                }
            } else {
                int responseVariableIndex = mlModel.getResponseIndex();
                List<Integer> includedFeatureIndices = mlModel.getNewToOldIndicesList();
                List<String[]> unencodedData = new ArrayList<String[]>();
                if (columnHeader.equalsIgnoreCase(MLConstants.NO)) {
                    int count = 0;
                    String[] dataRow = line.split(csvFormat.getDelimiter() + "");
                    unencodedData.add(dataRow.clone());
                    String[] includedFeatureValues = new String[includedFeatureIndices.size()];
                    for (int index : includedFeatureIndices) {
                        includedFeatureValues[count++] = dataRow[index];
                    }
                    data.add(includedFeatureValues);
                } else {
                    predictionsWithData.append(line).append(MLConstants.NEW_LINE);
                }
                while ((line = br.readLine()) != null) {
                    int count = 0;
                    String[] dataRow = line.split(csvFormat.getDelimiter() + "");
                    unencodedData.add(dataRow.clone());
                    String[] includedFeatureValues = new String[includedFeatureIndices.size()];
                    for (int index : includedFeatureIndices) {
                        includedFeatureValues[count++] = dataRow[index];
                    }
                    data.add(includedFeatureValues);
                }

                List<?> predictions = predict(tenantId, userName, modelId, data, percentile, skipDecoding);
                for (int i = 0; i < predictions.size(); i++) {
                    // replace with predicted value
                    unencodedData.get(i)[responseVariableIndex] = String.valueOf(predictions.get(i));
                    predictionsWithData
                            .append(MLUtils.arrayToCsvString(unencodedData.get(i), csvFormat.getDelimiter()));
                    predictionsWithData.deleteCharAt(predictionsWithData.length() - 1);
                    predictionsWithData.append(MLConstants.NEW_LINE);
                }
            }
            return predictionsWithData.toString();
        } catch (IOException | ArrayIndexOutOfBoundsException e) {
            String msg = "Failed to read the data points for prediction for model [id] " + modelId;
            log.error(msg, e);
            throw new MLModelHandlerException(msg, e);
        } finally {
            try {
                if (dataStream != null && br != null) {
                    dataStream.close();
                    br.close();
                }
            } catch (IOException e) {
                String msg = MLUtils.getErrorMsg(String.format(
                        "Error occurred while closing the streams for model [id] %s of tenant [id] %s and [user] %s.",
                        modelId, tenantId, userName), e);
                log.warn(msg, e);
            }
        }

    }

    public List<?> predict(int tenantId, String userName, long modelId, List<String[]> data, double percentile,
            boolean skipDecoding) throws MLModelHandlerException {

        if (!isValidModelId(tenantId, userName, modelId)) {
            String msg = String.format(
                    "Failed to build the model. Invalid model id: %s for tenant: %s and user: %s", modelId,
                    tenantId, userName);
            throw new MLModelHandlerException(msg);
        }

        if (!isValidModelStatus(modelId, tenantId, userName)) {
            String msg = String.format(
                    "This model cannot be used for prediction. Status of the model for model id: %s for tenant: %s and user: %s is not 'Complete'",
                    modelId, tenantId, userName);
            throw new MLModelHandlerException(msg);
        }

        if (data.size() == 0) {
            throw new MLModelHandlerException("Predict dataset is empty.");
        }

        MLModel builtModel = retrieveModel(modelId);

        // Validate number of features in predict dataset
        if (builtModel.getNewToOldIndicesList().size() != data.get(0).length) {
            String msg = String.format(
                    "Prediction failed from model [id] %s since [number of features of model]"
                            + " %s does not match [number of features in the input data] %s",
                    modelId, builtModel.getFeatures().size(), data.get(0).length);
            throw new MLModelHandlerException(msg);
        }

        // Validate numerical feature type in predict dataset
        for (Feature feature : builtModel.getFeatures()) {
            if (feature.getType().equals(FeatureType.NUMERICAL)) {
                int actualIndex = builtModel.getNewToOldIndicesList().indexOf(feature.getIndex());
                for (String[] dataPoint : data) {
                    if (!NumberUtils.isNumber(dataPoint[actualIndex])) {
                        String msg = String.format("Invalid value: %s for the feature: %s at feature index: %s",
                                dataPoint[actualIndex], feature.getName(), actualIndex);
                        throw new MLModelHandlerException(msg);
                    }
                }
            }
        }

        // predict
        Predictor predictor = new Predictor(modelId, builtModel, data, percentile, skipDecoding);
        List<?> predictions = predictor.predict();

        return predictions;
    }

    public List<?> getProductRecommendations(int tenantId, String userName, long modelId, int userId,
            int noOfProducts) throws MLModelHandlerException {

        MatrixFactorizationModel model = getMatrixFactorizationModel(tenantId, userName, modelId);
        List<?> recommendations = CollaborativeFiltering.recommendProducts(model, userId, noOfProducts);

        log.info(String.format("Recommendations from model [id] %s was successful.", modelId));
        return recommendations;

    }

    public List<?> getUserRecommendations(int tenantId, String userName, long modelId, int productId, int noOfUsers)
            throws MLModelHandlerException {

        MatrixFactorizationModel model = getMatrixFactorizationModel(tenantId, userName, modelId);
        List<?> recommendations = CollaborativeFiltering.recommendUsers(model, productId, noOfUsers);

        log.info(String.format("Recommendations from model [id] %s was successful.", modelId));
        return recommendations;

    }

    private MatrixFactorizationModel getMatrixFactorizationModel(int tenantId, String userName, long modelId)
            throws MLModelHandlerException {
        if (!isValidModelId(tenantId, userName, modelId)) {
            String msg = String.format(
                    "Failed to build the model. Invalid model id: %s for tenant: %s and user: %s", modelId,
                    tenantId, userName);
            throw new MLModelHandlerException(msg);
        }

        MLModel builtModel = retrieveModel(modelId);

        //validate if retrieved model is a MatrixFactorizationModel
        if (!(builtModel.getModel() instanceof MLMatrixFactorizationModel)) {
            String msg = String
                    .format("Cannot get recommendations for model [id] %s , since it is not generated from a "
                            + "Recommendation algorithm.", modelId);
            throw new MLModelHandlerException(msg);
        }
        //get recommendations
        MatrixFactorizationModel model = ((MLMatrixFactorizationModel) builtModel.getModel()).getModel();
        return model;
    }

    private void persistModel(long modelId, String modelName, MLModel model) throws MLModelBuilderException {
        try {
            MLStorage storage = databaseService.getModelStorage(modelId);
            if (storage == null) {
                throw new MLModelBuilderException("Invalid model ID: " + modelId);
            }
            String storageType = storage.getType();
            String storageLocation = storage.getLocation();
            String outPath = storageLocation + File.separator + modelName;

            // if this is a deeplearning model, need to set the storage location for writing
            // then the sparkdeeplearning model will use ObjectTreeBinarySerializer to write it to the given directory
            // the DeeplearningModel will be saved as a .bin file
            if (MLConstants.DEEPLEARNING.equalsIgnoreCase(model.getAlgorithmClass())) {
                MLDeeplearningModel mlDeeplearningModel = (MLDeeplearningModel) model.getModel();
                mlDeeplearningModel.setStorageLocation(storageLocation);
                model.setModel(mlDeeplearningModel);

                // Write POJO if it is a Deep Learning model
                // convert model name
                String dlModelName = modelName.replace('.', '_').replace('-', '_');
                File file = new File(storageLocation + "/" + dlModelName + "_dl" + ".java");
                FileOutputStream fileOutputStream = new FileOutputStream(file);
                DeepLearningModel deepLearningModel = mlDeeplearningModel.getDlModel();
                deepLearningModel.toJava(fileOutputStream, false, false);
                fileOutputStream.close();

                MLModel dlModel = new MLModel();
                dlModel.setAlgorithmClass(model.getAlgorithmClass());
                dlModel.setAlgorithmName(model.getAlgorithmName());
                dlModel.setEncodings(model.getEncodings());
                dlModel.setFeatures(model.getFeatures());
                dlModel.setResponseIndex(model.getResponseIndex());
                dlModel.setResponseVariable(model.getResponseVariable());
                dlModel.setNewToOldIndicesList(model.getNewToOldIndicesList());

                // Writing the DL model without Deep Learning logic
                // For prediction with POJO
                MLIOFactory ioFactoryDl = new MLIOFactory(mlProperties);
                MLOutputAdapter outputAdapterDl = ioFactoryDl
                        .getOutputAdapter(storageType + MLConstants.OUT_SUFFIX);
                ByteArrayOutputStream baosDl = new ByteArrayOutputStream();
                ObjectOutputStream oosDl = new ObjectOutputStream(baosDl);
                oosDl.writeObject(dlModel);
                oosDl.flush();
                oosDl.close();
                InputStream isDl = new ByteArrayInputStream(baosDl.toByteArray());
                // adapter will write the model and close the stream.
                outputAdapterDl.write(outPath + "_dl", isDl);
            }

            MLIOFactory ioFactory = new MLIOFactory(mlProperties);
            MLOutputAdapter outputAdapter = ioFactory.getOutputAdapter(storageType + MLConstants.OUT_SUFFIX);
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ObjectOutputStream oos = new ObjectOutputStream(baos);
            oos.writeObject(model);
            oos.flush();
            oos.close();
            InputStream is = new ByteArrayInputStream(baos.toByteArray());
            // adapter will write the model and close the stream.
            outputAdapter.write(outPath, is);
            databaseService.updateModelStorage(modelId, storageType, outPath);
            log.info(String.format("Successfully persisted the model [id] %s", modelId));
        } catch (Exception e) {
            throw new MLModelBuilderException("Failed to persist the model [id] " + modelId + ". " + e.getMessage(),
                    e);
        }
    }

    private List<Integer> getNewToOldIndicesList(SortedMap<Integer, String> includedFeatures) {
        List<Integer> indicesList = new ArrayList<Integer>();
        for (int featureIdx : includedFeatures.keySet()) {
            indicesList.add(featureIdx);
        }
        return indicesList;
    }

    public MLModel retrieveModel(long modelId) throws MLModelHandlerException {
        InputStream in = null;
        ObjectInputStream ois = null;
        String storageLocation = null;
        try {
            MLStorage storage = databaseService.getModelStorage(modelId);
            if (storage == null) {
                throw new MLModelHandlerException("Invalid model ID: " + modelId);
            }
            String storageType = storage.getType();
            storageLocation = storage.getLocation();
            MLIOFactory ioFactory = new MLIOFactory(mlProperties);
            MLInputAdapter inputAdapter = ioFactory.getInputAdapter(storageType + MLConstants.IN_SUFFIX);
            in = inputAdapter.read(storageLocation);
            ois = new ObjectInputStream(in);

            // for the DeeplearningModel since the storageLocation is serialized
            // so the ObjectTreeBinarySerializer will get the storageLocation and deserialize
            MLModel model = (MLModel) ois.readObject();

            if (log.isDebugEnabled()) {
                log.debug("Successfully retrieved model");
            }

            return model;
        } catch (Exception e) {
            throw new MLModelHandlerException("Failed to retrieve the model [id] " + modelId, e);
        } finally {
            if (in != null) {
                try {
                    in.close();
                } catch (IOException e) {
                    String msg = "Error in closing input stream while publishing model";
                    log.error(msg, e);
                }
            }
            if (ois != null) {
                try {
                    ois.close();
                } catch (IOException e) {
                    String msg = "Error in closing input stream while publishing model";
                    log.error(msg, e);
                }
            }
        }
    }

    /**
     * Publish a ML model to registry.
     *
     * @param tenantId Unique ID of the tenant.
     * @param userName Username of the user.
     * @param modelId  Unique ID of the built ML model
     * @throws InvalidRequestException, MLModelPublisherException, MLModelHandlerException
     */
    public String publishModel(int tenantId, String userName, long modelId, Format mode)
            throws InvalidRequestException, MLModelPublisherException, MLModelHandlerException,
            MLPmmlExportException {
        InputStream in = null;
        String errorMsg = "Failed to publish the model [id] " + modelId;
        RegistryOutputAdapter registryOutputAdapter = new RegistryOutputAdapter();
        String relativeRegistryPath = null;

        switch (mode) {
        case SERIALIZED:
            try {
                // read model
                MLStorage storage = databaseService.getModelStorage(modelId);
                if (storage == null) {
                    throw new InvalidRequestException("Invalid model [id] " + modelId);
                }
                String storageType = storage.getType();
                String storageLocation = storage.getLocation();
                MLIOFactory ioFactory = new MLIOFactory(mlProperties);
                MLInputAdapter inputAdapter = ioFactory.getInputAdapter(storageType + MLConstants.IN_SUFFIX);
                in = inputAdapter.read(storageLocation);
                if (in == null) {
                    throw new InvalidRequestException("Invalid model [id] " + modelId);
                }
                // create registry path
                MLCoreServiceValueHolder valueHolder = MLCoreServiceValueHolder.getInstance();
                String modelName = databaseService.getModel(tenantId, userName, modelId).getName();
                relativeRegistryPath = "/" + valueHolder.getModelRegistryLocation() + "/" + modelName;
                // publish to registry
                registryOutputAdapter.write(relativeRegistryPath, in);
            } catch (DatabaseHandlerException e) {
                throw new MLModelPublisherException(errorMsg, e);
            } catch (MLInputAdapterException e) {
                throw new MLModelPublisherException(errorMsg, e);
            } catch (MLOutputAdapterException e) {
                throw new MLModelPublisherException(errorMsg, e);
            } finally {
                if (in != null) {
                    try {
                        in.close();
                    } catch (IOException e) {
                        String msg = "Error in closing input stream while publishing model";
                        log.error(msg, e);
                    }
                }
            }
            break;

        case PMML:
            MLCoreServiceValueHolder valueHolder = MLCoreServiceValueHolder.getInstance();
            try {
                String modelName = databaseService.getModel(tenantId, userName, modelId).getName();
                relativeRegistryPath = "/" + valueHolder.getModelRegistryLocation() + "/" + modelName + ".xml";

                MLModel model = retrieveModel(modelId);
                String pmmlModel = exportAsPMML(model);
                in = new ByteArrayInputStream(pmmlModel.getBytes(StandardCharsets.UTF_8));
                registryOutputAdapter.write(relativeRegistryPath, in);

            } catch (DatabaseHandlerException e) {
                throw new MLModelPublisherException(errorMsg, e);
            } catch (MLModelHandlerException e) {
                throw new MLModelHandlerException("Failed to retrieve the model [id] " + modelId, e);
            } catch (MLOutputAdapterException e) {
                throw new MLModelPublisherException(errorMsg, e);
            } catch (MLPmmlExportException e) {
                throw new MLPmmlExportException("PMML export not supported for model type");
            } finally {
                if (in != null) {
                    try {
                        in.close();
                    } catch (IOException e) {
                        String msg = "Error in closing input stream while publishing model";
                        log.error(msg, e);
                    }
                }
            }
            break;

        default:
            throw new MLModelPublisherException(errorMsg);
        }

        return RegistryConstants.GOVERNANCE_REGISTRY_BASE_PATH + relativeRegistryPath;
    }

    public List<ClusterPoint> getClusterPoints(int tenantId, String userName, long datasetId,
            String featureListString, int noOfClusters)
            throws MLMalformedDatasetException, MLModelHandlerException {
        JavaSparkContext sparkContext = null;
        List<String> features = Arrays.asList(featureListString.split("\\s*,\\s*"));

        try {
            List<ClusterPoint> clusterPoints = new ArrayList<ClusterPoint>();

            String datasetURL = databaseService.getDatasetUri(datasetId);
            MLDataset dataset = databaseService.getDataset(tenantId, userName, datasetId);
            String dataSourceType = dataset.getDataSourceType();
            String dataType = dataset.getDataType();
            // java spark context
            sparkContext = MLCoreServiceValueHolder.getInstance().getSparkContext();
            JavaRDD<String> lines;
            // parse lines in the dataset
            lines = extractLines(tenantId, datasetId, sparkContext, datasetURL, dataSourceType, dataType);
            // get column separator
            String columnSeparator = ColumnSeparatorFactory.getColumnSeparator(dataType);
            // get header line
            String headerRow = databaseService.getFeatureNamesInOrder(datasetId, columnSeparator);
            Pattern pattern = MLUtils.getPatternFromDelimiter(columnSeparator);
            // get selected feature indices
            List<Integer> featureIndices = new ArrayList<Integer>();
            for (String feature : features) {
                featureIndices.add(MLUtils.getFeatureIndex(feature, headerRow, columnSeparator));
            }
            JavaRDD<org.apache.spark.mllib.linalg.Vector> featureVectors = null;

            double sampleSize = (double) MLCoreServiceValueHolder.getInstance().getSummaryStatSettings()
                    .getSampleSize();
            double sampleFraction = sampleSize / (lines.count() - 1);
            HeaderFilter headerFilter = new HeaderFilter.Builder().header(headerRow).build();
            LineToTokens lineToTokens = new LineToTokens.Builder().separator(pattern).build();
            MissingValuesFilter missingValuesFilter = new MissingValuesFilter.Builder().build();
            TokensToVectors tokensToVectors = new TokensToVectors.Builder().indices(featureIndices).build();

            // Use entire dataset if number of records is less than or equal to sample fraction
            if (sampleFraction >= 1.0) {
                featureVectors = lines.filter(headerFilter).map(lineToTokens).filter(missingValuesFilter)
                        .map(tokensToVectors);
            }
            // Use ramdomly selected sample fraction of rows if number of records is > sample fraction
            else {
                featureVectors = lines.filter(headerFilter).sample(false, sampleFraction).map(lineToTokens)
                        .filter(missingValuesFilter).map(tokensToVectors);
            }
            KMeans kMeans = new KMeans();
            KMeansModel kMeansModel = kMeans.train(featureVectors, noOfClusters, 100);
            // Populate cluster points list with predicted clusters and features
            List<Tuple2<Integer, org.apache.spark.mllib.linalg.Vector>> kMeansPredictions = kMeansModel
                    .predict(featureVectors).zip(featureVectors).collect();
            for (Tuple2<Integer, org.apache.spark.mllib.linalg.Vector> kMeansPrediction : kMeansPredictions) {
                ClusterPoint clusterPoint = new ClusterPoint();
                clusterPoint.setCluster(kMeansPrediction._1());
                clusterPoint.setFeatures(kMeansPrediction._2().toArray());
                clusterPoints.add(clusterPoint);
            }
            return clusterPoints;
        } catch (DatabaseHandlerException e) {
            throw new MLModelHandlerException(
                    "An error occurred while generating cluster points: " + e.getMessage(), e);
        }
    }

    private JavaRDD<String> extractLines(int tenantId, long datasetId, JavaSparkContext sparkContext,
            String datasetURL, String dataSourceType, String dataType) throws MLMalformedDatasetException {
        JavaRDD<String> lines;
        if (DatasetType.DAS == DatasetType.getDatasetType(dataSourceType)) {
            try {
                lines = MLUtils.getLinesFromDASTable(datasetURL, tenantId, sparkContext);
            } catch (Exception e) {
                throw new MLMalformedDatasetException("Unable to extract the data from DAS table: " + datasetURL,
                        e);
            }
        } else {
            // parse lines in the dataset
            lines = sparkContext.textFile(datasetURL);
        }
        return lines;
    }

    /**
     * Export a ML model in PMML format.
     *
     * @param model the model to be exported
     * @return PMML model as a String
     * @throws MLPmmlExportException
     */
    public String exportAsPMML(MLModel model) throws MLPmmlExportException {
        Externalizable extModel = model.getModel();

        try {
            if (extModel instanceof PMMLModelContainer) {
                PMMLExportable pmmlExportableModel = ((PMMLModelContainer) extModel).getPMMLExportable();
                String pmmlString = pmmlExportableModel.toPMML();
                try {
                    //temporary fix for appending version
                    String pmmlWithVersion = appendVersionToPMML(pmmlString);
                    // print the model in the log
                    log.info(pmmlWithVersion);
                    return pmmlWithVersion;
                } catch (Exception e) {
                    String msg = "Error while appending version attribute to pmml";
                    log.error(msg, e);
                    throw new MLPmmlExportException(msg);
                }
            } else {
                throw new MLPmmlExportException("PMML export not supported for model type");
            }
        } catch (MLPmmlExportException e) {
            throw new MLPmmlExportException("PMML export not supported for model type");
        }
    }

    /**
     * Append version attribute to pmml (temporary fix)
     *
     * @param pmmlString the pmml string to be appended
     * @return PMML with version as a String
     * @throws MLPmmlExportException
     */
    private String appendVersionToPMML(String pmmlString) throws MLPmmlExportException {
        DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
        DocumentBuilder builder;
        StringWriter stringWriter = null;

        try {
            //convert the string to xml to append the version attribute
            builder = factory.newDocumentBuilder();
            Document document = builder.parse(new InputSource(new StringReader(pmmlString)));
            Element root = document.getDocumentElement();
            root.setAttribute("version", "4.2");

            // convert it back to string
            stringWriter = new StringWriter();
            TransformerFactory tf = TransformerFactory.newInstance();
            Transformer transformer = tf.newTransformer();
            transformer.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "no");
            transformer.setOutputProperty(OutputKeys.METHOD, "xml");
            transformer.setOutputProperty(OutputKeys.INDENT, "yes");
            transformer.setOutputProperty(OutputKeys.ENCODING, "UTF-8");

            transformer.transform(new DOMSource(document), new StreamResult(stringWriter));
            return stringWriter.toString();
        } catch (Exception e) {
            String msg = "Error while appending version attribute to pmml";
            log.error(msg, e);
            throw new MLPmmlExportException(msg);
        } finally {
            try {
                if (stringWriter != null) {
                    stringWriter.close();
                }
            } catch (IOException e) {
                String msg = "Error while closing stringWriter stream resource";
                log.error(msg, e);
                throw new MLPmmlExportException(msg);
            }
        }
    }

    class ModelBuilder implements Runnable {

        private long id;
        private MLModelConfigurationContext ctxt;
        private int tenantId;
        private String tenantDomain;
        private String username;
        private String emailNotificationEndpoint = MLCoreServiceValueHolder.getInstance()
                .getEmailNotificationEndpoint();

        public ModelBuilder(long modelId, MLModelConfigurationContext context) {
            id = modelId;
            ctxt = context;
            CarbonContext carbonContext = PrivilegedCarbonContext.getThreadLocalCarbonContext();
            tenantId = carbonContext.getTenantId();
            tenantDomain = carbonContext.getTenantDomain();
            username = carbonContext.getUsername();
        }

        @Override
        public void run() {
            org.wso2.carbon.metrics.manager.Timer timer = MetricManager.timer(Level.INFO,
                    "org.wso2.carbon.ml.model-building-time." + ctxt.getFacts().getAlgorithmName());
            Context context = timer.start();
            String[] emailTemplateParameters = new String[2];
            try {
                long t1 = System.currentTimeMillis();
                emailTemplateParameters[0] = username;
                // Set tenant info in the carbon context
                PrivilegedCarbonContext.startTenantFlow();
                PrivilegedCarbonContext.getThreadLocalCarbonContext().setTenantId(tenantId);
                PrivilegedCarbonContext.getThreadLocalCarbonContext().setTenantDomain(tenantDomain);

                String algorithmType = ctxt.getFacts().getAlgorithmClass();
                List<Map<String, Integer>> encodings = SparkModelUtils.buildEncodings(ctxt);
                ctxt.setEncodings(encodings);

                // gets the model builder
                MLModelBuilder modelBuilder = ModelBuilderFactory.getModelBuilder(algorithmType, ctxt);
                // pre-process and build the model
                MLModel model = modelBuilder.build();
                log.info(String.format("Successfully built the model [id] %s in %s seconds.", id,
                        (double) (System.currentTimeMillis() - t1) / 1000));

                persistModel(id, ctxt.getModel().getName(), model);

                if (emailNotificationEndpoint != null) {

                    emailTemplateParameters[1] = getLink(ctxt, MLConstants.MODEL_STATUS_COMPLETE);
                    EmailNotificationSender.sendModelBuildingCompleteNotification(emailNotificationEndpoint,
                            emailTemplateParameters);
                }
            } catch (MLInputValidationException e) {
                log.error(String.format("Failed to build the model [id] %s ", id), e);
                try {
                    databaseService.updateModelStatus(id, MLConstants.MODEL_STATUS_FAILED);
                    databaseService.updateModelError(id, e.getMessage() + "\n" + ctxt.getFacts().toString());
                    emailTemplateParameters[1] = getLink(ctxt, MLConstants.MODEL_STATUS_FAILED);
                } catch (DatabaseHandlerException e1) {
                    log.error(String.format("Failed to update the status of model [id] %s ", id), e1);
                }
                EmailNotificationSender.sendModelBuildingFailedNotification(emailNotificationEndpoint,
                        emailTemplateParameters);
            } catch (MLModelBuilderException e) {
                log.error(String.format("Failed to build the model [id] %s ", id), e);
                try {
                    databaseService.updateModelStatus(id, MLConstants.MODEL_STATUS_FAILED);
                    databaseService.updateModelError(id, e.getMessage() + "\n" + ctxt.getFacts().toString());
                    emailTemplateParameters[1] = getLink(ctxt, MLConstants.MODEL_STATUS_FAILED);
                } catch (DatabaseHandlerException e1) {
                    log.error(String.format("Failed to update the status of model [id] %s ", id), e1);
                }
                EmailNotificationSender.sendModelBuildingFailedNotification(emailNotificationEndpoint,
                        emailTemplateParameters);
            } finally {
                context.stop();
                PrivilegedCarbonContext.endTenantFlow();
            }
        }
    }

    private void handleNull(Object obj, String msg) throws MLModelHandlerException {
        if (obj == null) {
            throw new MLModelHandlerException(msg);
        }
    }

    /**
     * Method to get the link to model build result page
     *
     * @param context ML model configuration context
     * @param status Model building status
     * @return link to model build result page
     */
    private String getLink(MLModelConfigurationContext context, String status) {

        MLModelData mlModelData = context.getModel();
        long modelId = mlModelData.getId();
        String modelName = mlModelData.getName();
        long analysisId = mlModelData.getAnalysisId();
        int tenantId = mlModelData.getTenantId();
        String userName = mlModelData.getUserName();

        MLAnalysis analysis;
        String analysisName;
        MLProject mlProject;
        String projectName;
        long datasetId;
        DatabaseService databaseService = MLCoreServiceValueHolder.getInstance().getDatabaseService();

        try {
            analysis = databaseService.getAnalysis(tenantId, userName, analysisId);
            analysisName = analysis.getName();
            long projectId = analysis.getProjectId();

            mlProject = databaseService.getProject(tenantId, userName, projectId);
            projectName = mlProject.getName();
            datasetId = mlProject.getDatasetId();
        } catch (DatabaseHandlerException e) {
            log.warn(String.format("Failed to generate link for model [id] %s ", modelId), e);
            return "[Failed to generate link for model ID: " + modelId + "]";
        }

        ConfigurationContextService configContextService = MLCoreServiceValueHolder.getInstance()
                .getConfigurationContextService();
        String mlUrl = configContextService.getServerConfigContext().getProperty("ml.url").toString();
        String link = mlUrl + "/site/analysis/analysis.jag?analysisId=" + analysisId + "&analysisName="
                + analysisName + "&datasetId=" + datasetId;
        if (status.equals(MLConstants.MODEL_STATUS_COMPLETE)) {
            link = mlUrl + "/site/analysis/view-model.jag?analysisId=" + analysisId + "&datasetId=" + datasetId
                    + "&modelId=" + modelId + "&projectName=" + projectName + "&" + "analysisName=" + analysisName
                    + "&modelName=" + modelName + "&fromCompare=false";
        }
        return link;
    }

}