Java tutorial
/* * Copyright (c) 2015, WSO2 Inc. (http://www.wso2.org) All Rights Reserved. * * WSO2 Inc. licenses this file to you under the Apache License, * Version 2.0 (the "License"); you may not use this file except * in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.wso2.carbon.notebook.core.util; import org.apache.commons.csv.CSVFormat; import org.apache.commons.lang.math.NumberUtils; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.wso2.carbon.analytics.dataservice.core.AnalyticsDataService; import org.wso2.carbon.analytics.datasource.commons.AnalyticsSchema; import org.wso2.carbon.analytics.datasource.commons.ColumnDefinition; import org.wso2.carbon.analytics.datasource.commons.exception.AnalyticsException; import org.wso2.carbon.analytics.datasource.commons.exception.AnalyticsTableNotAvailableException; import org.wso2.carbon.ml.commons.domain.Feature; import org.wso2.carbon.ml.commons.domain.FeatureType; import org.wso2.carbon.ml.commons.domain.SamplePoints; import org.wso2.carbon.ml.core.exceptions.MLMalformedDatasetException; import org.wso2.carbon.ml.core.spark.transformations.RowsToLines; import org.wso2.carbon.notebook.commons.constants.MLConstants; import org.wso2.carbon.notebook.core.ServiceHolder; import org.wso2.carbon.notebook.core.ml.transformation.HeaderFilter; import org.wso2.carbon.notebook.core.ml.transformation.LineToTokens; import java.util.*; import java.util.regex.Pattern; /** * Machine learner utility functions for the notebook */ public class MLUtils { /** * Generate a random sample of the data set using Spark. * * @param tableName Name of the table * @param sampleSize Sample size * @param tenantId Tenant ID * @return Sample points */ public static SamplePoints getSampleFromDAS(String tableName, int sampleSize, int tenantId) throws MLMalformedDatasetException { JavaSparkContext sparkContext; try { Map<String, Integer> headerMap; // List containing actual data of the sample. List<List<String>> columnData = new ArrayList<List<String>>(); // java spark context sparkContext = ServiceHolder.getSparkContextService().getJavaSparkContext(); JavaRDD<String> lines; String headerLine = extractHeaderLine(tableName, tenantId); headerMap = generateHeaderMap(headerLine, CSVFormat.RFC4180); // DAS case path = table name lines = getLinesFromDASTable(tableName, tenantId, sparkContext); return getSamplePoints(sampleSize, true, headerMap, columnData, CSVFormat.RFC4180, lines); } catch (Exception e) { throw new MLMalformedDatasetException( "Failed to extract the sample points from path: " + tableName + ". Cause: " + e, e); } } /** * Get the rows as lines of a table in the DAS * * @param tableName Name of the table * @param tenantId Tenant ID * @param sparkContext Java spark context * @return Table rows as lines */ public static JavaRDD<String> getLinesFromDASTable(String tableName, int tenantId, JavaSparkContext sparkContext) throws AnalyticsException { JavaRDD<String> lines; String tableSchema = extractTableSchema(tableName, tenantId); SQLContext sqlCtx = new SQLContext(sparkContext); sqlCtx.sql( "CREATE TEMPORARY TABLE ML_REF USING org.wso2.carbon.analytics.spark.core.sources.AnalyticsRelationProvider " + "OPTIONS (" + "tenantId \"" + tenantId + "\", " + "tableName \"" + tableName + "\", " + "schema \"" + tableSchema + "\"" + ")"); DataFrame dataFrame = sqlCtx.sql("select * from ML_REF"); // Additional auto-generated column "_timestamp" needs to be dropped because it is not in the schema. JavaRDD<Row> rows = dataFrame.drop("_timestamp").javaRDD(); lines = rows.map(new RowsToLines.Builder().separator(CSVFormat.RFC4180.getDelimiter() + "").build()); return lines; } /** * Get the cell values as String tokens from table lines * * @param dataFormat Data format of the lines * @param lines Table lines from which taken should be fetched * @return The string tokens of the table cell values */ private static JavaRDD<String[]> getTokensFromLines(CSVFormat dataFormat, JavaRDD<String> lines) { String columnSeparator = String.valueOf(dataFormat.getDelimiter()); HeaderFilter headerFilter = new HeaderFilter.Builder().init(lines.first()).build(); JavaRDD<String> data = lines.filter(headerFilter).cache(); Pattern pattern = getPatternFromDelimiter(columnSeparator); LineToTokens lineToTokens = new LineToTokens.Builder().separator(pattern).build(); JavaRDD<String[]> tokens = data.map(lineToTokens); // remove from cache data.unpersist(); return tokens; } /** * Take a sample from the data provided * * @param sampleSize Sample size * @param containsHeader Whether the header should be contained in the sample * @param headerMap Header map form column name to column data list index * @param columnData Empty list of values in the columns as a list to which items will be filled * @param dataFormat Data format of the lines provided * @param lines Rows as lines * @return */ private static SamplePoints getSamplePoints(int sampleSize, boolean containsHeader, Map<String, Integer> headerMap, List<List<String>> columnData, CSVFormat dataFormat, JavaRDD<String> lines) { // take the first line String firstLine = lines.first(); // count the number of features int featureSize = getFeatureSize(firstLine, dataFormat); int[] missing = new int[featureSize]; int[] stringCellCount = new int[featureSize]; int[] decimalCellCount = new int[featureSize]; JavaRDD<String[]> tokens = getTokensFromLines(dataFormat, lines); tokens.cache(); if (sampleSize >= 0 && featureSize > 0) { sampleSize = sampleSize / featureSize; } for (int i = 0; i < featureSize; i++) { columnData.add(new ArrayList<String>()); } if (headerMap == null) { // generate the header map if (containsHeader) { headerMap = generateHeaderMap(lines.first(), dataFormat); } else { headerMap = generateHeaderMap(featureSize); } } // take a random sample List<String[]> sampleLines = tokens.takeSample(false, sampleSize); // remove from cache tokens.unpersist(); // iterate through sample lines for (String[] columnValues : sampleLines) { for (int currentCol = 0; currentCol < featureSize; currentCol++) { // Check whether the row is complete. if (currentCol < columnValues.length) { // Append the cell to the respective column. columnData.get(currentCol).add(columnValues[currentCol]); if (MLConstants.MISSING_VALUES.contains(columnValues[currentCol])) { // If the cell is empty, increase the missing value count. missing[currentCol]++; } else { // check whether a column value is a string if (!NumberUtils.isNumber(columnValues[currentCol])) { stringCellCount[currentCol]++; } else if (columnValues[currentCol].indexOf('.') != -1) { // if it is a number and has the decimal point decimalCellCount[currentCol]++; } } } else { columnData.get(currentCol).add(null); missing[currentCol]++; } } } SamplePoints samplePoints = new SamplePoints(); samplePoints.setHeader(headerMap); samplePoints.setSamplePoints(columnData); samplePoints.setMissing(missing); samplePoints.setStringCellCount(stringCellCount); samplePoints.setDecimalCellCount(decimalCellCount); return samplePoints; } /** * Get the table schema of the table specified * * @param tableName Name of the table * @param tenantId Tenant ID * @return Table schema as string */ public static String extractTableSchema(String tableName, int tenantId) throws AnalyticsException { if (tableName == null) { return null; } AnalyticsDataService analyticsDataApi = ServiceHolder.getAnalyticsDataService(); // table schema will be something like; <col1_name> <col1_type>,<col2_name> <col2_type> StringBuilder sb = new StringBuilder(); AnalyticsSchema analyticsSchema = analyticsDataApi.getTableSchema(tenantId, tableName); Map<String, ColumnDefinition> columnsMap = analyticsSchema.getColumns(); for (Iterator<Map.Entry<String, ColumnDefinition>> iterator = columnsMap.entrySet().iterator(); iterator .hasNext();) { Map.Entry<String, ColumnDefinition> column = iterator.next(); sb.append(column.getKey() + " " + column.getValue().getType().name() + ","); } return sb.substring(0, sb.length() - 1); } /** * Header line of the table sepcified * * @param tableName Name of the table * @param tenantId Tenant ID * @return Header line */ public static String extractHeaderLine(String tableName, int tenantId) throws AnalyticsTableNotAvailableException, AnalyticsException { if (tableName == null) { return null; } AnalyticsDataService analyticsDataService = ServiceHolder.getAnalyticsDataService(); // header line will be something like; <col1_name>,<col2_name> StringBuilder sb = new StringBuilder(); AnalyticsSchema analyticsSchema = analyticsDataService.getTableSchema(tenantId, tableName); Map<String, ColumnDefinition> columnsMap = analyticsSchema.getColumns(); for (String columnName : columnsMap.keySet()) { sb.append(columnName + ","); } return sb.substring(0, sb.length() - 1); } /** * Retrieve the indices of features where discard row imputaion is applied. * * @param features The list of features of the dataset * @param imputeOption Impute option * @return Returns indices of features where discard row imputaion is applied */ public static List<Integer> getImputeFeatureIndices(List<Feature> features, List<Integer> newToOldIndicesList, String imputeOption) { List<Integer> imputeFeatureIndices = new ArrayList<Integer>(); for (Feature feature : features) { if (feature.getImputeOption().equals(imputeOption) && feature.isInclude() == true) { int currentIndex = feature.getIndex(); int newIndex = newToOldIndicesList.indexOf(currentIndex) != -1 ? newToOldIndicesList.indexOf(currentIndex) : currentIndex; imputeFeatureIndices.add(newIndex); } } return imputeFeatureIndices; } /** * Retrieve the index of a feature in the dataset. * * @param feature Feature name * @param headerRow First row (header) in the data file * @param columnSeparator ColumnDefinition separator character * @return Index of the response variable */ public static int getFeatureIndex(String feature, String headerRow, String columnSeparator) { int featureIndex = 0; String[] headerItems = headerRow.split(columnSeparator); for (int i = 0; i < headerItems.length; i++) { if (headerItems[i] != null) { String column = headerItems[i].replace("\"", "").trim(); if (feature.equals(column)) { featureIndex = i; break; } } } return featureIndex; } /** * Get the indices of the included features * * @param features list of features of the dataset * @return A list of indices of features to be included after processed */ public static List<Integer> getIncludedFeatureIndices(List<Feature> features) { List<Integer> includedFeatureIndices = new ArrayList<Integer>(); for (Feature feature : features) { if (feature.isInclude()) { includedFeatureIndices.add(feature.getIndex()); } } return includedFeatureIndices; } /** * Generate the header map without column names * * @param numberOfFeatures Number of features for which the header map is created * @return Header map */ public static Map<String, Integer> generateHeaderMap(int numberOfFeatures) { Map<String, Integer> headerMap = new HashMap<String, Integer>(); for (int i = 1; i <= numberOfFeatures; i++) { headerMap.put("V" + i, i - 1); } return headerMap; } /** * Generate the header map with column names * * @param line Lines of the table for which the header map is created * @param format Data format of the lines * @return Header map */ public static Map<String, Integer> generateHeaderMap(String line, CSVFormat format) { Map<String, Integer> headerMap = new HashMap<String, Integer>(); String[] values = line.split("" + format.getDelimiter()); int i = 0; for (String value : values) { headerMap.put(value, i); i++; } return headerMap; } /** * Get the number of features in the lines * * @param line Lines of the table * @param format Data format of the lines * @return Number of features */ public static int getFeatureSize(String line, CSVFormat format) { String[] values = line.split("" + format.getDelimiter()); return values.length; } /** * Generates a pattern to represent CSV or TSV format. * * @param delimiter "," or "\t" * @return Pattern */ public static Pattern getPatternFromDelimiter(String delimiter) { return Pattern.compile(delimiter + "(?=([^\"]*\"[^\"]*\")*(?![^\"]*\"))"); } /** * Get the column types from the sample points : Categorical or Numerical * * @param samplePoints Sample points * @return Column types list */ public static String[] getColumnTypes(SamplePoints samplePoints) { Map<String, Integer> headerMap = samplePoints.getHeader(); int[] stringCellCount = samplePoints.getStringCellCount(); int[] decimalCellCount = samplePoints.getDecimalCellCount(); String[] type = new String[headerMap.size()]; List<Integer> numericDataColumnPositions = new ArrayList<Integer>(); // If at least one cell contains strings, then the column is considered to has string data. for (int col = 0; col < headerMap.size(); col++) { if (stringCellCount[col] > 0) { type[col] = FeatureType.CATEGORICAL; } else { numericDataColumnPositions.add(col); type[col] = FeatureType.NUMERICAL; } } List<List<String>> columnData = samplePoints.getSamplePoints(); // Marking categorical data encoded as numerical data as categorical features for (int currentCol = 0; currentCol < headerMap.size(); currentCol++) { if (numericDataColumnPositions.contains(currentCol)) { // Create a unique set from the column. List<String> data = columnData.get(currentCol); // Check whether it is an empty column // Rows with missing values are not filtered. Therefore it is possible to // have all rows in sample with values missing in a column. if (data.size() == 0) { continue; } Set<String> uniqueSet = new HashSet<String>(data); int multipleOccurences = 0; for (String uniqueValue : uniqueSet) { int frequency = Collections.frequency(data, uniqueValue); if (frequency > 1) { multipleOccurences++; } } // if a column has at least one decimal value, then it can't be categorical. // if a feature has more than X% of repetitive distinct values, then that feature can be a categorical // one. X = categoricalThreshold if (decimalCellCount[currentCol] == 0 && (multipleOccurences / uniqueSet.size()) * 100 >= MLConstants.CATEGORICAL_THRESHOLD) { type[currentCol] = FeatureType.CATEGORICAL; } } } return type; } /** * Factory for getting the data type from the data type name */ public static class DataTypeFactory { public static CSVFormat getCSVFormat(String dataType) { if ("TSV".equalsIgnoreCase(dataType)) { return CSVFormat.TDF; } return CSVFormat.RFC4180; } } /** * Factory for getting the column separator from data type name */ public static class ColumnSeparatorFactory { public static String getColumnSeparator(String dataType) { CSVFormat csvFormat = DataTypeFactory.getCSVFormat(dataType); return csvFormat.getDelimiter() + ""; } } }