org.wso2.carbon.ml.core.utils.MLUtils.java Source code

Java tutorial

Introduction

Here is the source code for org.wso2.carbon.ml.core.utils.MLUtils.java

Source

/*
 * Copyright (c) 2015, WSO2 Inc. (http://www.wso2.org) All Rights Reserved.
 *
 * WSO2 Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.wso2.carbon.ml.core.utils;

import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.regex.Pattern;

import org.apache.commons.csv.CSVFormat;
import org.apache.commons.lang.math.NumberUtils;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.wso2.carbon.analytics.api.AnalyticsDataAPI;
import org.wso2.carbon.analytics.datasource.commons.AnalyticsSchema;
import org.wso2.carbon.analytics.datasource.commons.ColumnDefinition;
import org.wso2.carbon.analytics.datasource.commons.exception.AnalyticsException;
import org.wso2.carbon.analytics.datasource.commons.exception.AnalyticsTableNotAvailableException;
import org.wso2.carbon.context.PrivilegedCarbonContext;
import org.wso2.carbon.ml.commons.constants.MLConstants;
import org.wso2.carbon.ml.commons.domain.Feature;
import org.wso2.carbon.ml.commons.domain.MLDatasetVersion;
import org.wso2.carbon.ml.commons.domain.SamplePoints;
import org.wso2.carbon.ml.commons.domain.Workflow;
import org.wso2.carbon.ml.commons.domain.config.MLProperty;
import org.wso2.carbon.ml.core.exceptions.MLMalformedDatasetException;
import org.wso2.carbon.ml.core.spark.transformations.DiscardedRowsFilter;
import org.wso2.carbon.ml.core.spark.transformations.HeaderFilter;
import org.wso2.carbon.ml.core.spark.transformations.LineToTokens;
import org.wso2.carbon.ml.core.spark.transformations.RowsToLines;

/**
 * Common utility methods used in ML core.
 */
public class MLUtils {

    /**
     * Generate a random sample of the dataset using Spark.
     */
    public static SamplePoints getSample(String path, String dataType, int sampleSize, boolean containsHeader)
            throws MLMalformedDatasetException {

        JavaSparkContext sparkContext = null;
        try {
            Map<String, Integer> headerMap = null;
            // List containing actual data of the sample.
            List<List<String>> columnData = new ArrayList<List<String>>();
            CSVFormat dataFormat = DataTypeFactory.getCSVFormat(dataType);

            // java spark context
            sparkContext = MLCoreServiceValueHolder.getInstance().getSparkContext();
            JavaRDD<String> lines;

            // parse lines in the dataset
            lines = sparkContext.textFile(path);
            // validates the data format of the file
            String firstLine = lines.first();
            if (!firstLine.contains("" + dataFormat.getDelimiter())) {
                throw new MLMalformedDatasetException(String.format(
                        "File content does not match the data format. [First Line] %s [Data Format] %s", firstLine,
                        dataType));
            }
            return getSamplePoints(sampleSize, containsHeader, headerMap, columnData, dataFormat, lines);

        } catch (Exception e) {
            throw new MLMalformedDatasetException(
                    "Failed to extract the sample points from path: " + path + ". Cause: " + e, e);
        }
    }

    public static String getFirstLine(String filePath) {
        JavaSparkContext sparkContext = MLCoreServiceValueHolder.getInstance().getSparkContext();
        return sparkContext.textFile(filePath).first();
    }

    /**
     * Generate a random sample of the dataset using Spark.
     */
    public static SamplePoints getSampleFromDAS(String path, int sampleSize, String sourceType, int tenantId)
            throws MLMalformedDatasetException {

        JavaSparkContext sparkContext = null;
        try {
            Map<String, Integer> headerMap = null;
            // List containing actual data of the sample.
            List<List<String>> columnData = new ArrayList<List<String>>();

            // java spark context
            sparkContext = MLCoreServiceValueHolder.getInstance().getSparkContext();
            JavaRDD<String> lines;
            String headerLine = extractHeaderLine(path, tenantId);
            headerMap = generateHeaderMap(headerLine, CSVFormat.RFC4180);

            // DAS case path = table name
            lines = getLinesFromDASTable(path, tenantId, sparkContext);

            return getSamplePoints(sampleSize, true, headerMap, columnData, CSVFormat.RFC4180, lines);

        } catch (Exception e) {
            throw new MLMalformedDatasetException(
                    "Failed to extract the sample points from path: " + path + ". Cause: " + e, e);
        }
    }

