com.davidbracewell.data.DataFrame.java Source code

Java tutorial

Introduction

Here is the source code for com.davidbracewell.data.DataFrame.java

Source

/*
 * (c) 2005 David B. Bracewell
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.davidbracewell.data;

import com.davidbracewell.Collector;
import com.davidbracewell.collection.Index;
import com.davidbracewell.collection.Indexes;
import com.davidbracewell.conversion.Convert;
import com.davidbracewell.conversion.Val;
import com.davidbracewell.io.structured.ElementType;
import com.davidbracewell.io.structured.StructuredReader;
import com.davidbracewell.string.StringUtils;
import com.google.common.base.Functions;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.collect.Collections2;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;

import java.io.IOException;
import java.io.Serializable;
import java.util.*;

/**
 * <p>An implementation of an R data frame. A data frame allows for the manipulation of a matrix-like data
 * structure.</p>
 *
 * @author David B. Bracewell
 */
public class DataFrame implements Iterable<DataFrame.Row> {

    private Index<String> columnNames = Indexes.newIndex();
    private List<Row> data = Lists.newArrayList();

    /**
     * Instantiates a new Data frame.
     */
    public DataFrame() {

    }

    /**
     * Instantiates a new Data frame.
     *
     * @param columnNames the column names
     */
    public DataFrame(String... columnNames) {
        Preconditions.checkNotNull(columnNames);
        this.columnNames.addAll(Arrays.asList(columnNames));
    }

    /**
     * Instantiates a new Data frame.
     *
     * @param columnNames the column names
     */
    public DataFrame(Collection<String> columnNames) {
        Preconditions.checkNotNull(columnNames);
        this.columnNames.addAll(columnNames);
    }

    private DataFrame(Index<String> columnNames, List<Row> rows) {
        this.columnNames = columnNames;
        this.data = rows;
    }

    private DataFrame(Index<String> columnNames, Iterable<Row> rows) {
        this.columnNames = columnNames;
        this.data = Lists.newArrayList(rows);
    }

    /**
     * Slices a section of the data frame. Changes made to the slice are reflected in this data frame.
     *
     * @param rowFrom The row to slice from
     * @param rowTo   The row to slice to
     * @return The resulting data frame.
     */
    public DataFrame slice(int rowFrom, int rowTo) {
        Preconditions.checkArgument(rowFrom > 0 && rowFrom < rowTo && rowTo <= data.size(),
                "Invalid range [" + rowFrom + ", " + rowTo + "]");
        return new DataFrame(columnNames, data.subList(rowFrom, rowTo));
    }

    /**
     * Select rows matching the given filter.
     *
     * @param filter the filter to match
     * @return the data frame
     */
    public DataFrame selectRows(Predicate<Row> filter) {
        return slice(Iterables.filter(data, Preconditions.checkNotNull(filter)));
    }

    private DataFrame slice(Iterable<Row> rows) {
        if (rows != null) {
            Row first = Iterables.getFirst(rows, null);
            if (first != null) {
                return new DataFrame(first.dataFrame.columnNames, rows);
            }
        }
        return new DataFrame();
    }

    public static DataFrame read(StructuredReader reader, boolean hasHeaders) throws IOException {
        Preconditions.checkNotNull(reader);
        DataFrame dataFrame;
        reader.beginDocument();
        if (hasHeaders) {
            dataFrame = new DataFrame(Collections2.transform(reader.nextArray(), Functions.toStringFunction()));
        } else {
            dataFrame = new DataFrame();
        }
        int row = 0;
        while (reader.peek() != ElementType.END_DOCUMENT) {
            if (reader.peek() == ElementType.BEGIN_OBJECT) {
                dataFrame.addRow(Lists.<String>newArrayList());
                reader.beginObject();
                for (Map.Entry<String, Val> entry : reader.readObjectToMap().entrySet()) {
                    dataFrame.set(row, entry.getKey(), entry.getValue().asString());
                }
                reader.endObject();
                row++;
            } else if (reader.peek() == ElementType.BEGIN_ARRAY) {
                dataFrame.addRow(Lists.<String>newArrayList());
                reader.beginArray();
                int col = 0;
                while (reader.peek() != ElementType.END_ARRAY) {
                    dataFrame.set(row, col, reader.nextValue().asString());
                    col++;
                }
                reader.endArray();
                row++;
            } else {
                throw new IOException("Could not parse input file [" + reader.peek() + "]");
            }
        }
        reader.endDocument();
        return dataFrame;
    }

