us.parr.animl.data.DataTable.java Source code

Java tutorial

Introduction

Here is the source code for us.parr.animl.data.DataTable.java

Source

/*
 * Copyright (c) 2017 Terence Parr. All rights reserved.
 * Use of this file is governed by the BSD 3-clause license that
 * can be found in the LICENSE file in the project root.
 */

package us.parr.animl.data;

import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.apache.commons.io.input.BOMInputStream;
import org.apache.commons.lang3.StringUtils;
import sun.misc.FloatingDecimal;
import us.parr.lib.ParrtCollections;
import us.parr.lib.ParrtStats;
import us.parr.lib.collections.CountingDenseIntSet;
import us.parr.lib.collections.CountingSet;
import us.parr.lib.collections.DenseIntSet;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.Spliterator;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.regex.Pattern;

import static java.util.Collections.max;
import static us.parr.animl.data.DataTable.VariableFormat.CENTER;
import static us.parr.animl.data.DataTable.VariableFormat.RIGHT;
import static us.parr.animl.data.DataTable.VariableType.CATEGORICAL_INT;
import static us.parr.animl.data.DataTable.VariableType.CATEGORICAL_STRING;
import static us.parr.animl.data.DataTable.VariableType.INVALID;
import static us.parr.animl.data.DataTable.VariableType.NUMERICAL_FLOAT;
import static us.parr.animl.data.DataTable.VariableType.NUMERICAL_INT;
import static us.parr.animl.data.DataTable.VariableType.TARGET_CATEGORICAL_INT;
import static us.parr.animl.data.DataTable.VariableType.TARGET_CATEGORICAL_STRING;
import static us.parr.animl.data.DataTable.VariableType.UNUSED_FLOAT;
import static us.parr.animl.data.DataTable.VariableType.UNUSED_INT;
import static us.parr.animl.data.DataTable.VariableType.UNUSED_STRING;
import static us.parr.lib.ParrtCollections.indexOf;
import static us.parr.lib.ParrtCollections.join;
import static us.parr.lib.ParrtCollections.map;

public class DataTable implements Iterable<int[]> {
    // 9.466524720191955566e-01
    public static final Pattern floatPattern = Pattern.compile("^-?[0-9]+\\.[0-9]*|\\.[0-9]+[eE][+-][0-9]+$");
    public static final Pattern intPattern = Pattern.compile("^-?[0-9]+$");

    /** Input sometimes has NA or blanks for unknown values */
    public static final Set<String> UNKNOWN_VALUE_STRINGS = new HashSet<String>() {
        {
            add("");
            add("NA");
            add("N/A");
        }
    };

    public enum VariableType {
        CATEGORICAL_INT, CATEGORICAL_STRING, NUMERICAL_INT, NUMERICAL_FLOAT, TARGET_CATEGORICAL_INT, TARGET_CATEGORICAL_STRING, UNUSED_INT, UNUSED_FLOAT, UNUSED_STRING, INVALID
    }

    public enum VariableFormat {
        LEFT, CENTER, RIGHT
    }

    public static final String[] varTypeShortNames = new String[VariableType.values().length];
    public static final VariableFormat[] defaultVarFormats = new VariableFormat[VariableType.values().length];
    static {
        varTypeShortNames[CATEGORICAL_INT.ordinal()] = "cat";
        varTypeShortNames[CATEGORICAL_STRING.ordinal()] = "string";
        varTypeShortNames[NUMERICAL_INT.ordinal()] = "int";
        varTypeShortNames[NUMERICAL_FLOAT.ordinal()] = "float";
        varTypeShortNames[TARGET_CATEGORICAL_INT.ordinal()] = "target";
        varTypeShortNames[TARGET_CATEGORICAL_STRING.ordinal()] = "target-string";
        varTypeShortNames[UNUSED_INT.ordinal()] = "unused";
        varTypeShortNames[UNUSED_FLOAT.ordinal()] = "unused";
        varTypeShortNames[UNUSED_STRING.ordinal()] = "unused";

        defaultVarFormats[CATEGORICAL_INT.ordinal()] = RIGHT;
        defaultVarFormats[CATEGORICAL_STRING.ordinal()] = CENTER;
        defaultVarFormats[NUMERICAL_INT.ordinal()] = RIGHT;
        defaultVarFormats[NUMERICAL_FLOAT.ordinal()] = RIGHT;
        defaultVarFormats[TARGET_CATEGORICAL_INT.ordinal()] = RIGHT;
        defaultVarFormats[TARGET_CATEGORICAL_STRING.ordinal()] = CENTER;
        defaultVarFormats[UNUSED_INT.ordinal()] = RIGHT;
        defaultVarFormats[UNUSED_FLOAT.ordinal()] = RIGHT;
        defaultVarFormats[UNUSED_STRING.ordinal()] = CENTER;
    }