    public static JavaRDD<String> getLinesFromDASTable(String tableName, int tenantId,
            JavaSparkContext sparkContext) throws AnalyticsTableNotAvailableException, AnalyticsException {
        JavaRDD<String> lines;
        String tableSchema = extractTableSchema(tableName, tenantId);
        SQLContext sqlCtx = new SQLContext(sparkContext);
        sqlCtx.sql(
                "CREATE TEMPORARY TABLE ML_REF USING org.wso2.carbon.analytics.spark.core.sources.AnalyticsRelationProvider "
                        + "OPTIONS (" + "tenantId \"" + tenantId + "\", " + "tableName \"" + tableName + "\", "
                        + "schema \"" + tableSchema + "\"" + ")");

        DataFrame dataFrame = sqlCtx.sql("select * from ML_REF");
        // Additional auto-generated column "_timestamp" needs to be dropped because it is not in the schema.
        JavaRDD<Row> rows = dataFrame.drop("_timestamp").javaRDD();
        lines = rows.map(new RowsToLines.Builder().separator(CSVFormat.RFC4180.getDelimiter() + "").build());
        return lines;
    }

    private static SamplePoints getSamplePoints(int sampleSize, boolean containsHeader,
            Map<String, Integer> headerMap, List<List<String>> columnData, CSVFormat dataFormat,
            JavaRDD<String> lines) {
        int featureSize;
        int[] missing;
        int[] stringCellCount;
        int[] decimalCellCount;
        // take the first line
        String firstLine = lines.first();
        // count the number of features
        featureSize = getFeatureSize(firstLine, dataFormat);

        List<Integer> featureIndices = new ArrayList<Integer>();
        for (int i = 0; i < featureSize; i++) {
            featureIndices.add(i);
        }

        String columnSeparator = String.valueOf(dataFormat.getDelimiter());
        HeaderFilter headerFilter = new HeaderFilter.Builder().header(lines.first()).build();
        JavaRDD<String> data = lines.filter(headerFilter).cache();
        Pattern pattern = MLUtils.getPatternFromDelimiter(columnSeparator);
        LineToTokens lineToTokens = new LineToTokens.Builder().separator(pattern).build();
        JavaRDD<String[]> tokens = data.map(lineToTokens);

        // remove from cache
        data.unpersist();
        // add to cache
        tokens.cache();

        missing = new int[featureSize];
        stringCellCount = new int[featureSize];
        decimalCellCount = new int[featureSize];
        if (sampleSize >= 0 && featureSize > 0) {
            sampleSize = sampleSize / featureSize;
        }
        for (int i = 0; i < featureSize; i++) {
            columnData.add(new ArrayList<String>());
        }

        if (headerMap == null) {
            // generate the header map
            if (containsHeader) {
                headerMap = generateHeaderMap(lines.first(), dataFormat);
            } else {
                headerMap = generateHeaderMap(featureSize);
            }
        }

        // take a random sample
        List<String[]> sampleLines = tokens.takeSample(false, sampleSize);

        // remove from cache
        tokens.unpersist();

        // iterate through sample lines
        for (String[] columnValues : sampleLines) {
            for (int currentCol = 0; currentCol < featureSize; currentCol++) {
                // Check whether the row is complete.
                if (currentCol < columnValues.length) {
                    // Append the cell to the respective column.
                    columnData.get(currentCol).add(columnValues[currentCol]);

                    if (MLConstants.MISSING_VALUES.contains(columnValues[currentCol])) {
                        // If the cell is empty, increase the missing value count.
                        missing[currentCol]++;
                    } else {
                        // check whether a column value is a string
                        if (!NumberUtils.isNumber(columnValues[currentCol])) {
                            stringCellCount[currentCol]++;
                        } else if (columnValues[currentCol].indexOf('.') != -1) {
                            // if it is a number and has the decimal point
                            decimalCellCount[currentCol]++;
                        }
                    }
                } else {
                    columnData.get(currentCol).add(null);
                    missing[currentCol]++;
                }
            }
        }

        SamplePoints samplePoints = new SamplePoints();
        samplePoints.setHeader(headerMap);
        samplePoints.setSamplePoints(columnData);
        samplePoints.setMissing(missing);
        samplePoints.setStringCellCount(stringCellCount);
        samplePoints.setDecimalCellCount(decimalCellCount);
        return samplePoints;
    }