    private void ensureRows(int row) {
        while (data.size() <= row) {
            data.add(new Row(this));
        }
    }

    /**
     * Gets column names.
     *
     * @return the column names
     */
    public List<String> getColumnNames() {
        return columnNames.asList();
    }

    private void addColumnNames(int columnSize) {
        while (columnNames.size() <= columnSize) {
            columnNames.add("COLUMN_" + columnNames.size());
        }
    }

    /**
     * Gets string.
     *
     * @param row the row
     * @param col the col
     * @return the string
     */
    public String getString(int row, int col) {
        ensureRows(row);
        return data.get(row).get(col);
    }

    /**
     * Get val.
     *
     * @param row the row
     * @param col the col
     * @return the val
     */
    public Val get(int row, int col) {
        return Val.of(getString(row, col));
    }

    /**
     * Gets string.
     *
     * @param row the row
     * @param col the col
     * @return the string
     */
    public String getString(int row, String col) {
        return getString(row, columnNames.indexOf(col));
    }

    /**
     * Get val.
     *
     * @param row the row
     * @param col the col
     * @return the val
     */
    public Val get(int row, String col) {
        return Val.of(getString(row, col));
    }

    /**
     * Set void.
     *
     * @param row the row
     * @param col the col
     * @param o   the o
     */
    public void set(int row, int col, Object o) {
        ensureRows(row);
        List<String> rowList = data.get(row);
        rowList.set(col, o.toString());
    }

    /**
     * Set void.
     *
     * @param row the row
     * @param col the col
     * @param o   the o
     */
    public void set(int row, String col, Object o) {
        set(row, columnNames.indexOf(col), o);
    }

    /**
     * Sets the name of a column
     *
     * @param index
     * @param name
     */
    public void setColumnName(int index, String name) {
        addColumnNames(index);
        columnNames.set(index, name);
    }

    /**
     * Add row.
     *
     * @param row the row
     * @return the newly added row or null if the row could not be added.
     */
    public Row addRow(List<String> row) {
        Row rowCopy = new Row(row, this);
        addColumnNames(row.size());
        if (data.add(rowCopy)) {
            return data.get(data.size() - 1);
        }
        return null;
    }

    /**
     * Gets row.
     *
     * @param row the row
     * @return the row
     */
    public Row getRow(int row) {
        ensureRows(row);
        return data.get(row);
    }

    /**
     * Removes the row at the given index from the data frame
     *
     * @param row The index of the row to remove
     * @return A list containing the row values;
     */
    public List<String> removeRow(int row) {
        if (row < 0 || row >= data.size()) {
            return Collections.emptyList();
        }
        return data.remove(row).row;
    }

    /**
     * Removes the column at the given index from the data frame.
     *
     * @param col The column index
     * @return A list of the column values
     */
    public List<String> removeColumn(int col) {
        if (col < 0 || col >= columnSize()) {
            return Collections.emptyList();
        }
        List<String> columnValues = Lists.newArrayList();
        for (Row row : data) {
            columnValues.add(row.remove(col));
        }
        columnNames.remove(col);
        return columnValues;
    }

    /**
     * Removes the column at the given index from the data frame.
     *
     * @param columnName The name of the column to remove
     * @return A list of the column values
     */
    public List<String> removeColumn(String columnName) {
        return removeColumn(columnNames.indexOf(columnName));
    }

    /**
     * @return A list of the columns in the data frame
     */
    public List<Column> columns() {
        return new ColumnList(this);
    }