    // TODO: this should be int[j][i] stored in columnar form; first index is the column then it goes down rows in that column
    protected List<int[]> rows;
    protected String[] colNames;
    protected VariableType[] colTypes;
    protected StringTable[] colStringToIntMap;
    protected int[] colMaxes;

    protected Set<Integer> cachedPredictionCategories;
    protected int cachedMaxPredictionCategoryValue = -1;

    public DataTable() {
    }

    public DataTable(List<int[]> rows, VariableType[] colTypes, String[] colNames, int[] colMaxes) {
        this(rows, colTypes, colNames, colMaxes, null);
    }

    public DataTable(List<int[]> rows, VariableType[] colTypes, String[] colNames, int[] colMaxes,
            StringTable[] colStringToIntMap) {
        this.rows = rows;
        this.colMaxes = colMaxes;
        this.colNames = colNames;
        this.colTypes = colTypes;
        this.colStringToIntMap = colStringToIntMap;
        if (this.colMaxes == null) {
            computeColMaxes();
        }
    }

    public static DataTable empty(VariableType[] colTypes, String[] colNames) {
        return new DataTable(new ArrayList<>(), colTypes, colNames, null, null);
    }

    /** Make a new table from an old table with a subset of rows */
    public DataTable(DataTable old, List<int[]> rows) {
        this(rows, old.colTypes, old.colNames, old.colMaxes, old.colStringToIntMap);
    }

    /** Make a new table from an old table with shallow copy of rows */
    public DataTable(DataTable old) {
        this.rows = new ArrayList<>(old.rows.size());
        this.rows.addAll(old.rows);
        this.colNames = old.colNames;
        System.arraycopy(old.colMaxes, 0, this.colMaxes, 0, old.colMaxes.length);
        this.colTypes = old.colTypes;
        this.colStringToIntMap = old.colStringToIntMap;
    }

    public static DataTable fromInts(List<int[]> rows, VariableType[] colTypes, String[] colNames) {
        if (rows == null)
            return empty(colTypes, colNames);
        if (rows.size() == 0 && colTypes == null) {
            return empty(colTypes, colNames);
        }

        int dim = rows.size() > 0 ? rows.get(0).length : colTypes.length;
        if (colTypes == null) {
            colTypes = getDefaultColTypes(dim);
        }
        if (colNames == null) {
            colNames = getDefaultColNames(colTypes, dim);
        }
        return new DataTable(rows, colTypes, colNames, null);
    }

    public static DataTable fromStrings(List<String[]> rows) {
        if (rows == null)
            return empty(null, null);
        if (rows.size() == 0) {
            return empty(null, null);
        }
        String[] headerRow = rows.get(0);
        if (headerRow == null) {
            return empty(null, null);
        }
        int numCols = headerRow.length;
        if (rows.size() == 1) { // just header row?
            return empty(null, headerRow);
        }

        rows = rows.subList(1, rows.size()); // don't use first row.

        VariableType[] actualTypes = computeColTypes(rows, numCols);

        return fromStrings(rows, actualTypes, headerRow, false);
    }