    public static String extractTableSchema(String path, int tenantId)
            throws AnalyticsTableNotAvailableException, AnalyticsException {
        if (path == null) {
            return null;
        }
        AnalyticsDataAPI analyticsDataApi = (AnalyticsDataAPI) PrivilegedCarbonContext.getThreadLocalCarbonContext()
                .getOSGiService(AnalyticsDataAPI.class, null);
        // table schema will be something like; <col1_name> <col1_type>,<col2_name> <col2_type>
        StringBuilder sb = new StringBuilder();
        AnalyticsSchema analyticsSchema = analyticsDataApi.getTableSchema(tenantId, path);
        Map<String, ColumnDefinition> columnsMap = analyticsSchema.getColumns();
        for (Iterator<Map.Entry<String, ColumnDefinition>> iterator = columnsMap.entrySet().iterator(); iterator
                .hasNext();) {
            Map.Entry<String, ColumnDefinition> column = iterator.next();
            sb.append(column.getKey() + " " + column.getValue().getType().name() + ",");
        }

        return sb.substring(0, sb.length() - 1);
    }

    public static String extractHeaderLine(String path, int tenantId)
            throws AnalyticsTableNotAvailableException, AnalyticsException {
        if (path == null) {
            return null;
        }

        AnalyticsDataAPI analyticsDataApi = (AnalyticsDataAPI) PrivilegedCarbonContext.getThreadLocalCarbonContext()
                .getOSGiService(AnalyticsDataAPI.class, null);
        // header line will be something like; <col1_name>,<col2_name>
        StringBuilder sb = new StringBuilder();
        AnalyticsSchema analyticsSchema = analyticsDataApi.getTableSchema(tenantId, path);
        Map<String, ColumnDefinition> columnsMap = analyticsSchema.getColumns();
        for (String columnName : columnsMap.keySet()) {
            sb.append(columnName + ",");
        }

        return sb.substring(0, sb.length() - 1);
    }

    public static class DataTypeFactory {
        public static CSVFormat getCSVFormat(String dataType) {
            if ("TSV".equalsIgnoreCase(dataType)) {
                return CSVFormat.TDF;
            }
            return CSVFormat.RFC4180;
        }
    }

    public static class ColumnSeparatorFactory {
        public static String getColumnSeparator(String dataType) {
            CSVFormat csvFormat = DataTypeFactory.getCSVFormat(dataType);
            return csvFormat.getDelimiter() + "";
        }
    }

    /**
     * Retrieve the indices of features where discard row imputaion is applied.
     *
     * @param workflow Machine learning workflow
     * @param imputeOption Impute option
     * @return Returns indices of features where discard row imputaion is applied
     */
    public static List<Integer> getImputeFeatureIndices(Workflow workflow, List<Integer> newToOldIndicesList,
            String imputeOption) {
        List<Integer> imputeFeatureIndices = new ArrayList<Integer>();
        for (Feature feature : workflow.getFeatures()) {
            if (feature.getImputeOption().equals(imputeOption) && feature.isInclude() == true) {
                int currentIndex = feature.getIndex();
                int newIndex = newToOldIndicesList.indexOf(currentIndex) != -1
                        ? newToOldIndicesList.indexOf(currentIndex)
                        : currentIndex;
                imputeFeatureIndices.add(newIndex);
            }
        }
        return imputeFeatureIndices;
    }