    /**
     * Gets column.
     *
     * @param col the col
     * @return the column
     */
    public Column getColumn(int col) {
        return new Column(this, col);
    }

    /**
     * Gets column.
     *
     * @param col the col
     * @return the column
     */
    public Column getColumn(String col) {
        return getColumn(columnNames.indexOf(col));
    }

    /**
     * Gets column.
     *
     * @param col   the col
     * @param clazz the clazz
     * @return the column
     */
    public <T> List<T> getColumn(String col, final Class<T> clazz) {
        return Lists.transform(getColumn(columnNames.indexOf(col)), Convert.getConverter(clazz));
    }

    /**
     * Computes a value over a column using a {@link Collector}
     *
     * @param col       The index of the column
     * @param collector The collector that processes the column
     * @param <RESULT>  The type returned by the collector
     * @return The result of the collector or null if the column index is invalid
     */
    public <RESULT> RESULT computeColumn(int col, final Collector<? super String, RESULT> collector) {
        Preconditions.checkNotNull(collector);
        if (col < 0 || col >= columnSize()) {
            return null;
        }
        for (String string : getColumn(col)) {
            collector.collect(string);
        }
        return collector.result();
    }

    /**
     * Computes a value over a column using a {@link Collector}
     *
     * @param columnName The name of the column
     * @param collector  The collector that processes the column
     * @param <RESULT>   The type returned by the collector
     * @return The result of the collector or null if the column index is invalid
     */
    public <RESULT> RESULT computeColumn(String columnName, final Collector<String, RESULT> collector) {
        Preconditions.checkNotNull(collector);
        return computeColumn(columnNames.indexOf(columnName), collector);
    }

    /**
     * Computes a value over a column using a {@link Collector}
     *
     * @param columnName The name of the column
     * @param clazz      The clazz to convert the cell value to
     * @param collector  The collector that processes the column
     * @param <IN>       The type to convert the strings in the data frame to
     * @param <RESULT>   The type returned by the collector
     * @return The result of the collector
     */
    public <IN, RESULT> RESULT computeColumn(String columnName, final Class<IN> clazz,
            final Collector<IN, RESULT> collector) {
        Preconditions.checkNotNull(clazz);
        Preconditions.checkNotNull(collector);
        return computeColumn(columnNames.indexOf(columnName), clazz, collector);
    }

    /**
     * Computes a value over a column using a {@link Collector}
     *
     * @param col       The index of the column
     * @param clazz     The clazz to convert the cell value to
     * @param collector The collector that processes the column
     * @param <IN>      The type to convert the strings in the data frame to
     * @param <RESULT>  The type returned by the collector
     * @return The result of the collector or null if the column index is invalid
     */
    public <IN, RESULT> RESULT computeColumn(int col, final Class<IN> clazz,
            final Collector<IN, RESULT> collector) {
        Preconditions.checkNotNull(clazz);
        Preconditions.checkNotNull(collector);
        if (col < 0 || col >= columnSize()) {
            return null;
        }
        for (String string : getColumn(col)) {
            collector.collect(Val.of(string).as(clazz));
        }
        return collector.result();
    }

    @Override
    public Iterator<Row> iterator() {
        return data.iterator();
    }

    /**
     * @return A list of the rows in the data frame
     */
    public List<Row> rows() {
        return new RowList(this);
    }

    public int rowSize() {
        return data.size();
    }

    public int columnSize() {
        return columnNames.size();
    }

    private static class RowList extends AbstractList<Row> implements Serializable {
        private static final long serialVersionUID = -1290773507429779126L;
        private final DataFrame dataFrame;

        private RowList(DataFrame dataFrame) {
            this.dataFrame = dataFrame;
        }

        @Override
        public Row get(int index) {
            Preconditions.checkArgument(index < dataFrame.rowSize());
            return dataFrame.getRow(index);
        }

        @Override
        public int size() {
            return dataFrame.rowSize();
        }
    }