    public static DataTable fromStrings(List<String[]> rows, VariableType[] colTypes, String[] colNames,
            boolean hasHeaderRow) {
        if (rows == null || rows.size() == 0)
            return empty(colTypes, colNames);
        if (rows.size() == 1 && hasHeaderRow) {
            return empty(colTypes, colNames);
        }

        if (hasHeaderRow && colNames == null) {
            colNames = rows.get(0);
        }

        int dim = rows.get(0).length;
        if (colTypes == null) {
            colTypes = getDefaultColTypes(dim);
        }
        if (colNames == null) {
            colNames = getDefaultColNames(colTypes, dim);
        }
        StringTable[] colStringToIntMap = new StringTable[colTypes.length];
        // don't waste space on string tables unless we need to
        for (int j = 0; j < colTypes.length; j++) {
            if (colTypes[j] == CATEGORICAL_STRING || colTypes[j] == TARGET_CATEGORICAL_STRING) {
                colStringToIntMap[j] = new StringTable();
            }
        }
        // process strings into ints using appropriate conversion
        List<int[]> rows2 = new ArrayList<>();
        for (int i = hasHeaderRow ? 1 : 0; i < rows.size(); i++) {
            String[] row = rows.get(i);
            int[] rowAsInts = new int[row.length];
            for (int j = 0; j < row.length; j++) {
                int col = 0;
                VariableType colType = colTypes[j];
                String colValue = row[j];
                switch (colType) {
                case CATEGORICAL_INT:
                case NUMERICAL_INT:
                case UNUSED_INT:
                case TARGET_CATEGORICAL_INT:
                    if (!UNKNOWN_VALUE_STRINGS.contains(row[j])) {
                        col = Integer.valueOf(colValue);
                    }
                    break;
                case CATEGORICAL_STRING:
                case TARGET_CATEGORICAL_STRING:
                case UNUSED_STRING:
                    if (!UNKNOWN_VALUE_STRINGS.contains(row[j])) {
                        col = colStringToIntMap[j].add(colValue);
                    }
                    break;
                case NUMERICAL_FLOAT:
                case UNUSED_FLOAT:
                    if (!UNKNOWN_VALUE_STRINGS.contains(row[j])) {
                        col = Float.floatToIntBits(Float.valueOf(colValue));
                    }
                    break;
                }
                rowAsInts[j] = col;
            }
            rows2.add(rowAsInts);
        }
        DataTable t = new DataTable(rows2, colTypes, colNames, null);
        t.colStringToIntMap = colStringToIntMap;
        return t;
    }

    public static DataTable loadCSV(String fileName, String formatType, VariableType[] colTypesOverride,
            String[] colNamesOverride, boolean hasHeaderRow) {
        try {
            // use apache commons io + csv to load but convert to list of String[]
            // byte-order markers are handled if present at start of file.
            FileInputStream fis = new FileInputStream(fileName);
            final Reader reader = new InputStreamReader(new BOMInputStream(fis), "UTF-8");
            CSVFormat format;
            if (formatType == null) {
                format = hasHeaderRow ? CSVFormat.RFC4180.withHeader() : CSVFormat.RFC4180;
            } else {
                switch (formatType.toLowerCase()) {
                case "tsv":
                    format = hasHeaderRow ? CSVFormat.TDF.withHeader() : CSVFormat.TDF;
                    break;
                case "mysql":
                    format = hasHeaderRow ? CSVFormat.MYSQL.withHeader() : CSVFormat.MYSQL;
                    break;
                case "excel":
                    format = hasHeaderRow ? CSVFormat.EXCEL.withHeader() : CSVFormat.EXCEL;
                    break;
                case "rfc4180":
                default:
                    format = hasHeaderRow ? CSVFormat.RFC4180.withHeader() : CSVFormat.RFC4180;
                    break;
                }
            }
            final CSVParser parser = new CSVParser(reader, format);
            List<String[]> rows = new ArrayList<>();
            int numHeaderNames = parser.getHeaderMap().size();
            try {
                for (final CSVRecord record : parser) {
                    String[] row = new String[record.size()];
                    for (int j = 0; j < record.size(); j++) {
                        row[j] = record.get(j);
                    }
                    rows.add(row);
                }
            } finally {
                parser.close();
                reader.close();
            }

            VariableType[] actualTypes = computeColTypes(rows, numHeaderNames);

            Set<String> colNameSet = parser.getHeaderMap().keySet();
            String[] colNames = colNameSet.toArray(new String[colNameSet.size()]);
            if (colNamesOverride != null) {
                colNames = colNamesOverride;
            }
            if (colTypesOverride != null) {
                actualTypes = colTypesOverride;
            }
            return fromStrings(rows, actualTypes, colNames, false);
        } catch (Exception e) {
            throw new IllegalArgumentException("Can't open and/or read " + fileName, e);
        }
    }