    /**
     * Retrieve the index of a feature in the dataset.
     *
     * @param feature Feature name
     * @param headerRow First row (header) in the data file
     * @param columnSeparator Column separator character
     * @return Index of the response variable
     */
    public static int getFeatureIndex(String feature, String headerRow, String columnSeparator) {
        int featureIndex = 0;
        String[] headerItems = headerRow.split(columnSeparator);
        for (int i = 0; i < headerItems.length; i++) {
            if (headerItems[i] != null) {
                String column = headerItems[i].replace("\"", "").trim();
                if (feature.equals(column)) {
                    featureIndex = i;
                    break;
                }
            }
        }
        return featureIndex;
    }

    /**
     * Retrieve the index of a feature in the dataset.
     */
    public static int getFeatureIndex(String featureName, List<Feature> features) {
        if (featureName == null || features == null) {
            return -1;
        }
        for (Feature feature : features) {
            if (featureName.equals(feature.getName())) {
                return feature.getIndex();
            }
        }
        return -1;
    }

    /**
     * @param workflow Workflow
     * @return A list of indices of features to be included in the model
     */
    public static SortedMap<Integer, String> getIncludedFeaturesAfterReordering(Workflow workflow,
            List<Integer> newToOldIndicesList, int responseIndex) {
        SortedMap<Integer, String> inlcudedFeatures = new TreeMap<Integer, String>();
        List<Feature> features = workflow.getFeatures();
        for (Feature feature : features) {
            if (feature.isInclude() == true && feature.getIndex() != responseIndex) {
                int currentIndex = feature.getIndex();
                int newIndex = newToOldIndicesList.indexOf(currentIndex);
                inlcudedFeatures.put(newIndex, feature.getName());
            }
        }
        return inlcudedFeatures;
    }

    /**
     * @param workflow Workflow
     * @return A list of indices of features to be included in the model
     */
    public static SortedMap<Integer, String> getIncludedFeatures(Workflow workflow, int responseIndex) {
        SortedMap<Integer, String> inlcudedFeatures = new TreeMap<Integer, String>();
        List<Feature> features = workflow.getFeatures();
        for (Feature feature : features) {
            if (feature.isInclude() == true && feature.getIndex() != responseIndex) {
                inlcudedFeatures.put(feature.getIndex(), feature.getName());
            }
        }
        return inlcudedFeatures;
    }

    /**
     * @param tenantId Tenant ID of the current user
     * @param datasetId ID of the datstet
     * @param userName Name of the current user
     * @param name Dataset name
     * @param version Dataset version
     * @param targetPath path of the stored data set
     * @return Dataset Version Object
     */
    public static MLDatasetVersion getMLDatsetVersion(int tenantId, long datasetId, String userName, String name,
            String version, String targetPath) {
        MLDatasetVersion valueSet = new MLDatasetVersion();
        valueSet.setTenantId(tenantId);
        valueSet.setDatasetId(datasetId);
        valueSet.setName(name);
        valueSet.setVersion(version);
        valueSet.setTargetPath(targetPath);
        valueSet.setUserName(userName);
        return valueSet;
    }

    /**
     * @return Current date and time
     */
    public static String getDate() {
        DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd_HH-mm-ss");
        Date date = new Date();
        return dateFormat.format(date);
    }

    /**
     * Get {@link Properties} from a list of {@link MLProperty}
     *
     * @param mlProperties list of {@link MLProperty}
     * @return {@link Properties}
     */
    public static Properties getProperties(List<MLProperty> mlProperties) {
        Properties properties = new Properties();
        for (MLProperty mlProperty : mlProperties) {
            if (mlProperty != null) {
                properties.put(mlProperty.getName(), mlProperty.getValue());
            }
        }
        return properties;

    }