    private static class ColumnList extends AbstractList<Column> implements Serializable {
        private static final long serialVersionUID = -1290773507429779126L;
        private final DataFrame dataFrame;

        private ColumnList(DataFrame dataFrame) {
            this.dataFrame = dataFrame;
        }

        @Override
        public Column get(int index) {
            Preconditions.checkArgument(index < dataFrame.columnSize());
            return new Column(dataFrame, index);
        }

        @Override
        public int size() {
            return dataFrame.columnSize();
        }
    }

    /**
     * Represents a view of a single column in a data frame.
     */
    public static class Column extends AbstractList<String> implements Serializable {

        private static final long serialVersionUID = 8436046573531221784L;
        private final DataFrame dataFrame;
        private final int column;

        private Column(DataFrame dataFrame, int column) {
            this.dataFrame = dataFrame;
            this.column = column;
        }

        @Override
        public String get(int index) {
            return dataFrame.getString(index, column);
        }

        /**
         * Gets the cell as a {@link Val} for conversion into other types
         *
         * @param index The row index
         * @return The value of the cell
         */
        public Val getVal(int index) {
            return dataFrame.get(index, column);
        }

        @Override
        public int size() {
            return dataFrame.rowSize();
        }

        @Override
        public boolean add(String s) {
            Row row = dataFrame.addRow(Lists.<String>newArrayList());
            if (row != null) {
                row.set(column, s);
                return true;
            }
            return false;
        }

        @Override
        public void clear() {
            for (int r = 0; r < dataFrame.rowSize(); r++) {
                dataFrame.set(r, column, StringUtils.EMPTY);
            }
        }

    }//END OF DataFrame$Column

    /**
     * The type Row.
     */
    public static class Row extends AbstractList<String> implements Serializable {

        private static final long serialVersionUID = 629132234337562319L;
        private final List<String> row = Lists.newArrayList();
        private final DataFrame dataFrame;

        private Row(DataFrame dataFrame) {
            this.dataFrame = dataFrame;
        }

        private Row(Collection<String> strings, DataFrame dataFrame) {
            row.addAll(strings);
            this.dataFrame = dataFrame;
        }

        @Override
        public String get(int index) {
            return index < row.size() ? row.get(index) : StringUtils.EMPTY;
        }

        /**
         * Get val.
         *
         * @param colName the col name
         * @return the val
         */
        public String get(String colName) {
            return get(dataFrame.columnNames.indexOf(colName));
        }

        /**
         * Gets val.
         *
         * @param index the index
         * @return the val
         */
        public Val getVal(int index) {
            return Val.of(index < row.size() ? row.get(index) : StringUtils.EMPTY);
        }

        /**
         * Get val.
         *
         * @param colName the col name
         * @return the val
         */
        public Val getVal(String colName) {
            return getVal(dataFrame.columnNames.indexOf(colName));
        }

        @Override
        public String set(int index, String element) {
            if (index > 4)
                dataFrame.addColumnNames(index);
            while (row.size() <= Math.min(index, dataFrame.columnNames.size() - 1)) {
                row.add(StringUtils.EMPTY);
            }
            return row.set(index, element);
        }

        @Override
        public int size() {
            return dataFrame.columnNames.size();
        }

        @Override
        public String toString() {
            if (row.size() == size()) {
                return row.toString();
            } else {
                StringBuilder builder = new StringBuilder("[");
                for (int i = 0; i < size(); i++) {
                    if (i > 0) {
                        builder.append(", ");
                    }
                    builder.append(get(i));
                }
                return builder.append("]").toString();
            }
        }

        @Override
        public String remove(int index) {
            String old = get(index);
            set(index, "");
            return old;
        }

        @Override
        public boolean add(String s) {
            dataFrame.addColumnNames(dataFrame.columnSize() + 1);
            set(dataFrame.columnSize(), s);
            return true;
        }

    }//END OF DataFrame$Row

}//END OF DataFrame