    public static DataTable loadCSV(String fileName, VariableType[] colTypes, boolean hasHeaderRow) {
        int numCols = colTypes.length;
        try {
            final FileInputStream fis = new FileInputStream(fileName);
            final Reader r = new InputStreamReader(new BOMInputStream(fis), "UTF-8");
            final BufferedReader bf = new BufferedReader(r);
            List<int[]> rows = new ArrayList<>();
            String line;
            String[] colNames = null;
            if (hasHeaderRow) {
                line = bf.readLine();
                if (line != null) {
                    line = line.trim();
                    if (line.length() > 0) {
                        colNames = line.split(",");
                        for (int i = 0; i < colNames.length; i++) {
                            colNames[i] = colNames[i].trim();
                        }
                    }
                }
            }
            int n = 0;
            while ((line = bf.readLine()) != null) {
                if (n > 0 && n % 10000 == 0)
                    System.out.println(n);
                line = line.trim();
                if (line.length() == 0)
                    continue;
                int[] row = new int[numCols];
                int comma = line.indexOf(',', 0);
                int prev = 0;
                int col = 0;
                while (comma >= 0) {
                    String v = line.substring(prev, comma);
                    row[col] = getValue(colTypes[col], v);

                    prev = comma + 1;
                    comma = line.indexOf(',', comma + 1);
                    col++;
                }
                // grab last element after last comma
                String lastv = line.substring(prev, line.length());
                row[col] = getValue(colTypes[col], lastv);

                //            System.out.println();
                rows.add(row);
                n++;
            }

            DataTable data = new DataTable(rows, colTypes, colNames, null);
            return data;
        } catch (IOException ioe) {
            throw new IllegalArgumentException("Can't open and/or read " + fileName, ioe);
        }
    }

    protected static int getValue(VariableType colType, String v) {
        switch (colType) {
        case NUMERICAL_FLOAT:
            return Float.floatToIntBits(FloatingDecimal.parseFloat(v));
        //                  System.out.print(Float.intBitsToFloat(row[col]));
        case NUMERICAL_INT:
        case TARGET_CATEGORICAL_INT:
            return Integer.valueOf(v);
        //                  System.out.print(row[col]);
        default:
            throw new UnsupportedOperationException("can't handle strings yet");
        }
    }

    protected static VariableType[] computeColTypes(List<String[]> rows, int numCols) {
        VariableType[] actualTypes = new VariableType[numCols];
        for (int j = 0; j < numCols; j++) {
            actualTypes[j] = INVALID;
        }
        for (String[] row : rows) {
            for (int j = 0; j < numCols; j++) {
                if (intPattern.matcher(row[j]).find()) {
                    if (actualTypes[j] == INVALID) { // only choose int if first type seen
                        actualTypes[j] = NUMERICAL_INT;
                    }
                } else if (floatPattern.matcher(row[j]).find()) { // let int become float but not vice versa
                    if (actualTypes[j] == INVALID || actualTypes[j] == NUMERICAL_INT) {
                        actualTypes[j] = NUMERICAL_FLOAT;
                    }
                } else { // anything else is a string
                    if (!UNKNOWN_VALUE_STRINGS.contains(row[j])) { // if NA, N/A don't know type
                        // if we ever see a string, convert and don't change back
                        if (actualTypes[j] == INVALID || actualTypes[j] == NUMERICAL_INT) {
                            if (j == row.length - 1) { // assume last column is predicted var
                                actualTypes[j] = TARGET_CATEGORICAL_STRING;
                            } else {
                                actualTypes[j] = CATEGORICAL_STRING;
                            }
                        }
                    }
                }
            }
        }
        return actualTypes;
    }

    public void computeColMaxes() {
        if (colTypes == null)
            return;
        ;
        this.colMaxes = new int[colTypes.length];
        for (int j = 0; j < getNumberOfColumns(); j++) {
            VariableType colType = colTypes[j];
            int max = 0;
            for (int i = 0; i < size(); i++) {
                int[] row = getRow(i);
                if (compare(row[j], max, colType) == 1) {
                    max = row[j];
                }
            }
            colMaxes[j] = max;
        }
    }