    /**
     * @param inArray String array
     * @return Double array
     */
    public static double[] toDoubleArray(String[] inArray) {
        double[] outArray = new double[inArray.length];
        int idx = 0;
        for (String string : inArray) {
            outArray[idx] = Double.parseDouble(string);
            idx++;
        }

        return outArray;
    }

    public static Map<String, Integer> generateHeaderMap(int numberOfFeatures) {
        Map<String, Integer> headerMap = new HashMap<String, Integer>();
        for (int i = 1; i <= numberOfFeatures; i++) {
            headerMap.put("V" + i, i - 1);
        }
        return headerMap;
    }

    public static Map<String, Integer> generateHeaderMap(String line, CSVFormat format) {
        Map<String, Integer> headerMap = new HashMap<String, Integer>();
        String[] values = line.split("" + format.getDelimiter());
        int i = 0;
        for (String value : values) {
            headerMap.put(value, i);
            i++;
        }
        return headerMap;
    }

    public static int getFeatureSize(String line, CSVFormat format) {
        String[] values = line.split("" + format.getDelimiter());
        return values.length;
    }

    public static String[] getFeatures(String line, CSVFormat format) {
        String[] values = line.split("" + format.getDelimiter());
        return values;
    }

    /**
     * Applies the discard filter to a JavaRDD
     *
     * @param delimiter Column separator of the dataset
     * @param headerRow Header row
     * @param lines JavaRDD which contains the dataset
     * @param featureIndices Indices of the features to apply filter
     * @return filtered JavaRDD
     */
    public static JavaRDD<String[]> filterRows(String delimiter, String headerRow, JavaRDD<String> lines,
            List<Integer> featureIndices) {
        String columnSeparator = String.valueOf(delimiter);
        HeaderFilter headerFilter = new HeaderFilter.Builder().header(headerRow).build();
        JavaRDD<String> data = lines.filter(headerFilter).cache();
        Pattern pattern = MLUtils.getPatternFromDelimiter(columnSeparator);
        LineToTokens lineToTokens = new LineToTokens.Builder().separator(pattern).build();
        JavaRDD<String[]> tokens = data.map(lineToTokens).cache();

        // get feature indices for discard imputation
        DiscardedRowsFilter discardedRowsFilter = new DiscardedRowsFilter.Builder().indices(featureIndices).build();
        // Discard the row if any of the impute indices content have a missing value
        JavaRDD<String[]> tokensDiscardedRemoved = tokens.filter(discardedRowsFilter).cache();

        return tokensDiscardedRemoved;
    }

    /**
     * format an error message.
     */
    public static String getErrorMsg(String customMessage, Exception ex) {
        if (ex != null) {
            return customMessage + " Cause: " + ex.getClass().getName() + " - " + ex.getMessage();
        }
        return customMessage;
    }

    /**
     * Utility method to get key from value of a map.
     *
     * @param map Map to be searched for a key
     * @param value Value of the key
     */
    public static <T, E> T getKeyByValue(Map<T, E> map, E value) {
        for (Map.Entry<T, E> entry : map.entrySet()) {
            if (Objects.equals(value, entry.getValue())) {
                return entry.getKey();
            }
        }
        return null;
    }

    /**
     * Utility method to convert a String array to CSV/TSV row string.
     *
     * @param array String array to be converted
     * @param delimiter Delimiter to be used (comma for CSV tab for TSV)
     * @return CSV/TSV row string
     */
    public static String arrayToCsvString(String[] array, char delimiter) {
        StringBuilder arrayString = new StringBuilder();
        for (String arrayElement : array) {
            arrayString.append(arrayElement);
            arrayString.append(delimiter);
        }
        return arrayString.toString();
    }

    /**
     * Generates a pattern to represent CSV or TSV format.
     *
     * @param delimiter "," or "\t"
     * @return Pattern
     */
    public static Pattern getPatternFromDelimiter(String delimiter) {
        return Pattern.compile(delimiter + "(?=([^\"]*\"[^\"]*\")*(?![^\"]*\"))");
    }
}