    public int getMaxPredictionCategoryValue() {
        if (cachedMaxPredictionCategoryValue == -1) {
            cachedMaxPredictionCategoryValue = max(getPredictionCategories());
        }
        return cachedMaxPredictionCategoryValue;
    }

    public Set<Integer> getPredictionCategories() {
        if (cachedPredictionCategories == null) {
            cachedPredictionCategories = getUniqueValues(getPredictedCol());
        }
        return cachedPredictionCategories;
    }

    public Set<Integer> getUniqueValues(int colIndex) {
        DenseIntSet values = new DenseIntSet(colMaxes[colIndex]);
        for (int i = 0; i < size(); i++) { // for each row, count different values for col splitVariable
            values.add(getAsInt(i, colIndex)); // pretend everything is an int
        }
        return values;
    }

    public CountingSet<Integer> getColValueCounts(int colIndex) {
        CountingDenseIntSet values = new CountingDenseIntSet(colMaxes[colIndex]);
        for (int i = 0; i < size(); i++) { // for each row, count different values for col splitVariable
            values.add(getAsInt(i, colIndex)); // pretend everything is an int
        }
        return values;
    }

    public int[] getColValues(int colIndex) {
        int[] values = new int[size()];
        for (int i = 0; i < size(); i++) {
            values[i] = getAsInt(i, colIndex);
        }
        return values;
    }

    public DataTable filter(Predicate<int[]> pred) {
        List<int[]> filtered = ParrtCollections.filter(rows, pred);
        return new DataTable(this, filtered);
    }

    public double entropy(int colIndex) {
        CountingSet<Integer> valueCounts = valueCountsInColumn(colIndex);
        return valueCounts.entropy();
    }

    public List<Integer> getSubsetOfVarIndexes(int m, Random random) {
        // create set of all predictor vars
        List<Integer> indexes = new ArrayList<>(colTypes.length);
        for (int i = 0; i < colTypes.length; i++) {
            if (isPredictorVar(colTypes[i])) {
                indexes.add(i);
            }
        }
        int M = indexes.size(); // number of usable predictor variables M
        if (m <= 0)
            m = M;
        if (m > M)
            m = M;
        if (m == M) {
            // don't bother to shuffle then sort
            return indexes;
        }
        if (random == null) {
            random = new Random();
        }
        Collections.shuffle(indexes, random);
        indexes = indexes.subList(0, m);
        Collections.sort(indexes);
        return indexes;
    }

    /** Partition rows in-place per splitVariable and splitCategory. Use
     *  left/right cursors moving in from edges until they cross. Return
     *  the index of the first element not in category. Everthing to left
     *  is == to splitCategory and everthing >= that index is != splitCategory.
     *
     *  https://en.wikipedia.org/wiki/Quicksort#Hoare_partition_scheme
     *
     *  says
     *
     *  "The original partition scheme described by C.A.R. Hoare uses two indices that
     *  start at the ends of the array being partitioned, then move toward each other,
     *  until they detect an inversion: a pair of elements, one greater or equal than
     *  the pivot, one lesser or equal, that are in the wrong order relative to each
     *  other. The inverted elements are then swapped.[16] When the indices meet,
     *  the algorithm stops and returns the final index."
     */
    public static int categoricalPartition(List<int[]> rows, int splitVariable, int splitCategory, int low,
            int high) {
        int i = low - 1;
        int j = high + 1;
        int n = rows.size();
        while (true) {
            do {
                i++;
            } while (i < n && rows.get(i)[splitVariable] == splitCategory);
            do {
                j--;
            } while (j >= 0 && rows.get(j)[splitVariable] != splitCategory);
            if (i >= j) {
                return i;
            }
            // swap elements at i and j
            int[] savei = rows.get(i);
            rows.set(i, rows.get(j));
            rows.set(j, savei);
        }
    }

    public static int numericalFloatPartition(List<int[]> rows, int splitVariable, double splitValue, int low,
            int high) {
        int i = low - 1;
        int j = high + 1;
        int n = rows.size();
        while (true) {
            do {
                i++;
            } while (i < n && Float.intBitsToFloat(rows.get(i)[splitVariable]) < splitValue);
            do {
                j--;
            } while (j >= 0 && Float.intBitsToFloat(rows.get(j)[splitVariable]) >= splitValue);
            if (i >= j) {
                return i;
            }
            // swap elements at i and j
            int[] savei = rows.get(i);
            rows.set(i, rows.get(j));
            rows.set(j, savei);
        }
    }

    public static int numericalIntPartition(List<int[]> rows, int splitVariable, double splitValue, int low,
            int high) {
        int i = low - 1;
        int j = high + 1;
        int n = rows.size();
        while (true) {
            do {
                i++;
            } while (i < n && rows.get(i)[splitVariable] < splitValue);
            do {
                j--;
            } while (j >= 0 && rows.get(j)[splitVariable] >= splitValue);
            if (i >= j) {
                return i;
            }
            // swap elements at i and j
            int[] savei = rows.get(i);
            rows.set(i, rows.get(j));
            rows.set(j, savei);
        }
    }

    /** Return new table with [i1..i2] inclusive in new table */
    public DataTable subset(int i1, int i2) {
        return new DataTable(this, rows.subList(i1, i2 + 1));
    }

    /** Return new table with all data except [i1..i2] inclusive in new table */
    public DataTable subsetNot(int i1, int i2) {
        List<int[]> missingChunk = new ArrayList<>();
        for (int i = 0; i < i1; i++) {
            missingChunk.add(rows.get(i));
        }
        for (int i = i2 + 1; i < rows.size(); i++) {
            missingChunk.add(rows.get(i));
        }
        return new DataTable(this, missingChunk);
    }

    /** Return new table with row i missing from table; makes shallow copy to do so. */
    public DataTable subsetNot(int i) {
        List<int[]> lessOne = new ArrayList<>();
        lessOne.addAll(rows);
        lessOne.remove(i);
        return new DataTable(this, lessOne);
    }

    /** Get a random subset of size n from the rows (with replacement)
     *  and return new DataTable.
     */
    public DataTable randomSubset(int n) {
        return new DataTable(this, ParrtStats.bootstrapWithRepl(this.rows, n));
    }

    public int getNumberOfPredictorVar() {
        return getSubsetOfVarIndexes(getNumberOfColumns(), null).size();
    }

    public int getNumberOfColumns() {
        return colTypes.length;
    }

    public static boolean isPredictorVar(VariableType colType) {
        return !(colType == UNUSED_INT || colType == UNUSED_FLOAT || colType == UNUSED_STRING
                || colType == TARGET_CATEGORICAL_INT || colType == TARGET_CATEGORICAL_STRING);
    }

    public static boolean isCategoricalVar(VariableType colType) {
        return colType == DataTable.VariableType.CATEGORICAL_INT
                || colType == DataTable.VariableType.CATEGORICAL_STRING;
    }

    /** Create a set that counts how many of each value in colIndex there is. Only
     *  works on int-valued columns.
     */
    public CountingSet<Integer> valueCountsInColumn(int colIndex) {
        CountingSet<Integer> valueCounts = new CountingDenseIntSet(colMaxes[colIndex]);
        if (!(colTypes[colIndex] == NUMERICAL_INT || colTypes[colIndex] == CATEGORICAL_INT
                || colTypes[colIndex] == CATEGORICAL_STRING || colTypes[colIndex] == TARGET_CATEGORICAL_INT
                || colTypes[colIndex] == TARGET_CATEGORICAL_STRING)) {
            throw new IllegalArgumentException(
                    colNames[colIndex] + " is not an int-based column; type is " + colTypes[colIndex]);
        }
        for (int i = 0; i < size(); i++) { // for each row, count different values for col splitVariable
            int[] row = getRow(i);
            int col = row[colIndex];
            valueCounts.add(col);
        }
        return valueCounts;
    }

    public void sortBy(int colIndex) {
        switch (colTypes[colIndex]) {
        case CATEGORICAL_INT:
        case NUMERICAL_INT:
        case CATEGORICAL_STRING: // strings are encoded as ints
        case TARGET_CATEGORICAL_STRING:
        case TARGET_CATEGORICAL_INT:
        case UNUSED_INT:
        case UNUSED_STRING:
            Collections.sort(rows, (ra, rb) -> {
                return Integer.compare(ra[colIndex], rb[colIndex]);
            });
            break;
        case NUMERICAL_FLOAT:
        case UNUSED_FLOAT:
            Collections.sort(rows, (ra, rb) -> {
                return Float.compare(Float.intBitsToFloat(ra[colIndex]), Float.intBitsToFloat(rb[colIndex]));
            });
            break;
        }
    }

    public void shuffle(Random random) {
        Collections.shuffle(rows, random);
    }

    public int size() {
        return rows.size();
    }

    /** Return the data[i,j] item as an appropriate object: Integer, Float, String */
    public Object get(int i, int j) {
        return getValue(i, j);
    }

    public int getAsInt(int i, int j) {
        return rows.get(i)[j];
    }

    public float getAsFloat(int i, int j) {
        return getAsFloat(rows.get(i)[j]);
    }

    public static float getAsFloat(int a) {
        return Float.intBitsToFloat(a);
    }

    public int[] getRow(int i) {
        return rows.get(i);
    }

    public void removeRow(int i) {
        rows.remove(i);
    }

    public List<int[]> getRows() {
        return rows;
    }

    public String[] getColNames() {
        return colNames;
    }

    public VariableType[] getColTypes() {
        return colTypes;
    }

    public void setColTypes(VariableType[] colTypes) {
        this.colTypes = colTypes;
    }

    public void setColType(int colIndex, VariableType colType) {
        this.colTypes[colIndex] = colType;
    }

    public void setColType(String colName, VariableType colType) {
        int j = indexOf(colNames, colName);
        if (j >= 0 && j < colTypes.length) {
            this.colTypes[j] = colType;
        } else {
            throw new IllegalArgumentException("Column " + colName + " unknown");
        }
    }

    public Number getColMax(int j) {
        if (colTypes[j] == NUMERICAL_FLOAT) {
            return getAsFloat(colMaxes[j]);
        }
        return colMaxes[j];
    }

    public Object getValue(int rowi, int colj) {
        int[] row = this.rows.get(rowi);
        return getValue(this, row[colj], colj);
    }

    /** Return an object representing the true value of 'value'
     *  relative to colj in table 'data'.
     */
    public static Object getValue(DataTable data, int value, int colj) {
        switch (data.colTypes[colj]) {
        case CATEGORICAL_INT:
        case NUMERICAL_INT:
        case TARGET_CATEGORICAL_INT:
        case UNUSED_INT:
            return value;
        case CATEGORICAL_STRING:
        case TARGET_CATEGORICAL_STRING:
        case UNUSED_STRING:
            return data.colStringToIntMap[colj].get(value);
        case NUMERICAL_FLOAT:
        case UNUSED_FLOAT:
            return Float.intBitsToFloat(value);
        default:
            throw new IllegalArgumentException(data.colNames[colj] + " has invalid type: " + data.colTypes[colj]);
        }
    }

    public Object[] getValues(int rowi) {
        int dim = colTypes.length;
        Object[] o = new Object[dim];
        for (int j = 0; j < dim; j++) {
            o[j] = getValue(rowi, j);
        }
        return o;
    }

    public int compare(int rowi, int rowj, int colIndex) {
        VariableType colType = colTypes[colIndex];
        switch (colType) {
        case CATEGORICAL_INT:
        case NUMERICAL_INT:
        case CATEGORICAL_STRING: // strings are encoded as ints
        case TARGET_CATEGORICAL_STRING:
        case TARGET_CATEGORICAL_INT:
        case UNUSED_INT:
        case UNUSED_STRING:
            return Integer.compare(getAsInt(rowi, colIndex), getAsInt(rowj, colIndex));
        case NUMERICAL_FLOAT:
        case UNUSED_FLOAT:
            float a = getAsFloat(rowi, colIndex);
            float b = getAsFloat(rowj, colIndex);
            return Float.compare(a, b);
        default:
            throw new IllegalArgumentException(colNames[colIndex] + " has invalid type: " + colType);
        }
    }

    public int compare(int a, int b, VariableType colType) {
        switch (colType) {
        case CATEGORICAL_INT:
        case NUMERICAL_INT:
        case CATEGORICAL_STRING: // strings are encoded as ints
        case TARGET_CATEGORICAL_STRING:
        case TARGET_CATEGORICAL_INT:
        case UNUSED_INT:
        case UNUSED_STRING:
            return Integer.compare(a, b);
        case NUMERICAL_FLOAT:
        case UNUSED_FLOAT:
            float af = getAsFloat(a);
            float bf = getAsFloat(b);
            return Float.compare(af, bf);
        default:
            throw new IllegalArgumentException("invalid type: " + colType);
        }
    }

    public int getPredictedCol() {
        int firstCol = indexOf(colTypes, t -> t == TARGET_CATEGORICAL_STRING || t == TARGET_CATEGORICAL_INT);
        return firstCol >= 0 ? firstCol : getNumberOfColumns() - 1; // default to last column
    }

    @Override
    public Iterator<int[]> iterator() {
        return new DataTableIterator(this);
    }

    @Override
    public void forEach(Consumer<? super int[]> action) {
        throw new UnsupportedOperationException();
    }

    @Override
    public Spliterator<int[]> spliterator() {
        throw new UnsupportedOperationException();
    }

    public static String[] getDefaultColNames(VariableType[] colTypes, int dim) {
        String[] colNames;
        colNames = new String[dim];
        for (int i = 0; i < dim; i++) {
            if (colTypes[i] == TARGET_CATEGORICAL_INT || colTypes[i] == TARGET_CATEGORICAL_STRING) {
                colNames[i] = "y";
            } else {
                colNames[i] = "x" + i;
            }
        }
        return colNames;
    }

    public static VariableType[] getDefaultColTypes(int dim) {
        VariableType[] colTypes;
        colTypes = new VariableType[dim];
        for (int i = 0; i < dim - 1; i++) {
            colTypes[i] = NUMERICAL_INT;
        }
        colTypes[dim - 1] = TARGET_CATEGORICAL_INT;
        return colTypes;
    }

    public String toTestString() {
        StringBuilder buf = new StringBuilder();
        if (colNames != null) {
            List<String> strings = map(colNames, Object::toString);
            if (colTypes != null) {
                for (int j = 0; j < strings.size(); j++) {
                    strings.set(j, strings.get(j) + "(" + varTypeShortNames[colTypes[j].ordinal()] + ")");
                }
            }
            buf.append(join(strings, ", "));
            buf.append("\n");
        }
        for (int i = 0; i < rows.size(); i++) {
            Object[] values = getValues(i);
            buf.append(join(values, ", "));
            buf.append("\n");
        }
        return buf.toString();
    }

    @Override
    public String toString() {
        return toString(defaultVarFormats);
    }

    public String toString(VariableFormat[] colFormats) {
        StringBuilder buf = new StringBuilder();
        List<Integer> colWidths = map(colNames, n -> n.length());
        // compute column widths as max of col name or widest value in column
        for (int j = 0; j < colWidths.size(); j++) {
            int w = Math.max(colWidths.get(j), getColumnMaxWidth(j));
            colWidths.set(j, w);
            String name = StringUtils.center(colNames[j], w);
            if (j > 0) {
                buf.append(" ");
            }
            buf.append(name);
        }
        buf.append("\n");
        for (int i = 0; i < rows.size(); i++) {
            Object[] values = getValues(i);
            for (int j = 0; j < colWidths.size(); j++) {
                int colWidth = colWidths.get(j);
                String colValue = values[j].toString();
                switch (colFormats[colTypes[j].ordinal()]) {
                case LEFT:
                    colValue = String.format("%-" + colWidth + "s", colValue);
                    break;
                case CENTER:
                    colValue = StringUtils.center(colValue, colWidth);
                    break;
                case RIGHT:
                    colValue = String.format("%" + colWidth + "s", colValue);
                    break;
                }
                if (j > 0) {
                    buf.append(" ");
                }
                buf.append(colValue);
            }
            buf.append("\n");
        }
        return buf.toString();
    }

    public int getColumnMaxWidth(int colIndex) {
        int w = 0;
        // scan column, find max width
        for (int i = 0; i < rows.size(); i++) {
            String v = getValue(i, colIndex).toString();
            if (v.length() > w) {
                w = v.length();
            }
        }
        return w;
    }
}